Bind parent nodes when executing Run().

This commit is contained in:
Alec Thomas
2019-10-30 13:46:08 +11:00
parent 77a613fb8b
commit d15c8fca8d
2 changed files with 21 additions and 6 deletions
+10 -1
View File
@@ -3,10 +3,19 @@ package kong
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
type bindings map[reflect.Type]reflect.Value type bindings map[reflect.Type]reflect.Value
func (b bindings) String() string {
out := []string{}
for k := range b {
out = append(out, k.String())
}
return "bindings{" + strings.Join(out, ", ") + "}"
}
func (b bindings) add(values ...interface{}) bindings { func (b bindings) add(values ...interface{}) bindings {
for _, v := range values { for _, v := range values {
b[reflect.TypeOf(v)] = reflect.ValueOf(v) b[reflect.TypeOf(v)] = reflect.ValueOf(v)
@@ -51,7 +60,7 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error {
if arg, ok := bindings[pt]; ok { if arg, ok := bindings[pt]; ok {
in = append(in, arg) in = append(in, arg)
} else { } else {
return fmt.Errorf("couldn't find binding of type %s for parameter %d of %T.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt) return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt)
} }
} }
out := f.Call(in) out := f.Call(in)
+11 -5
View File
@@ -525,8 +525,9 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) {
// Run executes the Run() method on the selected command, which must exist. // Run executes the Run() method on the selected command, which must exist.
// //
// Any passed values will be bindable to arguments of the target Run() method. // Any passed values will be bindable to arguments of the target Run() method. Additionally,
func (c *Context) Run(bindings ...interface{}) (err error) { // all parent nodes in the command structure will be bound.
func (c *Context) Run(binds ...interface{}) (err error) {
defer catch(&err) defer catch(&err)
node := c.Selected() node := c.Selected()
if node == nil { if node == nil {
@@ -535,12 +536,18 @@ func (c *Context) Run(bindings ...interface{}) (err error) {
type targetMethod struct { type targetMethod struct {
node *Node node *Node
method reflect.Value method reflect.Value
binds bindings
} }
methodBinds := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings)
methods := []targetMethod{} methods := []targetMethod{}
for i := 0; node != nil; i, node = i+1, node.Parent { for i := 0; node != nil; i, node = i+1, node.Parent {
method := getMethod(node.Target, "Run") method := getMethod(node.Target, "Run")
methodBinds = methodBinds.clone()
for p := node; p != nil; p = p.Parent {
methodBinds = methodBinds.add(p.Target.Addr().Interface())
}
if method.IsValid() { if method.IsValid() {
methods = append(methods, targetMethod{node, method}) methods = append(methods, targetMethod{node, method, methodBinds})
} }
} }
if len(methods) == 0 { if len(methods) == 0 {
@@ -552,8 +559,7 @@ func (c *Context) Run(bindings ...interface{}) (err error) {
} }
for _, method := range methods { for _, method := range methods {
binds := c.Kong.bindings.clone().add(bindings...).add(c).merge(c.bindings) if err = callMethod("Run", method.node.Target, method.method, method.binds); err != nil {
if err = callMethod("Run", method.node.Target, method.method, binds); err != nil {
return err return err
} }
} }