diff --git a/README.md b/README.md index 4022a27..3f04f22 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ 1. [Command handling](#command-handling) 1. [Switch on the command string](#switch-on-the-command-string) 1. [Attach a `Run(...) error` method to each command](#attach-a-run-error-method-to-each-command) +1. [BeforeHook\(\), AfterHook\(\) and the Bind\(\) option](#beforehook-afterhook-and-the-bind-option) 1. [Flags](#flags) 1. [Commands and sub-commands](#commands-and-sub-commands) 1. [Branching positional arguments](#branching-positional-arguments) @@ -26,7 +27,7 @@ 1. [`Resolver(...)` - support for default values from external sources](#resolver---support-for-default-values-from-external-sources) 1. [`*Mapper(...)` - customising how the command-line is mapped to Go values](#mapper---customising-how-the-command-line-is-mapped-to-go-values) 1. [`ConfigureHelp(HelpOptions)` and `Help(HelpFunc)` - customising help](#configurehelphelpoptions-and-helphelpfunc---customising-help) - 1. [`Hook(&field, HookFunc)` - callback hooks to execute when the command-line is parsed](#hookfield-hookfunc---callback-hooks-to-execute-when-the-command-line-is-parsed) + 1. [`Bindings(...)` - bind values for callback hooks anr Run\(\) methods](#bindings---bind-values-for-callback-hooks-anr-run-methods) 1. [Other options](#other-options) @@ -165,6 +166,8 @@ A more robust approach is to break each command out into their own structs: 3. Call `kong.Kong.Parse()` to obtain a `kong.Context`. 4. Call `kong.Context.Run(params...)` to call the selected parsed command. +Note that `Run()` method arguments may also be provided by the `Bind(...)` option (see below). + There's a full example emulating part of the Docker CLI [here](https://github.com/alecthomas/kong/tree/master/_examples/docker). eg. @@ -207,6 +210,40 @@ func main() { ``` +## BeforeHook(), AfterHook() and the Bind() option + +If a node in the grammar has a `BeforeHook(...) error` and/or `AfterHook(...) error` method, those methods will +be called before validation/assignment and after validation/assignment, respectively. + +The `--help` flag is implemented with a `BeforeHook`. + +Arguments to hooks are provided via the `Bind(...)` option. `*Kong`, `*Context` and `*Path` are also bound. + +eg. + +```go +// A flag with a hook that, if triggered, will set the debug loggers output to stdout. +var debugFlag bool + +func (d debugFlag) BeforeHook(logger *log.Logger) error { + logger.SetOutput(os.Stdout) + return nil +} + +var cli struct { + Debug debugFlag `help:"Enable debug logging."` +} + +func main() { + // Debug logger going to discard. + logger := log.New(ioutil.Discard, "", log.LstdFlags) + + ctx := kong.Parse(&cli, kong.Bind(logger)) + + // ... +} +``` + ## Flags Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure *not* tagged with `cmd` or `arg` will be a flag. Flags are optional by default. @@ -450,26 +487,8 @@ The default help output is usually sufficient, but if not there are two solution 1. Use `ConfigureHelp(HelpOptions)` to configure how help is formatted (see [HelpOptions](https://godoc.org/github.com/alecthomas/kong#HelpOptions) for details). 2. Custom help can be wired into Kong via the `Help(HelpFunc)` option. The `HelpFunc` is passed a `Context`, which contains the parsed context for the current command-line. See the implementation of `PrintHelp` for an example. -### `Hook(&field, HookFunc)` - callback hooks to execute when the command-line is parsed +### `Bindings(...)` - bind values for callback hooks anr Run() methods -Hooks are callback functions that are bound to a node in the command-line and executed at parse time, before structural validation and assignment. - -eg. - -```go -app := kong.Must(&CLI, kong.Hook(&CLI.Debug, func(ctx *Context, path *Path) error { - log.SetLevel(DEBUG) - return nil -})) -``` - -Note: it is generally less verbose to use an imperative approach to building command-lines, eg. - -```go -if CLI.Debug { - log.SetLevel(DEBUG) -} -``` But under some circumstances, hooks can be useful. diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 0000000..351ee48 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,55 @@ +package kong + +import ( + "fmt" + "reflect" +) + +type bindings map[reflect.Type]reflect.Value + +func (b bindings) add(values ...interface{}) bindings { + for _, v := range values { + b[reflect.TypeOf(v)] = reflect.ValueOf(v) + } + return b +} + +// Clone and add values. +func (b bindings) clone() bindings { + out := make(bindings, len(b)) + for k, v := range b { + out[k] = v + } + return out +} + +func getMethod(value reflect.Value, name string) reflect.Value { + method := value.MethodByName(name) + if !method.IsValid() { + if value.CanAddr() { + method = value.Addr().MethodByName(name) + } + } + return method +} + +func callMethod(name string, v, f reflect.Value, bindings bindings) error { + in := []reflect.Value{} + t := f.Type() + if t.NumOut() != 1 || t.Out(0) != callbackReturnSignature { + return fmt.Errorf("return value of %T.%s() must be exactly \"error\"", v.Type(), name) + } + for i := 0; i < t.NumIn(); i++ { + pt := t.In(i) + if arg, ok := bindings[pt]; ok { + in = append(in, arg) + } else { + return fmt.Errorf("couldn't find binding of type %s for parameter %d of %T.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt) + } + } + out := f.Call(in) + if out[0].IsNil() { + return nil + } + return out[0].Interface().(error) +} diff --git a/context.go b/context.go index 8d60d98..76e2c64 100644 --- a/context.go +++ b/context.go @@ -475,41 +475,20 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) { // The target Run() method must exist and have the type signature "Run(params...) error". func (c *Context) Run(params ...interface{}) (err error) { defer catch(&err) - expectedRunSignature, err := c.validateRun(c.Model.Node, nil) - if err != nil { - return err - } - if expectedRunSignature.NumIn() != len(params) { - return fmt.Errorf("expected %d params but received %d; does not match target Run() signature of %s", - expectedRunSignature.NumIn(), len(params), expectedRunSignature) - } - for i, param := range params { - if reflect.TypeOf(param) != expectedRunSignature.In(i) { - return fmt.Errorf("param %d is of type %s but should be of type %s to match target Run() signature of %s", - i, reflect.TypeOf(param), expectedRunSignature.In(i), expectedRunSignature) - } - } node := c.Selected() if node == nil { return fmt.Errorf("no command selected") } - method, err := getRunMethod(node.Target) - if err != nil { - return err + method := getMethod(node.Target, "Run") + if !method.IsValid() { + return fmt.Errorf("no Run() method on %s", node.Target) } _, err = c.Apply() if err != nil { return err } - reflectedParams := []reflect.Value{} - for _, param := range params { - reflectedParams = append(reflectedParams, reflect.ValueOf(param)) - } - result := method.Call(reflectedParams) - if result[0].IsNil() { - return nil - } - return result[0].Interface().(error) + binds := c.Kong.bindings.clone().add(params...).add(c) + return callMethod("Run", node.Target, method, binds) } // PrintUsage to Kong's stdout. @@ -522,45 +501,6 @@ func (c *Context) PrintUsage(summary bool) error { return nil } -// Validate that all commands have Run() methods and that their signatures are the same. -func (c *Context) validateRun(node *Node, signature reflect.Type) (reflect.Type, error) { - if node.Leaf() { - method, err := getRunMethod(node.Target) - if err != nil { - return nil, err - } - if signature == nil { - signature = method.Type() - } else if signature != method.Type() { - return nil, fmt.Errorf("Run() methods are not consistent on %s, expected %s but got %s", node.Target.Type(), signature, method.Type()) - } - if signature.NumOut() != 1 || signature.Out(0) != expectedRunReturnSignature { - return nil, fmt.Errorf("Run() method on %s should return (error)", node.Target.Type()) - } - } - for _, child := range node.Children { - if childSignature, err := c.validateRun(child, signature); err != nil { - return nil, err - } else if signature == nil { - signature = childSignature - } - } - return signature, nil -} - -func getRunMethod(value reflect.Value) (reflect.Value, error) { - method := value.MethodByName("Run") - if !method.IsValid() { - if value.CanAddr() { - method = value.Addr().MethodByName("Run") - } - if !method.IsValid() { - return method, fmt.Errorf("no Run() method on %s", value.Type()) - } - } - return method, nil -} - func checkMissingFlags(flags []*Flag) error { missing := []string{} for _, flag := range flags { diff --git a/kong.go b/kong.go index 8ff4bcb..8e3fdce 100644 --- a/kong.go +++ b/kong.go @@ -10,7 +10,7 @@ import ( ) var ( - expectedRunReturnSignature = reflect.TypeOf((*error)(nil)).Elem() + callbackReturnSignature = reflect.TypeOf((*error)(nil)).Elem() ) // Error reported by Kong. @@ -42,7 +42,7 @@ type Kong struct { Stdout io.Writer Stderr io.Writer - before map[reflect.Value]HookFunc + bindings bindings resolvers []ResolverFunc registry *Registry @@ -65,12 +65,14 @@ func New(grammar interface{}, options ...Option) (*Kong, error) { Exit: os.Exit, Stdout: os.Stdout, Stderr: os.Stderr, - before: map[reflect.Value]HookFunc{}, registry: NewRegistry().RegisterDefaults(), resolvers: []ResolverFunc{Envars()}, vars: map[string]string{}, + bindings: bindings{}, } + options = append(options, Bind(k)) + for _, option := range options { if err := option.Apply(k); err != nil { return nil, err @@ -155,13 +157,26 @@ func mergeVars(base, extra map[string]string) map[string]string { return out } +type helpValue bool + +func (h helpValue) BeforeHook(ctx *Context) error { + options := ctx.Kong.helpOptions + options.Summary = false + err := ctx.Kong.help(options, ctx) + if err != nil { + return err + } + ctx.Kong.Exit(1) + return nil +} + // Provide additional builtin flags, if any. func (k *Kong) extraFlags() []*Flag { if k.noDefaultHelp { return nil } - helpValue := false - value := reflect.ValueOf(&helpValue).Elem() + var helpTarget helpValue + value := reflect.ValueOf(&helpTarget).Elem() helpFlag := &Flag{ Value: &Value{ Name: "help", @@ -172,18 +187,7 @@ func (k *Kong) extraFlags() []*Flag { }, } helpFlag.Flag = helpFlag - hook := Hook(&helpValue, func(ctx *Context, path *Path) error { - options := k.helpOptions - options.Summary = false - err := k.help(options, ctx) - if err != nil { - return err - } - k.Exit(1) - return nil - }) k.helpFlag = helpFlag - _ = hook(k) return []*Flag{helpFlag} } @@ -200,45 +204,48 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) { if err != nil { return nil, err } - if err = k.applyHooks(ctx); err != nil { - return nil, &ParseError{error: err, Context: ctx} - } if ctx.Error != nil { return nil, &ParseError{error: ctx.Error, Context: ctx} } + if err = k.applyHook(ctx, "BeforeHook"); err != nil { + return nil, &ParseError{error: err, Context: ctx} + } if err = ctx.Validate(); err != nil { return nil, &ParseError{error: err, Context: ctx} } if _, err = ctx.Apply(); err != nil { return nil, &ParseError{error: err, Context: ctx} } + if err = k.applyHook(ctx, "AfterHook"); err != nil { + return nil, &ParseError{error: err, Context: ctx} + } return ctx, nil } -func (k *Kong) applyHooks(ctx *Context) error { +func (k *Kong) applyHook(ctx *Context, name string) error { for _, trace := range ctx.Path { - var key reflect.Value + var value reflect.Value switch { case trace.App != nil: - key = trace.App.Target + value = trace.App.Target case trace.Argument != nil: - key = trace.Argument.Target + value = trace.Argument.Target case trace.Command != nil: - key = trace.Command.Target + value = trace.Command.Target case trace.Positional != nil: - key = trace.Positional.Target + value = trace.Positional.Target case trace.Flag != nil: - key = trace.Flag.Value.Target + value = trace.Flag.Value.Target default: panic("unsupported Path") } - if key.IsValid() { - key = key.Addr() + method := getMethod(value, name) + if !method.IsValid() { + continue } - if hook := k.before[key]; hook != nil { - if err := hook(ctx, trace); err != nil { - return err - } + binds := k.bindings.clone().add(ctx, trace) + if err := callMethod(name, value, method, binds); err != nil { + return err } } return nil diff --git a/kong_test.go b/kong_test.go index 2f11c40..f67892b 100644 --- a/kong_test.go +++ b/kong_test.go @@ -355,43 +355,64 @@ func TestTraceErrorPartiallySucceeds(t *testing.T) { require.Equal(t, "one", ctx.Command()) } +type hookContext struct { + cmd bool + values []string +} + +type hookValue string + +func (h *hookValue) BeforeHook(ctx *hookContext) error { + ctx.values = append(ctx.values, "before:"+string(*h)) + return nil +} + +func (h *hookValue) AfterHook(ctx *hookContext) error { + ctx.values = append(ctx.values, "after:"+string(*h)) + return nil +} + +type hookCmd struct { + Two hookValue `kong:"arg,optional"` + Three hookValue +} + +func (h *hookCmd) BeforeHook(ctx *hookContext) error { + ctx.cmd = true + return nil +} + +func (h *hookCmd) AfterHook(ctx *hookContext) error { + ctx.cmd = true + return nil +} + func TestHooks(t *testing.T) { - var cli struct { - One struct { - Two string `kong:"arg,optional"` - Three string - } `kong:"cmd"` - } - type values struct { - one bool - two string - three string - } - hooked := values{} var tests = []struct { name string input string - values values + values hookContext }{ - {"Command", "one", values{true, "", ""}}, - {"Arg", "one two", values{true, "two", ""}}, - {"Flag", "one --three=three", values{true, "", "three"}}, - {"ArgAndFlag", "one two --three=three", values{true, "two", "three"}}, + {"Command", "one", hookContext{true, nil}}, + {"Arg", "one two", hookContext{true, []string{"before:", "after:two"}}}, + {"Flag", "one --three=THREE", hookContext{true, []string{"before:", "after:THREE"}}}, + {"ArgAndFlag", "one two --three=THREE", hookContext{true, []string{"before:", "before:", "after:two", "after:THREE"}}}, } - setOne := func(ctx *kong.Context, path *kong.Path) error { hooked.one = true; return nil } - setTwo := func(ctx *kong.Context, path *kong.Path) error { hooked.two = ctx.Value(path).String(); return nil } - setThree := func(ctx *kong.Context, path *kong.Path) error { hooked.three = ctx.Value(path).String(); return nil } - p := mustNew(t, &cli, - kong.Hook(&cli.One, setOne), - kong.Hook(&cli.One.Two, setTwo), - kong.Hook(&cli.One.Three, setThree)) + + var cli struct { + One hookCmd `cmd:""` + } + + ctx := &hookContext{} + p := mustNew(t, &cli, kong.Bind(ctx)) for _, test := range tests { - hooked = values{} + *ctx = hookContext{} + cli.One = hookCmd{} t.Run(test.name, func(t *testing.T) { _, err := p.Parse(strings.Split(test.input, " ")) require.NoError(t, err) - require.Equal(t, test.values, hooked) + require.Equal(t, &test.values, ctx) }) } } diff --git a/options.go b/options.go index 0db6d23..1420f7e 100644 --- a/options.go +++ b/options.go @@ -109,24 +109,20 @@ func Writers(stdout, stderr io.Writer) OptionFunc { } } -// HookFunc is a callback tied to a field of the grammar, called before a value is applied. +// Bind binds values for hooks and Run() function arguments. // -// "ctx" is the current parse Context, "path" is the Path entry corresponding to the hooked value. -type HookFunc func(ctx *Context, path *Path) error - -// Hook to apply before a command, flag or positional argument is encountered. +// Any arguments passed will be available to the receiving hook functions, but may be omitted. Additionally, *Kong and +// the current *Context will also be made available. // -// "ptr" is a pointer to a field of the grammar. +// There are two hook points: // -// Note that the hook will be called once for each time the corresponding node is encountered. This means that if a flag -// is passed twice, its hook will be called twice. -func Hook(ptr interface{}, hook HookFunc) OptionFunc { - key := reflect.ValueOf(ptr) - if key.Kind() != reflect.Ptr { - panic("expected a pointer") - } +// BeforeHook(...) error +// AfterHook(...) error +// +// Called before validation/assignment, and immediately after validation/assignment, respectively. +func Bind(args ...interface{}) OptionFunc { return func(k *Kong) error { - k.before[key] = hook + k.bindings.add(args...) return nil } } diff --git a/resolver_test.go b/resolver_test.go index 7e16a8d..1caa942 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -125,33 +125,6 @@ func TestJSONBasic(t *testing.T) { require.True(t, cli.Bool) } -func TestResolvedValueTriggersHooks(t *testing.T) { - var cli struct { - Int int - } - resolver := func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { - if flag.Name == "int" { - return "1", nil - } - return "", nil - } - hooked := 0 - p := mustNew(t, &cli, kong.Resolver(resolver), kong.Hook(&cli.Int, func(ctx *kong.Context, path *kong.Path) error { - hooked++ - return nil - })) - _, err := p.Parse(nil) - require.NoError(t, err) - require.Equal(t, 1, cli.Int) - require.Equal(t, 1, hooked) - - hooked = 0 - _, err = p.Parse([]string{"--int=2"}) - require.NoError(t, err) - require.Equal(t, 2, cli.Int) - require.Equal(t, 1, hooked) -} - type testUppercaseMapper struct{} func (testUppercaseMapper) Decode(ctx *kong.DecodeContext, target reflect.Value) error {