Add Visitor function for walking the model.

This commit is contained in:
Alec Thomas
2018-09-20 21:29:57 +10:00
parent 6fa83bdc0e
commit 6406edf15f
7 changed files with 143 additions and 95 deletions
+1 -1
View File
@@ -12,7 +12,7 @@ jobs:
command: | command: |
go get -v github.com/jstemmer/go-junit-report go get -v github.com/jstemmer/go-junit-report
go get -v -t -d ./... go get -v -t -d ./...
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10 curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10.2
mkdir ~/report mkdir ~/report
when: always when: always
- run: - run:
-7
View File
@@ -191,13 +191,6 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv
Format: tag.Format, Format: tag.Format,
} }
if value.Default != "" {
err := value.Parse(Scan(tag.Default), value.DefaultValue)
if err != nil {
fail("invalid default value %q for field type %s.%s (of type %s)", value.Default, v.Type(), ft.Name, ft.Type)
}
}
if tag.Arg { if tag.Arg {
node.Positional = append(node.Positional, value) node.Positional = append(node.Positional, value)
} else { } else {
+9 -29
View File
@@ -197,32 +197,12 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
// Recursively reset values to defaults (as specified in the grammar) or the zero value. // Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (c *Context) reset(node *Node) error { func (c *Context) reset(node *Node) error {
for _, flag := range node.Flags { return Visit(node, func(node Visitable, next Next) error {
err := flag.Value.Reset() if value, ok := node.(*Value); ok {
if err != nil { return next(value.Reset())
return err
} }
} return next(nil)
for _, pos := range node.Positional { })
err := pos.Reset()
if err != nil {
return err
}
}
for _, branch := range node.Children {
if branch.Argument != nil {
arg := branch.Argument
err := arg.Reset()
if err != nil {
return err
}
}
err := c.reset(branch)
if err != nil {
return err
}
}
return nil
} }
func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
@@ -562,13 +542,13 @@ func checkMissingChildren(node *Node) error {
return nil return nil
} }
if len(missing) == 1 {
return fmt.Errorf("expected %s", missing[0])
}
if len(missing) > 5 { if len(missing) > 5 {
missing = append(missing[:5], "...") missing = append(missing[:5], "...")
} }
return fmt.Errorf("expected one of %s", strings.Join(missing, ", ")) if len(missing) == 1 {
return fmt.Errorf("expected %s", missing[0])
}
return fmt.Errorf("expected %s", strings.Join(missing, " "))
} }
// If we're missing any positionals and they're required, return an error. // If we're missing any positionals and they're required, return an error.
+10 -6
View File
@@ -61,8 +61,7 @@ func TestHelp(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
require.True(t, exited) require.True(t, exited)
t.Log(w.String()) expected := `Usage: test-app --required <command>
require.Equal(t, `Usage: test-app --required <command>
A test app. A test app.
@@ -86,7 +85,10 @@ Commands:
Sub-sub-command. Sub-sub-command.
Run "test-app <command> --help" for more information on a command. Run "test-app <command> --help" for more information on a command.
`, w.String()) `
t.Log(w.String())
t.Log(expected)
require.Equal(t, expected, w.String())
}) })
t.Run("Selected", func(t *testing.T) { t.Run("Selected", func(t *testing.T) {
@@ -97,8 +99,7 @@ Run "test-app <command> --help" for more information on a command.
require.NoError(t, err) require.NoError(t, err)
}) })
require.True(t, exited) require.True(t, exited)
t.Log(w.String()) expected := `Usage: test-app two <three> --required --required-two --required-three
require.Equal(t, `Usage: test-app two <three> --required --required-two --required-three
Sub-sub-arg. Sub-sub-arg.
@@ -117,6 +118,9 @@ Flags:
--required-two --required-two
--required-three --required-three
`, w.String()) `
t.Log(expected)
t.Log(w.String())
require.Equal(t, expected, w.String())
}) })
} }
+47 -38
View File
@@ -108,29 +108,38 @@ func New(grammar interface{}, options ...Option) (*Kong, error) {
return k, nil return k, nil
} }
type varStack []Vars
func (v *varStack) head() Vars { return (*v)[len(*v)-1] }
func (v *varStack) pop() { *v = (*v)[:len(*v)-1] }
func (v *varStack) push(vars Vars) Vars {
if len(*v) != 0 {
vars = (*v)[len(*v)-1].CloneWith(vars)
}
*v = append(*v, vars)
return vars
}
// Interpolate variables into model. // Interpolate variables into model.
func (k *Kong) interpolate(node *Node) (err error) { func (k *Kong) interpolate(node *Node) (err error) {
vars := node.Vars() stack := varStack{}
node.Help, err = interpolate(node.Help, vars) return Visit(node, func(node Visitable, next Next) error {
if err != nil { switch node := node.(type) {
return fmt.Errorf("help for %s: %s", node.Path(), err) case *Node:
} vars := stack.push(node.Vars())
for _, flag := range node.Flags { node.Help, err = interpolate(node.Help, vars)
if err = k.interpolateValue(flag.Value, vars); err != nil { if err != nil {
return fmt.Errorf("help for %s: %s", node.Path(), err)
}
err = next(nil)
stack.pop()
return err return err
case *Value:
return next(k.interpolateValue(node, stack.head()))
} }
} return next(nil)
for _, pos := range node.Positional { })
if err = k.interpolateValue(pos, vars); err != nil {
return err
}
}
for _, child := range node.Children {
if err = k.interpolate(child); err != nil {
return err
}
}
return nil
} }
func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) { func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) {
@@ -244,27 +253,27 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
if node == nil { if node == nil {
return nil return nil
} }
bindings := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars)) return Visit(node, func(n Visitable, next Next) error {
for _, flag := range node.Flags { node, ok := n.(*Node)
if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { if !ok {
continue return next(nil)
} }
method := getMethod(flag.Target, name) binds := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars))
if !method.IsValid() { for _, flag := range node.Flags {
continue if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
}
path := &Path{Flag: flag}
if err := callMethod(name, flag.Target, method, binds.clone().add(path)); err != nil {
return next(err)
}
} }
path := &Path{Flag: flag} return next(nil)
if err := callMethod(name, flag.Target, method, bindings.clone().add(path)); err != nil { })
return err
}
}
for _, branch := range node.Children {
err := k.applyHookToDefaultFlags(ctx, branch, name)
if err != nil {
return err
}
}
return nil
} }
func formatMultilineMessage(w io.Writer, leaders []string, format string, args ...interface{}) { func formatMultilineMessage(w io.Writer, leaders []string, format string, args ...interface{}) {
+20 -14
View File
@@ -7,6 +7,11 @@ import (
"strings" "strings"
) )
// A Visitable component in the model.
type Visitable interface {
node()
}
// Application is the root of the Kong model. // Application is the root of the Kong model.
type Application struct { type Application struct {
*Node *Node
@@ -48,6 +53,8 @@ type Node struct {
Argument *Value // Populated when Type is ArgumentNode. Argument *Value // Populated when Type is ArgumentNode.
} }
func (*Node) node() {}
// Leaf returns true if this Node is a leaf node. // Leaf returns true if this Node is a leaf node.
func (n *Node) Leaf() bool { func (n *Node) Leaf() bool {
return len(n.Children) == 0 return len(n.Children) == 0
@@ -99,23 +106,20 @@ func (n *Node) AllFlags(hide bool) (out [][]*Flag) {
// //
// If "hidden" is true hidden leaves will be omitted. // If "hidden" is true hidden leaves will be omitted.
func (n *Node) Leaves(hide bool) (out []*Node) { func (n *Node) Leaves(hide bool) (out []*Node) {
var walk func(n *Node) _ = Visit(n, func(nd Visitable, next Next) error {
walk = func(n *Node) { if nd == n {
if hide && n.Hidden { return next(nil)
return
} }
if len(n.Children) == 0 && n.Type != ApplicationNode { if node, ok := nd.(*Node); ok {
out = append(out, n) if hide && node.Hidden {
} return next(nil)
for _, child := range n.Children { }
if child.Type == CommandNode || child.Type == ArgumentNode { if len(node.Children) == 0 && node.Type != ApplicationNode {
walk(child) out = append(out, node)
} }
} }
} return next(nil)
for _, child := range n.Children { })
walk(child)
}
return return
} }
@@ -289,6 +293,8 @@ func (v *Value) Reset() error {
return nil return nil
} }
func (*Value) node() {}
// A Positional represents a non-branching command-line positional argument. // A Positional represents a non-branching command-line positional argument.
type Positional = Value type Positional = Value
+56
View File
@@ -0,0 +1,56 @@
package kong
import (
"fmt"
)
// Next should be called by Visitor to proceed with the walk.
type Next func(error) error
// Visitor can be used to walk all nodes in the model.
type Visitor func(node Visitable, next Next) error
// Visit all nodes.
func Visit(node Visitable, visitor Visitor) error {
return visitor(node, func(err error) error {
if err != nil {
return err
}
switch node := node.(type) {
case *Application:
return visitNodeChildren(node.Node, visitor)
case *Node:
return visitNodeChildren(node, visitor)
case *Value:
case *Flag:
return Visit(node.Value, visitor)
default:
panic(fmt.Sprintf("unsupported node type %T", node))
}
return nil
})
}
func visitNodeChildren(node *Node, visitor Visitor) error {
if node.Argument != nil {
if err := Visit(node.Argument, visitor); err != nil {
return err
}
}
for _, flag := range node.Flags {
if err := Visit(flag, visitor); err != nil {
return err
}
}
for _, pos := range node.Positional {
if err := Visit(pos, visitor); err != nil {
return err
}
}
for _, child := range node.Children {
if err := Visit(child, visitor); err != nil {
return err
}
}
return nil
}