diff --git a/.circleci/config.yml b/.circleci/config.yml index 56fa38e..31d2bd3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,7 +12,7 @@ jobs: command: | go get -v github.com/jstemmer/go-junit-report go get -v -t -d ./... - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10 + curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10.2 mkdir ~/report when: always - run: diff --git a/build.go b/build.go index 2a6a21d..49cc767 100644 --- a/build.go +++ b/build.go @@ -191,13 +191,6 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv Format: tag.Format, } - if value.Default != "" { - err := value.Parse(Scan(tag.Default), value.DefaultValue) - if err != nil { - fail("invalid default value %q for field type %s.%s (of type %s)", value.Default, v.Type(), ft.Name, ft.Type) - } - } - if tag.Arg { node.Positional = append(node.Positional, value) } else { diff --git a/context.go b/context.go index 8175eef..2144dfb 100644 --- a/context.go +++ b/context.go @@ -197,32 +197,12 @@ func (c *Context) FlagValue(flag *Flag) interface{} { // Recursively reset values to defaults (as specified in the grammar) or the zero value. func (c *Context) reset(node *Node) error { - for _, flag := range node.Flags { - err := flag.Value.Reset() - if err != nil { - return err + return Visit(node, func(node Visitable, next Next) error { + if value, ok := node.(*Value); ok { + return next(value.Reset()) } - } - for _, pos := range node.Positional { - err := pos.Reset() - if err != nil { - return err - } - } - for _, branch := range node.Children { - if branch.Argument != nil { - arg := branch.Argument - err := arg.Reset() - if err != nil { - return err - } - } - err := c.reset(branch) - if err != nil { - return err - } - } - return nil + return next(nil) + }) } func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo @@ -562,13 +542,13 @@ func checkMissingChildren(node *Node) error { return nil } - if len(missing) == 1 { - return fmt.Errorf("expected %s", missing[0]) - } if len(missing) > 5 { missing = append(missing[:5], "...") } - return fmt.Errorf("expected one of %s", strings.Join(missing, ", ")) + if len(missing) == 1 { + return fmt.Errorf("expected %s", missing[0]) + } + return fmt.Errorf("expected %s", strings.Join(missing, " ")) } // If we're missing any positionals and they're required, return an error. diff --git a/help_test.go b/help_test.go index 54caae9..d8af552 100644 --- a/help_test.go +++ b/help_test.go @@ -61,8 +61,7 @@ func TestHelp(t *testing.T) { require.NoError(t, err) }) require.True(t, exited) - t.Log(w.String()) - require.Equal(t, `Usage: test-app --required + expected := `Usage: test-app --required A test app. @@ -86,7 +85,10 @@ Commands: Sub-sub-command. Run "test-app --help" for more information on a command. -`, w.String()) +` + t.Log(w.String()) + t.Log(expected) + require.Equal(t, expected, w.String()) }) t.Run("Selected", func(t *testing.T) { @@ -97,8 +99,7 @@ Run "test-app --help" for more information on a command. require.NoError(t, err) }) require.True(t, exited) - t.Log(w.String()) - require.Equal(t, `Usage: test-app two --required --required-two --required-three + expected := `Usage: test-app two --required --required-two --required-three Sub-sub-arg. @@ -117,6 +118,9 @@ Flags: --required-two --required-three -`, w.String()) +` + t.Log(expected) + t.Log(w.String()) + require.Equal(t, expected, w.String()) }) } diff --git a/kong.go b/kong.go index 2de2b12..53c345d 100644 --- a/kong.go +++ b/kong.go @@ -108,29 +108,38 @@ func New(grammar interface{}, options ...Option) (*Kong, error) { return k, nil } +type varStack []Vars + +func (v *varStack) head() Vars { return (*v)[len(*v)-1] } +func (v *varStack) pop() { *v = (*v)[:len(*v)-1] } +func (v *varStack) push(vars Vars) Vars { + if len(*v) != 0 { + vars = (*v)[len(*v)-1].CloneWith(vars) + } + *v = append(*v, vars) + return vars +} + // Interpolate variables into model. func (k *Kong) interpolate(node *Node) (err error) { - vars := node.Vars() - node.Help, err = interpolate(node.Help, vars) - if err != nil { - return fmt.Errorf("help for %s: %s", node.Path(), err) - } - for _, flag := range node.Flags { - if err = k.interpolateValue(flag.Value, vars); err != nil { + stack := varStack{} + return Visit(node, func(node Visitable, next Next) error { + switch node := node.(type) { + case *Node: + vars := stack.push(node.Vars()) + node.Help, err = interpolate(node.Help, vars) + if err != nil { + return fmt.Errorf("help for %s: %s", node.Path(), err) + } + err = next(nil) + stack.pop() return err + + case *Value: + return next(k.interpolateValue(node, stack.head())) } - } - for _, pos := range node.Positional { - if err = k.interpolateValue(pos, vars); err != nil { - return err - } - } - for _, child := range node.Children { - if err = k.interpolate(child); err != nil { - return err - } - } - return nil + return next(nil) + }) } func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) { @@ -244,27 +253,27 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er if node == nil { return nil } - bindings := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars)) - for _, flag := range node.Flags { - if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { - continue + return Visit(node, func(n Visitable, next Next) error { + node, ok := n.(*Node) + if !ok { + return next(nil) } - method := getMethod(flag.Target, name) - if !method.IsValid() { - continue + binds := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars)) + for _, flag := range node.Flags { + if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { + continue + } + method := getMethod(flag.Target, name) + if !method.IsValid() { + continue + } + path := &Path{Flag: flag} + if err := callMethod(name, flag.Target, method, binds.clone().add(path)); err != nil { + return next(err) + } } - path := &Path{Flag: flag} - if err := callMethod(name, flag.Target, method, bindings.clone().add(path)); err != nil { - return err - } - } - for _, branch := range node.Children { - err := k.applyHookToDefaultFlags(ctx, branch, name) - if err != nil { - return err - } - } - return nil + return next(nil) + }) } func formatMultilineMessage(w io.Writer, leaders []string, format string, args ...interface{}) { diff --git a/model.go b/model.go index f0180d8..50c1df9 100644 --- a/model.go +++ b/model.go @@ -7,6 +7,11 @@ import ( "strings" ) +// A Visitable component in the model. +type Visitable interface { + node() +} + // Application is the root of the Kong model. type Application struct { *Node @@ -48,6 +53,8 @@ type Node struct { Argument *Value // Populated when Type is ArgumentNode. } +func (*Node) node() {} + // Leaf returns true if this Node is a leaf node. func (n *Node) Leaf() bool { return len(n.Children) == 0 @@ -99,23 +106,20 @@ func (n *Node) AllFlags(hide bool) (out [][]*Flag) { // // If "hidden" is true hidden leaves will be omitted. func (n *Node) Leaves(hide bool) (out []*Node) { - var walk func(n *Node) - walk = func(n *Node) { - if hide && n.Hidden { - return + _ = Visit(n, func(nd Visitable, next Next) error { + if nd == n { + return next(nil) } - if len(n.Children) == 0 && n.Type != ApplicationNode { - out = append(out, n) - } - for _, child := range n.Children { - if child.Type == CommandNode || child.Type == ArgumentNode { - walk(child) + if node, ok := nd.(*Node); ok { + if hide && node.Hidden { + return next(nil) + } + if len(node.Children) == 0 && node.Type != ApplicationNode { + out = append(out, node) } } - } - for _, child := range n.Children { - walk(child) - } + return next(nil) + }) return } @@ -289,6 +293,8 @@ func (v *Value) Reset() error { return nil } +func (*Value) node() {} + // A Positional represents a non-branching command-line positional argument. type Positional = Value diff --git a/visit.go b/visit.go new file mode 100644 index 0000000..d00bc27 --- /dev/null +++ b/visit.go @@ -0,0 +1,56 @@ +package kong + +import ( + "fmt" +) + +// Next should be called by Visitor to proceed with the walk. +type Next func(error) error + +// Visitor can be used to walk all nodes in the model. +type Visitor func(node Visitable, next Next) error + +// Visit all nodes. +func Visit(node Visitable, visitor Visitor) error { + return visitor(node, func(err error) error { + if err != nil { + return err + } + switch node := node.(type) { + case *Application: + return visitNodeChildren(node.Node, visitor) + case *Node: + return visitNodeChildren(node, visitor) + case *Value: + case *Flag: + return Visit(node.Value, visitor) + default: + panic(fmt.Sprintf("unsupported node type %T", node)) + } + return nil + }) +} + +func visitNodeChildren(node *Node, visitor Visitor) error { + if node.Argument != nil { + if err := Visit(node.Argument, visitor); err != nil { + return err + } + } + for _, flag := range node.Flags { + if err := Visit(flag, visitor); err != nil { + return err + } + } + for _, pos := range node.Positional { + if err := Visit(pos, visitor); err != nil { + return err + } + } + for _, child := range node.Children { + if err := Visit(child, visitor); err != nil { + return err + } + } + return nil +}