From cf89213e1ef55939f5a956510b4ff94e354e24ec Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sat, 26 May 2018 17:38:35 -0400 Subject: [PATCH] Add hook support. --- _examples/shell/main.go | 14 ++++-- build.go | 20 ++------ context.go | 59 ++++++++--------------- context_test.go | 19 -------- help.go | 31 ++++++------ kong.go | 104 +++++++++++++++++++++++++++++++++++----- kong_test.go | 75 +++++++++++++++++++++++++---- model.go | 1 + options.go | 6 +-- options_test.go | 6 +-- 10 files changed, 213 insertions(+), 122 deletions(-) diff --git a/_examples/shell/main.go b/_examples/shell/main.go index 9a6e7d0..ebefeae 100644 --- a/_examples/shell/main.go +++ b/_examples/shell/main.go @@ -3,25 +3,29 @@ package main import ( "encoding/json" "fmt" + "os" "github.com/alecthomas/kong" ) var CLI struct { - Rm struct { + Help bool `kong:"help='Display help.'"` + Rm struct { Force bool `kong:"help='Force removal.'"` Recursive bool `kong:"help='Recursively remove files.'"` - Paths []string `kong:"help='Paths to remove.',type='path'"` - } `kong:"help='Remove files.'"` + Paths []string `kong:"arg,help='Paths to remove.',type='path'"` + } `kong:"cmd,help='Remove files.'"` Ls struct { Paths []string `kong:"help='Paths to list.',type='path'"` - } `kong:"help='List paths.'"` + } `kong:"cmd,help='List paths.'"` } func main() { - cmd := kong.Parse(&CLI) + app := kong.Must(&CLI).Hook(&CLI.Help, kong.Help(nil, nil)) + cmd, err := app.Parse(os.Args[1:]) + app.FatalIfErrorf(err) s, _ := json.Marshal(&CLI) fmt.Println(cmd) fmt.Println(string(s)) diff --git a/build.go b/build.go index f59ef7e..8b84ce7 100644 --- a/build.go +++ b/build.go @@ -14,23 +14,11 @@ func build(ast interface{}) (app *Application, err error) { return nil, fmt.Errorf("expected a pointer to a struct but got %T", ast) } - app = &Application{ - // Synthesize a --help flag. - HelpFlag: &Flag{ - Value: Value{ - Name: "help", - Help: "Show context-sensitive help.", - Flag: true, - Value: reflect.New(reflect.TypeOf(false)).Elem(), - Decoder: kindDecoders[reflect.Bool], - }}, - } - node := buildNode(iv, map[string]bool{"help": true}) + app = &Application{} + node := buildNode(iv, map[string]bool{}) if len(node.Positional) > 0 && len(node.Children) > 0 { return nil, fmt.Errorf("can't mix positional arguments and branching arguments on %T", ast) } - // Prepend --help flag. - node.Flags = append([]*Flag{app.HelpFlag}, node.Flags...) app.Node = *node return app, nil } @@ -40,7 +28,9 @@ func dashedString(s string) string { } func buildNode(v reflect.Value, seenFlags map[string]bool) *Node { - node := &Node{} + node := &Node{ + Target: v, + } for i := 0; i < v.NumField(); i++ { ft := v.Type().Field(i) if strings.ToLower(ft.Name[0:1]) == ft.Name[0:1] { diff --git a/context.go b/context.go index 108873a..9b082d5 100644 --- a/context.go +++ b/context.go @@ -2,12 +2,13 @@ package kong import ( "fmt" + "io" "reflect" "strings" ) -// ParseTrace records the nodes and parsed values from the current command-line. -type ParseTrace struct { +// Trace records the nodes and parsed values from the current command-line. +type Trace struct { // One of these will be non-nil. App *Application Positional *Value @@ -22,9 +23,12 @@ type ParseTrace struct { Value reflect.Value } -type ParseContext struct { - Trace []*ParseTrace // A trace through parsed nodes. - Error error // Error that occurred during trace, if any. +type Context struct { + Trace []*Trace // A trace through parsed nodes. + Error error // Error that occurred during trace, if any. + + Stdout io.Writer + Stderr io.Writer node *Node // Current node being parsed. @@ -34,7 +38,7 @@ type ParseContext struct { } // Flags returns the accumulated available flags. -func (p *ParseContext) Flags() (flags []*Flag) { +func (p *Context) Flags() (flags []*Flag) { for _, trace := range p.Trace { flags = append(flags, trace.Flags...) } @@ -42,7 +46,7 @@ func (p *ParseContext) Flags() (flags []*Flag) { } // Command returns the full command path. -func (p *ParseContext) Command() (command []string) { +func (p *Context) Command() (command []string) { for _, trace := range p.Trace { switch { case trace.Positional != nil: @@ -56,30 +60,8 @@ func (p *ParseContext) Command() (command []string) { return } -// Trace parses the command-line, validating and collecting matching grammar nodes. -func Trace(args []string, app *Application) (*ParseContext, error) { - p := &ParseContext{ - app: app, - args: args, - } - p.Trace = append(p.Trace, &ParseTrace{ - App: app, - Flags: append([]*Flag{}, app.Flags...), - }) - err := p.reset(&p.app.Node) - if err != nil { - return nil, err - } - p.Error = p.trace(&p.app.Node) - if err = checkMissingFlags(p.Flags()); err != nil { - return nil, err - } - - return p, nil -} - // FlagValue returns the set value of a flag, if it was encountered and exists. -func (p *ParseContext) FlagValue(flag *Flag) reflect.Value { +func (p *Context) FlagValue(flag *Flag) reflect.Value { for _, trace := range p.Trace { if trace.Flag == flag { return trace.Value @@ -89,7 +71,7 @@ func (p *ParseContext) FlagValue(flag *Flag) reflect.Value { } // Recursively reset values to defaults (as specified in the grammar) or the zero value. -func (p *ParseContext) reset(node *Node) error { +func (p *Context) reset(node *Node) error { p.scan = Scan(p.args...) for _, flag := range node.Flags { err := flag.Value.Reset() @@ -124,7 +106,7 @@ func (p *ParseContext) reset(node *Node) error { return nil } -func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo +func (p *Context) trace(node *Node) (err error) { // nolint: gocyclo positional := 0 p.node = node flags := append(p.Flags(), node.Flags...) @@ -203,7 +185,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo if err != nil { return err } - p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value, Flags: node.Flags}) + p.Trace = append(p.Trace, &Trace{Positional: arg, Value: value, Flags: node.Flags}) positional++ break } @@ -214,9 +196,10 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo case branch.Command != nil: if branch.Command.Name == token.Value { p.scan.Pop() - p.Trace = append(p.Trace, &ParseTrace{ + p.Trace = append(p.Trace, &Trace{ Command: branch.Command, Flags: node.Flags, + Value: branch.Command.Target, }) return p.trace(branch.Command) } @@ -224,7 +207,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo case branch.Argument != nil: arg := branch.Argument.Argument if value, err := arg.Parse(p.scan); err == nil { - p.Trace = append(p.Trace, &ParseTrace{ + p.Trace = append(p.Trace, &Trace{ Argument: branch.Argument, Value: value, Flags: node.Flags, @@ -252,7 +235,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo } // Apply traced context to the target grammar. -func (p *ParseContext) Apply() (string, error) { +func (p *Context) Apply() (string, error) { path := []string{} for _, trace := range p.Trace { switch { @@ -324,7 +307,7 @@ func checkMissingPositionals(positional int, values []*Value) error { return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } -func (p *ParseContext) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) { +func (p *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) { defer catch(&err) token := p.scan.Peek() for _, flag := range flags { @@ -335,7 +318,7 @@ func (p *ParseContext) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (er if err != nil { return err } - p.Trace = append(p.Trace, &ParseTrace{Flag: flag, Value: value}) + p.Trace = append(p.Trace, &Trace{Flag: flag, Value: value}) return nil } } diff --git a/context_test.go b/context_test.go index 08bd685..1af827e 100644 --- a/context_test.go +++ b/context_test.go @@ -1,20 +1 @@ package kong - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestTraceErrorPartiallySucceeds(t *testing.T) { - var cli struct { - One struct { - Two struct { - } `kong:"cmd"` - } `kong:"cmd"` - } - p := mustNew(t, &cli) - trace, err := Trace([]string{"one", "bad"}, p.Model) - require.NoError(t, err) - require.Error(t, trace.Error) -} diff --git a/help.go b/help.go index 8a92146..8338ed3 100644 --- a/help.go +++ b/help.go @@ -1,7 +1,6 @@ package kong import ( - "io" "text/template" ) @@ -18,20 +17,20 @@ usage: {{.Name}} var defaultHelpTemplate = template.Must(template.New("help").Parse(defaultHelp)) -// WriteHelp to w. -// -// If w is nil, the default stdout writer will be used. -// -// If args are provided, help will be written in the context o -func (k *Kong) WriteHelp(w io.Writer, args ...interface{}) error { - if w == nil { - w = k.stdout +// Help returns a Hook that will display help and exit. +func Help(tmpl *template.Template, tmplctx map[string]interface{}) Hook { + return func(app *Kong, ctx *Context, trace *Trace) error { + merged := map[string]interface{}{ + "Application": app.Model, + } + for k, v := range tmplctx { + merged[k] = v + } + err := tmpl.Execute(app.Stdout, merged) + if err != nil { + return err + } + app.Exit(0) + return nil } - ctx := map[string]interface{}{ - "Application": k.Model, - } - for k, v := range k.helpContext { - ctx[k] = v - } - return k.help.Execute(w, ctx) } diff --git a/kong.go b/kong.go index 70855ed..d43623b 100644 --- a/kong.go +++ b/kong.go @@ -5,9 +5,12 @@ import ( "io" "os" "path/filepath" + "reflect" "text/template" ) +type Hook func(app *Kong, ctx *Context, trace *Trace) error + // Error reported by Kong. type Error struct{ msg string } @@ -17,29 +20,39 @@ func fail(format string, args ...interface{}) { panic(Error{fmt.Sprintf(format, args...)}) } +func Must(ast interface{}, options ...Option) *Kong { + k, err := New(ast, options...) + if err != nil { + panic(err) + } + return k +} + // Kong is the main parser type. type Kong struct { Model *Application // Termination function (defaults to os.Exit) - terminate func(int) + Exit func(int) - stdout io.Writer - stderr io.Writer + Stdout io.Writer + Stderr io.Writer help *template.Template helpContext map[string]interface{} helpFuncs template.FuncMap + hooks map[reflect.Value]Hook } // New creates a new Kong parser into ast. func New(ast interface{}, options ...Option) (*Kong, error) { k := &Kong{ - terminate: os.Exit, - stdout: os.Stdout, - stderr: os.Stderr, + Exit: os.Exit, + Stdout: os.Stdout, + Stderr: os.Stderr, help: defaultHelpTemplate, helpContext: map[string]interface{}{}, helpFuncs: template.FuncMap{}, + hooks: map[reflect.Value]Hook{}, } model, err := build(ast) @@ -56,26 +69,91 @@ func New(ast interface{}, options ...Option) (*Kong, error) { return k, nil } +// Trace parses the command-line, validating and collecting matching grammar nodes. +func (k *Kong) Trace(args []string) (*Context, error) { + p := &Context{ + app: k.Model, + args: args, + Trace: []*Trace{ + {App: k.Model, Flags: append([]*Flag{}, k.Model.Flags...), Value: k.Model.Target}, + }, + } + err := p.reset(&p.app.Node) + if err != nil { + return nil, err + } + p.Error = p.trace(&p.app.Node) + if err = checkMissingFlags(p.Flags()); err != nil { + return nil, err + } + return p, nil +} + +// Hook to execute when a command is encountered. +func (k *Kong) Hook(ptr interface{}, hook Hook) *Kong { + k.hooks[reflect.ValueOf(ptr)] = hook + return k +} + // Parse arguments into target. +// +// The returned "command" is a space separated path to the final selected command, if any. Commands appear as +// the command name while positional arguments are the argument name surrounded by "". func (k *Kong) Parse(args []string) (command string, err error) { defer catch(&err) - ctx, err := Trace(args, k.Model) + ctx, err := k.Trace(args) if err != nil { return "", err } + if err := k.applyHooks(ctx); err != nil { + return "", err + } if ctx.Error != nil { return "", ctx.Error } - if value := ctx.FlagValue(k.Model.HelpFlag); value.IsValid() && value.Bool() { - return "", nil - } return ctx.Apply() } -func (k *Kong) Errorf(format string, args ...interface{}) { - fmt.Fprintf(os.Stderr, k.Model.Name+": "+format, args...) +func (k *Kong) applyHooks(ctx *Context) error { + for _, trace := range ctx.Trace { + var key reflect.Value + switch { + case trace.App != nil: + key = trace.App.Target + case trace.Argument != nil: + key = trace.Argument.Target + case trace.Command != nil: + key = trace.Command.Target + case trace.Positional != nil: + key = trace.Positional.Value + case trace.Flag != nil: + key = trace.Flag.Value.Value + default: + panic("unsupported Trace") + } + if key.IsValid() { + key = key.Addr() + } + if hook := k.hooks[key]; hook != nil { + if err := hook(k, ctx, trace); err != nil { + return err + } + } + } + return nil } +// Printf writes a message to Kong.Stdout with the application name prefixed. +func (k *Kong) Printf(format string, args ...interface{}) { + fmt.Fprintf(k.Stdout, k.Model.Name+": "+format, args...) +} + +// Errorf writes a message to Kong.Stderr with the application name prefixed. +func (k *Kong) Errorf(format string, args ...interface{}) { + fmt.Fprintf(k.Stderr, k.Model.Name+": "+format, args...) +} + +// FatalIfError terminates with an error message if err != nil. func (k *Kong) FatalIfErrorf(err error, args ...interface{}) { if err == nil { return @@ -85,7 +163,7 @@ func (k *Kong) FatalIfErrorf(err error, args ...interface{}) { msg = fmt.Sprintf(args[0].(string), args...) + ": " + err.Error() } k.Errorf("%s\n", msg) - k.terminate(1) + k.Exit(1) } func catch(err *error) { diff --git a/kong_test.go b/kong_test.go index 44233d3..cdfa1a9 100644 --- a/kong_test.go +++ b/kong_test.go @@ -1,6 +1,7 @@ package kong import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -311,16 +312,6 @@ func TestInvalidDefaultErrors(t *testing.T) { require.Error(t, err) } -func TestHelp(t *testing.T) { - var cli struct { - Flag string - } - p := mustNew(t, &cli) - _, err := p.Parse([]string{"--flag=hello", "--help"}) - require.NoError(t, err) - require.NotEqual(t, "hello", cli.Flag) -} - func TestCommandMissingTagIsInvalid(t *testing.T) { var cli struct { One struct{} @@ -352,3 +343,67 @@ func TestDuplicateFlagOnPeerCommandIsOkay(t *testing.T) { _, err := New(&cli) require.NoError(t, err) } + +func TestTraceErrorPartiallySucceeds(t *testing.T) { + var cli struct { + One struct { + Two struct { + } `kong:"cmd"` + } `kong:"cmd"` + } + p := mustNew(t, &cli) + trace, err := p.Trace([]string{"one", "bad"}) + require.NoError(t, err) + require.Error(t, trace.Error) + require.Equal(t, []string{"one"}, trace.Command()) +} + +func TestHooks(t *testing.T) { + var cli struct { + One struct { + Two string `kong:"arg,optional"` + Three string + } `kong:"cmd"` + } + type values struct { + one bool + two string + three string + } + hooked := values{} + var tests = []struct { + name string + input string + values values + }{ + {"Command", "one", values{true, "", ""}}, + {"Arg", "one two", values{true, "two", ""}}, + {"Flag", "one --three=three", values{true, "", "three"}}, + {"ArgAndFlag", "one two --three=three", values{true, "two", "three"}}, + } + p := mustNew(t, &cli). + Hook(&cli.One, func(app *Kong, ctx *Context, trace *Trace) error { + hooked.one = true + return nil + }). + Hook(&cli.One.Two, func(app *Kong, ctx *Context, trace *Trace) error { + hooked.two = trace.Value.String() + return nil + }). + Hook(&cli.One.Three, func(app *Kong, ctx *Context, trace *Trace) error { + hooked.three = trace.Value.String() + return nil + }) + + for _, test := range tests { + hooked = values{} + t.Run(test.name, func(t *testing.T) { + _, err := p.Parse(strings.Split(test.input, " ")) + require.NoError(t, err) + require.Equal(t, test.values, hooked) + }) + } +} + +func TestHelp(t *testing.T) { +} diff --git a/model.go b/model.go index 32fb206..4e70f8d 100644 --- a/model.go +++ b/model.go @@ -21,6 +21,7 @@ type Node struct { Flags []*Flag Positional []*Value Children []*Branch + Target reflect.Value } // A Value is either a flag or a variable positional argument. diff --git a/options.go b/options.go index 5a3ed24..17ef158 100644 --- a/options.go +++ b/options.go @@ -9,7 +9,7 @@ type Option func(k *Kong) // ExitFunction overrides the function used to terminate. This is useful for testing or interactive use. func ExitFunction(exit func(int)) Option { - return func(k *Kong) { k.terminate = exit } + return func(k *Kong) { k.Exit = exit } } // Name overrides the application name. @@ -37,7 +37,7 @@ func HelpContext(context map[string]interface{}) Option { // Writers overrides the default writers. Useful for testing or interactive use. func Writers(stdout, stderr io.Writer) Option { return func(k *Kong) { - k.stdout = stdout - k.stderr = stderr + k.Stdout = stdout + k.Stderr = stderr } } diff --git a/options_test.go b/options_test.go index d6fd419..61c7de1 100644 --- a/options_test.go +++ b/options_test.go @@ -12,7 +12,7 @@ func TestOptions(t *testing.T) { require.NoError(t, err) require.Equal(t, "name", p.Model.Name) require.Equal(t, "description", p.Model.Help) - require.Nil(t, p.stdout) - require.Nil(t, p.stderr) - require.Nil(t, p.terminate) + require.Nil(t, p.Stdout) + require.Nil(t, p.Stderr) + require.Nil(t, p.Exit) }