From d15c8fca8dfd0f4dba023021118452517cf36d9a Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Wed, 30 Oct 2019 13:46:08 +1100 Subject: [PATCH] Bind parent nodes when executing Run(). --- callbacks.go | 11 ++++++++++- context.go | 16 +++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/callbacks.go b/callbacks.go index 7f4637b..e403d27 100644 --- a/callbacks.go +++ b/callbacks.go @@ -3,10 +3,19 @@ package kong import ( "fmt" "reflect" + "strings" ) type bindings map[reflect.Type]reflect.Value +func (b bindings) String() string { + out := []string{} + for k := range b { + out = append(out, k.String()) + } + return "bindings{" + strings.Join(out, ", ") + "}" +} + func (b bindings) add(values ...interface{}) bindings { for _, v := range values { b[reflect.TypeOf(v)] = reflect.ValueOf(v) @@ -51,7 +60,7 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error { if arg, ok := bindings[pt]; ok { in = append(in, arg) } else { - return fmt.Errorf("couldn't find binding of type %s for parameter %d of %T.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt) + return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt) } } out := f.Call(in) diff --git a/context.go b/context.go index 7279f63..80c262f 100644 --- a/context.go +++ b/context.go @@ -525,8 +525,9 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) { // Run executes the Run() method on the selected command, which must exist. // -// Any passed values will be bindable to arguments of the target Run() method. -func (c *Context) Run(bindings ...interface{}) (err error) { +// Any passed values will be bindable to arguments of the target Run() method. Additionally, +// all parent nodes in the command structure will be bound. +func (c *Context) Run(binds ...interface{}) (err error) { defer catch(&err) node := c.Selected() if node == nil { @@ -535,12 +536,18 @@ func (c *Context) Run(bindings ...interface{}) (err error) { type targetMethod struct { node *Node method reflect.Value + binds bindings } + methodBinds := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings) methods := []targetMethod{} for i := 0; node != nil; i, node = i+1, node.Parent { method := getMethod(node.Target, "Run") + methodBinds = methodBinds.clone() + for p := node; p != nil; p = p.Parent { + methodBinds = methodBinds.add(p.Target.Addr().Interface()) + } if method.IsValid() { - methods = append(methods, targetMethod{node, method}) + methods = append(methods, targetMethod{node, method, methodBinds}) } } if len(methods) == 0 { @@ -552,8 +559,7 @@ func (c *Context) Run(bindings ...interface{}) (err error) { } for _, method := range methods { - binds := c.Kong.bindings.clone().add(bindings...).add(c).merge(c.bindings) - if err = callMethod("Run", method.node.Target, method.method, binds); err != nil { + if err = callMethod("Run", method.node.Target, method.method, method.binds); err != nil { return err } }