diff --git a/context.go b/context.go index 1f39e74..9b26d5f 100755 --- a/context.go +++ b/context.go @@ -21,9 +21,6 @@ type Path struct { // Flags added by this node. Flags []*Flag - // Parsed value for non-commands. - Value reflect.Value - // True if this Path element was created as the result of a resolver. Resolved bool } @@ -34,8 +31,22 @@ type Context struct { Path []*Path // A trace through parsed nodes. Error error // Error that occurred during trace, if any. - args []string - scan *Scanner + values map[*Value]reflect.Value // Temporary values during tracing. + args []string + scan *Scanner +} + +// Value returns the value for a particular path element. +func (c *Context) Value(path *Path) reflect.Value { + switch { + case path.Positional != nil: + return c.values[path.Positional] + case path.Flag != nil: + return c.values[path.Flag.Value] + case path.Argument != nil: + return c.values[path.Argument.Argument] + } + panic("can only retrieve value for flag, argument or positional") } // Selected command or argument. @@ -62,9 +73,10 @@ func Trace(k *Kong, args []string) (*Context, error) { App: k, args: args, Path: []*Path{ - {App: k.Model, Flags: k.Model.Flags, Value: k.Model.Target}, + {App: k.Model, Flags: k.Model.Flags}, }, - scan: Scan(args...), + values: map[*Value]reflect.Value{}, + scan: Scan(args...), } c.Error = c.trace(&c.App.Model.Node) return c, c.traceResolvers() @@ -142,7 +154,7 @@ func (c *Context) Command() (command []string) { func (c *Context) FlagValue(flag *Flag) reflect.Value { for _, trace := range c.Path { if trace.Flag == flag { - return trace.Value + return c.values[trace.Flag.Value] } } return reflect.Value{} @@ -253,14 +265,13 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo // Ensure we've consumed all positional arguments. if positional < len(node.Positional) { arg := node.Positional[positional] - value, err := arg.Parse(c.scan) + err := arg.Parse(c.scan, c.getValue(arg)) if err != nil { return err } c.Path = append(c.Path, &Path{ Parent: node, Positional: arg, - Value: value, }) positional++ break @@ -273,7 +284,6 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo c.Path = append(c.Path, &Path{ Parent: node, Command: branch, - Value: branch.Target, Flags: branch.Flags, }) return c.trace(branch) @@ -284,11 +294,10 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo for _, branch := range node.Children { if branch.Type == ArgumentNode { arg := branch.Argument - if value, err := arg.Parse(c.scan); err == nil { + if err := arg.Parse(c.scan, c.getValue(arg)); err == nil { c.Path = append(c.Path, &Path{ Parent: node, Argument: branch, - Value: value, Flags: branch.Flags, }) return c.trace(branch) @@ -313,6 +322,10 @@ func (c *Context) traceResolvers() error { inserted := []*Path{} for _, path := range c.Path { for _, flag := range path.Flags { + // Flag has already been set on the command-line. + if _, ok := c.values[flag.Value]; ok { + continue + } for _, resolver := range c.App.resolvers { s, err := resolver(c, path, flag) if err != nil { @@ -323,13 +336,13 @@ func (c *Context) traceResolvers() error { } scan := Scan().PushTyped(s, FlagValueToken) - value, err := flag.Parse(scan) + delete(c.values, flag.Value) + err = flag.Parse(scan, c.getValue(flag.Value)) if err != nil { return err } inserted = append(inserted, &Path{ Flag: flag, - Value: value, Resolved: true, }) } @@ -339,6 +352,15 @@ func (c *Context) traceResolvers() error { return nil } +func (c *Context) getValue(value *Value) reflect.Value { + v, ok := c.values[value] + if !ok { + v = reflect.New(value.Target.Type()).Elem() + c.values[value] = v + } + return v +} + // Apply traced context to the target grammar. func (c *Context) Apply() (string, error) { err := c.reset(&c.App.Model.Node) @@ -349,21 +371,25 @@ func (c *Context) Apply() (string, error) { path := []string{} for _, trace := range c.Path { + var value *Value switch { case trace.App != nil: case trace.Argument != nil: path = append(path, "<"+trace.Argument.Name+">") - trace.Argument.Argument.Apply(trace.Value) + value = trace.Argument.Argument case trace.Command != nil: path = append(path, trace.Command.Name) case trace.Flag != nil: - trace.Flag.Value.Apply(trace.Value) + value = trace.Flag.Value case trace.Positional != nil: path = append(path, "<"+trace.Positional.Name+">") - trace.Positional.Apply(trace.Value) + value = trace.Positional default: panic("unsupported path ?!") } + if value != nil { + value.Apply(c.getValue(value)) + } } return strings.Join(path, " "), nil @@ -378,11 +404,11 @@ func (c *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err err continue } c.scan.Pop() - value, err := flag.Parse(c.scan) + err := flag.Parse(c.scan, c.getValue(flag.Value)) if err != nil { return err } - c.Path = append(c.Path, &Path{Flag: flag, Value: value}) + c.Path = append(c.Path, &Path{Flag: flag}) return nil } return fmt.Errorf("unknown flag --%s", token.Value) diff --git a/kong_test.go b/kong_test.go index a288980..bb3d322 100644 --- a/kong_test.go +++ b/kong_test.go @@ -352,8 +352,8 @@ func TestHooks(t *testing.T) { {"ArgAndFlag", "one two --three=three", values{true, "two", "three"}}, } setOne := func(ctx *Context, path *Path) error { hooked.one = true; return nil } - setTwo := func(ctx *Context, path *Path) error { hooked.two = path.Value.String(); return nil } - setThree := func(ctx *Context, path *Path) error { hooked.three = path.Value.String(); return nil } + setTwo := func(ctx *Context, path *Path) error { hooked.two = ctx.Value(path).String(); return nil } + setThree := func(ctx *Context, path *Path) error { hooked.three = ctx.Value(path).String(); return nil } p := mustNew(t, &cli, Hook(&cli.One, setOne), Hook(&cli.One.Two, setTwo), diff --git a/model.go b/model.go index 91bb891..a7834ca 100644 --- a/model.go +++ b/model.go @@ -178,13 +178,12 @@ func (v *Value) IsBool() bool { } // Parse tokens into value, parse, and validate, but do not write to the field. -func (v *Value) Parse(scan *Scanner) (reflect.Value, error) { - value := reflect.New(v.Target.Type()).Elem() - err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, value) +func (v *Value) Parse(scan *Scanner, target reflect.Value) error { + err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target) if err == nil { v.Set = true } - return value, err + return err } // Apply value to field. @@ -199,12 +198,7 @@ func (v *Value) Apply(value reflect.Value) { func (v *Value) Reset() error { v.Target.Set(reflect.Zero(v.Target.Type())) if v.Default != "" { - value, err := v.Parse(Scan(v.Default)) - if err != nil { - return err - } - v.Apply(value) - v.Set = false + return v.Parse(Scan(v.Default), v.Target) } return nil } diff --git a/resolver_test.go b/resolver_test.go index 0bd02e1..be4d1f6 100755 --- a/resolver_test.go +++ b/resolver_test.go @@ -147,7 +147,7 @@ func TestResolvedValueTriggersHooks(t *testing.T) { _, err = p.Parse([]string{"--int=2"}) require.NoError(t, err) require.Equal(t, 2, cli.Int) - require.Equal(t, 2, hooked) + require.Equal(t, 1, hooked) } type testUppercaseMapper struct{}