Track trace values externally to the path.
This allows accumulating mappers to work correctly. This also means that resolvers are not even triggered if a command-line value is passed, which is more desirable behaviour.
This commit is contained in:
+46
-20
@@ -21,9 +21,6 @@ type Path struct {
|
|||||||
// Flags added by this node.
|
// Flags added by this node.
|
||||||
Flags []*Flag
|
Flags []*Flag
|
||||||
|
|
||||||
// Parsed value for non-commands.
|
|
||||||
Value reflect.Value
|
|
||||||
|
|
||||||
// True if this Path element was created as the result of a resolver.
|
// True if this Path element was created as the result of a resolver.
|
||||||
Resolved bool
|
Resolved bool
|
||||||
}
|
}
|
||||||
@@ -34,8 +31,22 @@ type Context struct {
|
|||||||
Path []*Path // A trace through parsed nodes.
|
Path []*Path // A trace through parsed nodes.
|
||||||
Error error // Error that occurred during trace, if any.
|
Error error // Error that occurred during trace, if any.
|
||||||
|
|
||||||
args []string
|
values map[*Value]reflect.Value // Temporary values during tracing.
|
||||||
scan *Scanner
|
args []string
|
||||||
|
scan *Scanner
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the value for a particular path element.
|
||||||
|
func (c *Context) Value(path *Path) reflect.Value {
|
||||||
|
switch {
|
||||||
|
case path.Positional != nil:
|
||||||
|
return c.values[path.Positional]
|
||||||
|
case path.Flag != nil:
|
||||||
|
return c.values[path.Flag.Value]
|
||||||
|
case path.Argument != nil:
|
||||||
|
return c.values[path.Argument.Argument]
|
||||||
|
}
|
||||||
|
panic("can only retrieve value for flag, argument or positional")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Selected command or argument.
|
// Selected command or argument.
|
||||||
@@ -62,9 +73,10 @@ func Trace(k *Kong, args []string) (*Context, error) {
|
|||||||
App: k,
|
App: k,
|
||||||
args: args,
|
args: args,
|
||||||
Path: []*Path{
|
Path: []*Path{
|
||||||
{App: k.Model, Flags: k.Model.Flags, Value: k.Model.Target},
|
{App: k.Model, Flags: k.Model.Flags},
|
||||||
},
|
},
|
||||||
scan: Scan(args...),
|
values: map[*Value]reflect.Value{},
|
||||||
|
scan: Scan(args...),
|
||||||
}
|
}
|
||||||
c.Error = c.trace(&c.App.Model.Node)
|
c.Error = c.trace(&c.App.Model.Node)
|
||||||
return c, c.traceResolvers()
|
return c, c.traceResolvers()
|
||||||
@@ -142,7 +154,7 @@ func (c *Context) Command() (command []string) {
|
|||||||
func (c *Context) FlagValue(flag *Flag) reflect.Value {
|
func (c *Context) FlagValue(flag *Flag) reflect.Value {
|
||||||
for _, trace := range c.Path {
|
for _, trace := range c.Path {
|
||||||
if trace.Flag == flag {
|
if trace.Flag == flag {
|
||||||
return trace.Value
|
return c.values[trace.Flag.Value]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return reflect.Value{}
|
return reflect.Value{}
|
||||||
@@ -253,14 +265,13 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
|
|||||||
// Ensure we've consumed all positional arguments.
|
// Ensure we've consumed all positional arguments.
|
||||||
if positional < len(node.Positional) {
|
if positional < len(node.Positional) {
|
||||||
arg := node.Positional[positional]
|
arg := node.Positional[positional]
|
||||||
value, err := arg.Parse(c.scan)
|
err := arg.Parse(c.scan, c.getValue(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.Path = append(c.Path, &Path{
|
c.Path = append(c.Path, &Path{
|
||||||
Parent: node,
|
Parent: node,
|
||||||
Positional: arg,
|
Positional: arg,
|
||||||
Value: value,
|
|
||||||
})
|
})
|
||||||
positional++
|
positional++
|
||||||
break
|
break
|
||||||
@@ -273,7 +284,6 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
|
|||||||
c.Path = append(c.Path, &Path{
|
c.Path = append(c.Path, &Path{
|
||||||
Parent: node,
|
Parent: node,
|
||||||
Command: branch,
|
Command: branch,
|
||||||
Value: branch.Target,
|
|
||||||
Flags: branch.Flags,
|
Flags: branch.Flags,
|
||||||
})
|
})
|
||||||
return c.trace(branch)
|
return c.trace(branch)
|
||||||
@@ -284,11 +294,10 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
|
|||||||
for _, branch := range node.Children {
|
for _, branch := range node.Children {
|
||||||
if branch.Type == ArgumentNode {
|
if branch.Type == ArgumentNode {
|
||||||
arg := branch.Argument
|
arg := branch.Argument
|
||||||
if value, err := arg.Parse(c.scan); err == nil {
|
if err := arg.Parse(c.scan, c.getValue(arg)); err == nil {
|
||||||
c.Path = append(c.Path, &Path{
|
c.Path = append(c.Path, &Path{
|
||||||
Parent: node,
|
Parent: node,
|
||||||
Argument: branch,
|
Argument: branch,
|
||||||
Value: value,
|
|
||||||
Flags: branch.Flags,
|
Flags: branch.Flags,
|
||||||
})
|
})
|
||||||
return c.trace(branch)
|
return c.trace(branch)
|
||||||
@@ -313,6 +322,10 @@ func (c *Context) traceResolvers() error {
|
|||||||
inserted := []*Path{}
|
inserted := []*Path{}
|
||||||
for _, path := range c.Path {
|
for _, path := range c.Path {
|
||||||
for _, flag := range path.Flags {
|
for _, flag := range path.Flags {
|
||||||
|
// Flag has already been set on the command-line.
|
||||||
|
if _, ok := c.values[flag.Value]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, resolver := range c.App.resolvers {
|
for _, resolver := range c.App.resolvers {
|
||||||
s, err := resolver(c, path, flag)
|
s, err := resolver(c, path, flag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -323,13 +336,13 @@ func (c *Context) traceResolvers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
scan := Scan().PushTyped(s, FlagValueToken)
|
scan := Scan().PushTyped(s, FlagValueToken)
|
||||||
value, err := flag.Parse(scan)
|
delete(c.values, flag.Value)
|
||||||
|
err = flag.Parse(scan, c.getValue(flag.Value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
inserted = append(inserted, &Path{
|
inserted = append(inserted, &Path{
|
||||||
Flag: flag,
|
Flag: flag,
|
||||||
Value: value,
|
|
||||||
Resolved: true,
|
Resolved: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -339,6 +352,15 @@ func (c *Context) traceResolvers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) getValue(value *Value) reflect.Value {
|
||||||
|
v, ok := c.values[value]
|
||||||
|
if !ok {
|
||||||
|
v = reflect.New(value.Target.Type()).Elem()
|
||||||
|
c.values[value] = v
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
// Apply traced context to the target grammar.
|
// Apply traced context to the target grammar.
|
||||||
func (c *Context) Apply() (string, error) {
|
func (c *Context) Apply() (string, error) {
|
||||||
err := c.reset(&c.App.Model.Node)
|
err := c.reset(&c.App.Model.Node)
|
||||||
@@ -349,21 +371,25 @@ func (c *Context) Apply() (string, error) {
|
|||||||
path := []string{}
|
path := []string{}
|
||||||
|
|
||||||
for _, trace := range c.Path {
|
for _, trace := range c.Path {
|
||||||
|
var value *Value
|
||||||
switch {
|
switch {
|
||||||
case trace.App != nil:
|
case trace.App != nil:
|
||||||
case trace.Argument != nil:
|
case trace.Argument != nil:
|
||||||
path = append(path, "<"+trace.Argument.Name+">")
|
path = append(path, "<"+trace.Argument.Name+">")
|
||||||
trace.Argument.Argument.Apply(trace.Value)
|
value = trace.Argument.Argument
|
||||||
case trace.Command != nil:
|
case trace.Command != nil:
|
||||||
path = append(path, trace.Command.Name)
|
path = append(path, trace.Command.Name)
|
||||||
case trace.Flag != nil:
|
case trace.Flag != nil:
|
||||||
trace.Flag.Value.Apply(trace.Value)
|
value = trace.Flag.Value
|
||||||
case trace.Positional != nil:
|
case trace.Positional != nil:
|
||||||
path = append(path, "<"+trace.Positional.Name+">")
|
path = append(path, "<"+trace.Positional.Name+">")
|
||||||
trace.Positional.Apply(trace.Value)
|
value = trace.Positional
|
||||||
default:
|
default:
|
||||||
panic("unsupported path ?!")
|
panic("unsupported path ?!")
|
||||||
}
|
}
|
||||||
|
if value != nil {
|
||||||
|
value.Apply(c.getValue(value))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(path, " "), nil
|
return strings.Join(path, " "), nil
|
||||||
@@ -378,11 +404,11 @@ func (c *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err err
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
c.scan.Pop()
|
c.scan.Pop()
|
||||||
value, err := flag.Parse(c.scan)
|
err := flag.Parse(c.scan, c.getValue(flag.Value))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.Path = append(c.Path, &Path{Flag: flag, Value: value})
|
c.Path = append(c.Path, &Path{Flag: flag})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown flag --%s", token.Value)
|
return fmt.Errorf("unknown flag --%s", token.Value)
|
||||||
|
|||||||
+2
-2
@@ -352,8 +352,8 @@ func TestHooks(t *testing.T) {
|
|||||||
{"ArgAndFlag", "one two --three=three", values{true, "two", "three"}},
|
{"ArgAndFlag", "one two --three=three", values{true, "two", "three"}},
|
||||||
}
|
}
|
||||||
setOne := func(ctx *Context, path *Path) error { hooked.one = true; return nil }
|
setOne := func(ctx *Context, path *Path) error { hooked.one = true; return nil }
|
||||||
setTwo := func(ctx *Context, path *Path) error { hooked.two = path.Value.String(); return nil }
|
setTwo := func(ctx *Context, path *Path) error { hooked.two = ctx.Value(path).String(); return nil }
|
||||||
setThree := func(ctx *Context, path *Path) error { hooked.three = path.Value.String(); return nil }
|
setThree := func(ctx *Context, path *Path) error { hooked.three = ctx.Value(path).String(); return nil }
|
||||||
p := mustNew(t, &cli,
|
p := mustNew(t, &cli,
|
||||||
Hook(&cli.One, setOne),
|
Hook(&cli.One, setOne),
|
||||||
Hook(&cli.One.Two, setTwo),
|
Hook(&cli.One.Two, setTwo),
|
||||||
|
|||||||
@@ -178,13 +178,12 @@ func (v *Value) IsBool() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse tokens into value, parse, and validate, but do not write to the field.
|
// Parse tokens into value, parse, and validate, but do not write to the field.
|
||||||
func (v *Value) Parse(scan *Scanner) (reflect.Value, error) {
|
func (v *Value) Parse(scan *Scanner, target reflect.Value) error {
|
||||||
value := reflect.New(v.Target.Type()).Elem()
|
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target)
|
||||||
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, value)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
v.Set = true
|
v.Set = true
|
||||||
}
|
}
|
||||||
return value, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply value to field.
|
// Apply value to field.
|
||||||
@@ -199,12 +198,7 @@ func (v *Value) Apply(value reflect.Value) {
|
|||||||
func (v *Value) Reset() error {
|
func (v *Value) Reset() error {
|
||||||
v.Target.Set(reflect.Zero(v.Target.Type()))
|
v.Target.Set(reflect.Zero(v.Target.Type()))
|
||||||
if v.Default != "" {
|
if v.Default != "" {
|
||||||
value, err := v.Parse(Scan(v.Default))
|
return v.Parse(Scan(v.Default), v.Target)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
v.Apply(value)
|
|
||||||
v.Set = false
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -147,7 +147,7 @@ func TestResolvedValueTriggersHooks(t *testing.T) {
|
|||||||
_, err = p.Parse([]string{"--int=2"})
|
_, err = p.Parse([]string{"--int=2"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 2, cli.Int)
|
require.Equal(t, 2, cli.Int)
|
||||||
require.Equal(t, 2, hooked)
|
require.Equal(t, 1, hooked)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testUppercaseMapper struct{}
|
type testUppercaseMapper struct{}
|
||||||
|
|||||||
Reference in New Issue
Block a user