Track trace values externally to the path.

This allows accumulating mappers to work correctly. This also means
that resolvers are not even triggered if a command-line value is passed,
which is more desirable behaviour.
This commit is contained in:
Alec Thomas
2018-06-13 21:41:55 +10:00
parent c7dca86dad
commit 54386f7fa5
4 changed files with 53 additions and 33 deletions
+46 -20
View File
@@ -21,9 +21,6 @@ type Path struct {
// Flags added by this node.
Flags []*Flag
// Parsed value for non-commands.
Value reflect.Value
// True if this Path element was created as the result of a resolver.
Resolved bool
}
@@ -34,8 +31,22 @@ type Context struct {
Path []*Path // A trace through parsed nodes.
Error error // Error that occurred during trace, if any.
args []string
scan *Scanner
values map[*Value]reflect.Value // Temporary values during tracing.
args []string
scan *Scanner
}
// Value returns the value for a particular path element.
func (c *Context) Value(path *Path) reflect.Value {
switch {
case path.Positional != nil:
return c.values[path.Positional]
case path.Flag != nil:
return c.values[path.Flag.Value]
case path.Argument != nil:
return c.values[path.Argument.Argument]
}
panic("can only retrieve value for flag, argument or positional")
}
// Selected command or argument.
@@ -62,9 +73,10 @@ func Trace(k *Kong, args []string) (*Context, error) {
App: k,
args: args,
Path: []*Path{
{App: k.Model, Flags: k.Model.Flags, Value: k.Model.Target},
{App: k.Model, Flags: k.Model.Flags},
},
scan: Scan(args...),
values: map[*Value]reflect.Value{},
scan: Scan(args...),
}
c.Error = c.trace(&c.App.Model.Node)
return c, c.traceResolvers()
@@ -142,7 +154,7 @@ func (c *Context) Command() (command []string) {
func (c *Context) FlagValue(flag *Flag) reflect.Value {
for _, trace := range c.Path {
if trace.Flag == flag {
return trace.Value
return c.values[trace.Flag.Value]
}
}
return reflect.Value{}
@@ -253,14 +265,13 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
// Ensure we've consumed all positional arguments.
if positional < len(node.Positional) {
arg := node.Positional[positional]
value, err := arg.Parse(c.scan)
err := arg.Parse(c.scan, c.getValue(arg))
if err != nil {
return err
}
c.Path = append(c.Path, &Path{
Parent: node,
Positional: arg,
Value: value,
})
positional++
break
@@ -273,7 +284,6 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
c.Path = append(c.Path, &Path{
Parent: node,
Command: branch,
Value: branch.Target,
Flags: branch.Flags,
})
return c.trace(branch)
@@ -284,11 +294,10 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
for _, branch := range node.Children {
if branch.Type == ArgumentNode {
arg := branch.Argument
if value, err := arg.Parse(c.scan); err == nil {
if err := arg.Parse(c.scan, c.getValue(arg)); err == nil {
c.Path = append(c.Path, &Path{
Parent: node,
Argument: branch,
Value: value,
Flags: branch.Flags,
})
return c.trace(branch)
@@ -313,6 +322,10 @@ func (c *Context) traceResolvers() error {
inserted := []*Path{}
for _, path := range c.Path {
for _, flag := range path.Flags {
// Flag has already been set on the command-line.
if _, ok := c.values[flag.Value]; ok {
continue
}
for _, resolver := range c.App.resolvers {
s, err := resolver(c, path, flag)
if err != nil {
@@ -323,13 +336,13 @@ func (c *Context) traceResolvers() error {
}
scan := Scan().PushTyped(s, FlagValueToken)
value, err := flag.Parse(scan)
delete(c.values, flag.Value)
err = flag.Parse(scan, c.getValue(flag.Value))
if err != nil {
return err
}
inserted = append(inserted, &Path{
Flag: flag,
Value: value,
Resolved: true,
})
}
@@ -339,6 +352,15 @@ func (c *Context) traceResolvers() error {
return nil
}
func (c *Context) getValue(value *Value) reflect.Value {
v, ok := c.values[value]
if !ok {
v = reflect.New(value.Target.Type()).Elem()
c.values[value] = v
}
return v
}
// Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) {
err := c.reset(&c.App.Model.Node)
@@ -349,21 +371,25 @@ func (c *Context) Apply() (string, error) {
path := []string{}
for _, trace := range c.Path {
var value *Value
switch {
case trace.App != nil:
case trace.Argument != nil:
path = append(path, "<"+trace.Argument.Name+">")
trace.Argument.Argument.Apply(trace.Value)
value = trace.Argument.Argument
case trace.Command != nil:
path = append(path, trace.Command.Name)
case trace.Flag != nil:
trace.Flag.Value.Apply(trace.Value)
value = trace.Flag.Value
case trace.Positional != nil:
path = append(path, "<"+trace.Positional.Name+">")
trace.Positional.Apply(trace.Value)
value = trace.Positional
default:
panic("unsupported path ?!")
}
if value != nil {
value.Apply(c.getValue(value))
}
}
return strings.Join(path, " "), nil
@@ -378,11 +404,11 @@ func (c *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err err
continue
}
c.scan.Pop()
value, err := flag.Parse(c.scan)
err := flag.Parse(c.scan, c.getValue(flag.Value))
if err != nil {
return err
}
c.Path = append(c.Path, &Path{Flag: flag, Value: value})
c.Path = append(c.Path, &Path{Flag: flag})
return nil
}
return fmt.Errorf("unknown flag --%s", token.Value)
+2 -2
View File
@@ -352,8 +352,8 @@ func TestHooks(t *testing.T) {
{"ArgAndFlag", "one two --three=three", values{true, "two", "three"}},
}
setOne := func(ctx *Context, path *Path) error { hooked.one = true; return nil }
setTwo := func(ctx *Context, path *Path) error { hooked.two = path.Value.String(); return nil }
setThree := func(ctx *Context, path *Path) error { hooked.three = path.Value.String(); return nil }
setTwo := func(ctx *Context, path *Path) error { hooked.two = ctx.Value(path).String(); return nil }
setThree := func(ctx *Context, path *Path) error { hooked.three = ctx.Value(path).String(); return nil }
p := mustNew(t, &cli,
Hook(&cli.One, setOne),
Hook(&cli.One.Two, setTwo),
+4 -10
View File
@@ -178,13 +178,12 @@ func (v *Value) IsBool() bool {
}
// Parse tokens into value, parse, and validate, but do not write to the field.
func (v *Value) Parse(scan *Scanner) (reflect.Value, error) {
value := reflect.New(v.Target.Type()).Elem()
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, value)
func (v *Value) Parse(scan *Scanner, target reflect.Value) error {
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target)
if err == nil {
v.Set = true
}
return value, err
return err
}
// Apply value to field.
@@ -199,12 +198,7 @@ func (v *Value) Apply(value reflect.Value) {
func (v *Value) Reset() error {
v.Target.Set(reflect.Zero(v.Target.Type()))
if v.Default != "" {
value, err := v.Parse(Scan(v.Default))
if err != nil {
return err
}
v.Apply(value)
v.Set = false
return v.Parse(Scan(v.Default), v.Target)
}
return nil
}
+1 -1
View File
@@ -147,7 +147,7 @@ func TestResolvedValueTriggersHooks(t *testing.T) {
_, err = p.Parse([]string{"--int=2"})
require.NoError(t, err)
require.Equal(t, 2, cli.Int)
require.Equal(t, 2, hooked)
require.Equal(t, 1, hooked)
}
type testUppercaseMapper struct{}