Add Visitor function for walking the model.

This commit is contained in:
Alec Thomas
2018-09-20 21:29:57 +10:00
parent 6fa83bdc0e
commit 6406edf15f
7 changed files with 143 additions and 95 deletions
+1 -1
View File
@@ -12,7 +12,7 @@ jobs:
command: |
go get -v github.com/jstemmer/go-junit-report
go get -v -t -d ./...
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.10.2
mkdir ~/report
when: always
- run:
-7
View File
@@ -191,13 +191,6 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv
Format: tag.Format,
}
if value.Default != "" {
err := value.Parse(Scan(tag.Default), value.DefaultValue)
if err != nil {
fail("invalid default value %q for field type %s.%s (of type %s)", value.Default, v.Type(), ft.Name, ft.Type)
}
}
if tag.Arg {
node.Positional = append(node.Positional, value)
} else {
+9 -29
View File
@@ -197,32 +197,12 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (c *Context) reset(node *Node) error {
for _, flag := range node.Flags {
err := flag.Value.Reset()
if err != nil {
return err
return Visit(node, func(node Visitable, next Next) error {
if value, ok := node.(*Value); ok {
return next(value.Reset())
}
}
for _, pos := range node.Positional {
err := pos.Reset()
if err != nil {
return err
}
}
for _, branch := range node.Children {
if branch.Argument != nil {
arg := branch.Argument
err := arg.Reset()
if err != nil {
return err
}
}
err := c.reset(branch)
if err != nil {
return err
}
}
return nil
return next(nil)
})
}
func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
@@ -562,13 +542,13 @@ func checkMissingChildren(node *Node) error {
return nil
}
if len(missing) == 1 {
return fmt.Errorf("expected %s", missing[0])
}
if len(missing) > 5 {
missing = append(missing[:5], "...")
}
return fmt.Errorf("expected one of %s", strings.Join(missing, ", "))
if len(missing) == 1 {
return fmt.Errorf("expected %s", missing[0])
}
return fmt.Errorf("expected %s", strings.Join(missing, " "))
}
// If we're missing any positionals and they're required, return an error.
+10 -6
View File
@@ -61,8 +61,7 @@ func TestHelp(t *testing.T) {
require.NoError(t, err)
})
require.True(t, exited)
t.Log(w.String())
require.Equal(t, `Usage: test-app --required <command>
expected := `Usage: test-app --required <command>
A test app.
@@ -86,7 +85,10 @@ Commands:
Sub-sub-command.
Run "test-app <command> --help" for more information on a command.
`, w.String())
`
t.Log(w.String())
t.Log(expected)
require.Equal(t, expected, w.String())
})
t.Run("Selected", func(t *testing.T) {
@@ -97,8 +99,7 @@ Run "test-app <command> --help" for more information on a command.
require.NoError(t, err)
})
require.True(t, exited)
t.Log(w.String())
require.Equal(t, `Usage: test-app two <three> --required --required-two --required-three
expected := `Usage: test-app two <three> --required --required-two --required-three
Sub-sub-arg.
@@ -117,6 +118,9 @@ Flags:
--required-two
--required-three
`, w.String())
`
t.Log(expected)
t.Log(w.String())
require.Equal(t, expected, w.String())
})
}
+47 -38
View File
@@ -108,29 +108,38 @@ func New(grammar interface{}, options ...Option) (*Kong, error) {
return k, nil
}
type varStack []Vars
func (v *varStack) head() Vars { return (*v)[len(*v)-1] }
func (v *varStack) pop() { *v = (*v)[:len(*v)-1] }
func (v *varStack) push(vars Vars) Vars {
if len(*v) != 0 {
vars = (*v)[len(*v)-1].CloneWith(vars)
}
*v = append(*v, vars)
return vars
}
// Interpolate variables into model.
func (k *Kong) interpolate(node *Node) (err error) {
vars := node.Vars()
node.Help, err = interpolate(node.Help, vars)
if err != nil {
return fmt.Errorf("help for %s: %s", node.Path(), err)
}
for _, flag := range node.Flags {
if err = k.interpolateValue(flag.Value, vars); err != nil {
stack := varStack{}
return Visit(node, func(node Visitable, next Next) error {
switch node := node.(type) {
case *Node:
vars := stack.push(node.Vars())
node.Help, err = interpolate(node.Help, vars)
if err != nil {
return fmt.Errorf("help for %s: %s", node.Path(), err)
}
err = next(nil)
stack.pop()
return err
case *Value:
return next(k.interpolateValue(node, stack.head()))
}
}
for _, pos := range node.Positional {
if err = k.interpolateValue(pos, vars); err != nil {
return err
}
}
for _, child := range node.Children {
if err = k.interpolate(child); err != nil {
return err
}
}
return nil
return next(nil)
})
}
func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) {
@@ -244,27 +253,27 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
if node == nil {
return nil
}
bindings := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars))
for _, flag := range node.Flags {
if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
return Visit(node, func(n Visitable, next Next) error {
node, ok := n.(*Node)
if !ok {
return next(nil)
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
binds := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars))
for _, flag := range node.Flags {
if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
}
path := &Path{Flag: flag}
if err := callMethod(name, flag.Target, method, binds.clone().add(path)); err != nil {
return next(err)
}
}
path := &Path{Flag: flag}
if err := callMethod(name, flag.Target, method, bindings.clone().add(path)); err != nil {
return err
}
}
for _, branch := range node.Children {
err := k.applyHookToDefaultFlags(ctx, branch, name)
if err != nil {
return err
}
}
return nil
return next(nil)
})
}
func formatMultilineMessage(w io.Writer, leaders []string, format string, args ...interface{}) {
+20 -14
View File
@@ -7,6 +7,11 @@ import (
"strings"
)
// A Visitable component in the model.
type Visitable interface {
node()
}
// Application is the root of the Kong model.
type Application struct {
*Node
@@ -48,6 +53,8 @@ type Node struct {
Argument *Value // Populated when Type is ArgumentNode.
}
func (*Node) node() {}
// Leaf returns true if this Node is a leaf node.
func (n *Node) Leaf() bool {
return len(n.Children) == 0
@@ -99,23 +106,20 @@ func (n *Node) AllFlags(hide bool) (out [][]*Flag) {
//
// If "hidden" is true hidden leaves will be omitted.
func (n *Node) Leaves(hide bool) (out []*Node) {
var walk func(n *Node)
walk = func(n *Node) {
if hide && n.Hidden {
return
_ = Visit(n, func(nd Visitable, next Next) error {
if nd == n {
return next(nil)
}
if len(n.Children) == 0 && n.Type != ApplicationNode {
out = append(out, n)
}
for _, child := range n.Children {
if child.Type == CommandNode || child.Type == ArgumentNode {
walk(child)
if node, ok := nd.(*Node); ok {
if hide && node.Hidden {
return next(nil)
}
if len(node.Children) == 0 && node.Type != ApplicationNode {
out = append(out, node)
}
}
}
for _, child := range n.Children {
walk(child)
}
return next(nil)
})
return
}
@@ -289,6 +293,8 @@ func (v *Value) Reset() error {
return nil
}
func (*Value) node() {}
// A Positional represents a non-branching command-line positional argument.
type Positional = Value
+56
View File
@@ -0,0 +1,56 @@
package kong
import (
"fmt"
)
// Next should be called by Visitor to proceed with the walk.
type Next func(error) error
// Visitor can be used to walk all nodes in the model.
type Visitor func(node Visitable, next Next) error
// Visit all nodes.
func Visit(node Visitable, visitor Visitor) error {
return visitor(node, func(err error) error {
if err != nil {
return err
}
switch node := node.(type) {
case *Application:
return visitNodeChildren(node.Node, visitor)
case *Node:
return visitNodeChildren(node, visitor)
case *Value:
case *Flag:
return Visit(node.Value, visitor)
default:
panic(fmt.Sprintf("unsupported node type %T", node))
}
return nil
})
}
func visitNodeChildren(node *Node, visitor Visitor) error {
if node.Argument != nil {
if err := Visit(node.Argument, visitor); err != nil {
return err
}
}
for _, flag := range node.Flags {
if err := Visit(flag, visitor); err != nil {
return err
}
}
for _, pos := range node.Positional {
if err := Visit(pos, visitor); err != nil {
return err
}
}
for _, child := range node.Children {
if err := Visit(child, visitor); err != nil {
return err
}
}
return nil
}