Tracing parser (#11)

* Add tracing to the parser.

* Synthesize a --help flag.

* Parsing now occurs in multiple phases.

1. Reset target.
2. Parse command-line into a "trace" (no values are written to target).
3. Apply traced, parsed values to the target fields.

This is another step in facilitating context-sensitive help and
completion.

* Detect duplicate flags.
This commit is contained in:
Alec Thomas
2018-05-22 00:07:43 -04:00
committed by Gerald Kaszuba
parent 3eb5e285ed
commit ab5cf7e6ef
8 changed files with 247 additions and 77 deletions
+28 -12
View File
@@ -22,21 +22,32 @@ func build(ast interface{}) (app *Application, err error) {
return nil, fmt.Errorf("expected a pointer to a struct but got %T", ast) return nil, fmt.Errorf("expected a pointer to a struct but got %T", ast)
} }
node, err := buildNode(iv, true) app = &Application{
if err != nil { // Synthesize a --help flag.
return node, err HelpFlag: &Flag{
Value: Value{
Name: "help",
Help: "Show context-sensitive help.",
Flag: true,
Value: reflect.New(reflect.TypeOf(false)).Elem(),
Decoder: kindDecoders[reflect.Bool],
}},
} }
node := buildNode(iv, map[string]bool{"help": true}, true)
if len(node.Positional) > 0 && len(node.Children) > 0 { if len(node.Positional) > 0 && len(node.Children) > 0 {
return nil, fmt.Errorf("can't mix positional arguments and branching arguments on %T", ast) return nil, fmt.Errorf("can't mix positional arguments and branching arguments on %T", ast)
} }
return node, nil // Prepend --help flag.
node.Flags = append([]*Flag{app.HelpFlag}, node.Flags...)
app.Node = *node
return app, nil
} }
func dashedString(s string) string { func dashedString(s string) string {
return strings.Join(camelCase(s), "-") return strings.Join(camelCase(s), "-")
} }
func buildNode(v reflect.Value, cmd bool) (*Node, error) { func buildNode(v reflect.Value, seenFlags map[string]bool, cmd bool) *Node {
node := &Node{} node := &Node{}
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
ft := v.Type().Field(i) ft := v.Type().Field(i)
@@ -52,7 +63,7 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
tag, err := parseTag(fv, ft.Tag.Get("kong")) tag, err := parseTag(fv, ft.Tag.Get("kong"))
if err != nil { if err != nil {
return nil, err fail("%s", err)
} }
decoder := DecoderForField(tag.Type, ft) decoder := DecoderForField(tag.Type, ft)
@@ -66,10 +77,7 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
// Nested structs are either commands or args. // Nested structs are either commands or args.
if ft.Type.Kind() == reflect.Struct && (cmd || tag.Arg) { if ft.Type.Kind() == reflect.Struct && (cmd || tag.Arg) {
child, err := buildNode(fv, false) child := buildNode(fv, seenFlags, false)
if err != nil {
return nil, err
}
child.Help = tag.Help child.Help = tag.Help
// A branching argument. This is a bit hairy, as we let buildNode() do the parsing, then check that // A branching argument. This is a bit hairy, as we let buildNode() do the parsing, then check that
@@ -115,7 +123,6 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
Default: tag.Default, Default: tag.Default,
Decoder: decoder, Decoder: decoder,
Value: fv, Value: fv,
Field: ft,
// Flags are optional by default, and args are required by default. // Flags are optional by default, and args are required by default.
Required: (flag && tag.Required) || (tag.Arg && !tag.Optional), Required: (flag && tag.Required) || (tag.Arg && !tag.Optional),
@@ -124,6 +131,10 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
if tag.Arg { if tag.Arg {
node.Positional = append(node.Positional, &value) node.Positional = append(node.Positional, &value)
} else { } else {
if seenFlags[value.Name] {
fail("duplicate flag --%s", value.Name)
}
seenFlags[value.Name] = true
node.Flags = append(node.Flags, &Flag{ node.Flags = append(node.Flags, &Flag{
Value: value, Value: value,
Short: tag.Short, Short: tag.Short,
@@ -134,6 +145,11 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
} }
} }
// "Unsee" flags.
for _, flag := range node.Flags {
delete(seenFlags, flag.Name)
}
// Scan through argument positionals to ensure optional is never before a required. // Scan through argument positionals to ensure optional is never before a required.
last := true last := true
for _, p := range node.Positional { for _, p := range node.Positional {
@@ -144,5 +160,5 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) {
last = p.Required last = p.Required
} }
return node, nil return node
} }
+136 -32
View File
@@ -2,28 +2,103 @@ package kong
import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"
) )
type ParseContext struct { // ParseTrace records the nodes and parsed values from the current command-line.
Scan *Scanner type ParseTrace struct {
Command []string // One of these will be non-nil.
Flags []*Flag Positional *Value
Flag *Flag
Argument *Argument
Command *Command
// Parsed value for non-commands.
Value reflect.Value
} }
func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo type ParseContext struct {
positional := 0 Trace []*ParseTrace // A trace through parsed nodes.
p.Flags = append(p.Flags, node.Flags...)
for token := p.Scan.Pop(); token.Type != EOLToken; token = p.Scan.Pop() { command []string // Full command path.
flags []*Flag // Accumulated available flags.
node *Node // Current node being parsed.
args []string
app *Application
scan *Scanner
}
// Trace parses the command-line, validating and collecting matching grammar nodes.
func Trace(args []string, app *Application) (*ParseContext, error) {
p := &ParseContext{
app: app,
args: args,
}
err := p.reset(&p.app.Node)
if err != nil {
return nil, err
}
return p, p.trace(&p.app.Node)
}
// FlagValue returns the set value of a flag, if it was encountered and exists.
func (p *ParseContext) FlagValue(flag *Flag) reflect.Value {
for _, trace := range p.Trace {
if trace.Flag == flag {
return trace.Value
}
}
return reflect.Value{}
}
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (p *ParseContext) reset(node *Node) error {
p.scan = Scan(p.args...)
for _, flag := range node.Flags {
err := flag.Value.Reset()
if err != nil {
return err
}
}
for _, pos := range node.Positional {
err := pos.Reset()
if err != nil {
return err
}
}
for _, branch := range node.Children {
if branch.Argument != nil {
arg := branch.Argument.Argument
err := arg.Reset()
if err != nil {
return err
}
p.reset(&branch.Argument.Node)
} else {
p.reset(branch.Command)
}
}
return nil
}
func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
positional := 0
p.node = node
p.flags = append(p.flags, node.Flags...)
for !p.scan.Peek().IsEOL() {
token := p.scan.Peek()
switch token.Type { switch token.Type {
case UntypedToken: case UntypedToken:
switch { switch {
// -- indicates end of parsing. All remaining arguments are treated as positional arguments only. // -- indicates end of parsing. All remaining arguments are treated as positional arguments only.
case token.Value == "--": case token.Value == "--":
p.scan.Pop()
args := []string{} args := []string{}
for { for {
token = p.Scan.Pop() token = p.scan.Pop()
if token.Type == EOLToken { if token.Type == EOLToken {
break break
} }
@@ -31,42 +106,46 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo
} }
// Note: tokens must be pushed in reverse order. // Note: tokens must be pushed in reverse order.
for i := range args { for i := range args {
p.Scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) p.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken)
} }
// Long flag. // Long flag.
case strings.HasPrefix(token.Value, "--"): case strings.HasPrefix(token.Value, "--"):
p.scan.Pop()
// Parse it and push the tokens. // Parse it and push the tokens.
parts := strings.SplitN(token.Value[2:], "=", 2) parts := strings.SplitN(token.Value[2:], "=", 2)
if len(parts) > 1 { if len(parts) > 1 {
p.Scan.PushTyped(parts[1], FlagValueToken) p.scan.PushTyped(parts[1], FlagValueToken)
} }
p.Scan.PushTyped(parts[0], FlagToken) p.scan.PushTyped(parts[0], FlagToken)
// Short flag. // Short flag.
case strings.HasPrefix(token.Value, "-"): case strings.HasPrefix(token.Value, "-"):
p.scan.Pop()
// Note: tokens must be pushed in reverse order. // Note: tokens must be pushed in reverse order.
p.Scan.PushTyped(token.Value[2:], ShortFlagTailToken) p.scan.PushTyped(token.Value[2:], ShortFlagTailToken)
p.Scan.PushTyped(token.Value[1:2], ShortFlagToken) p.scan.PushTyped(token.Value[1:2], ShortFlagToken)
default: default:
p.Scan.PushTyped(token.Value, PositionalArgumentToken) p.scan.Pop()
p.scan.PushTyped(token.Value, PositionalArgumentToken)
} }
case ShortFlagTailToken: case ShortFlagTailToken:
p.scan.Pop()
// Note: tokens must be pushed in reverse order. // Note: tokens must be pushed in reverse order.
p.Scan.PushTyped(token.Value[1:], ShortFlagTailToken) p.scan.PushTyped(token.Value[1:], ShortFlagTailToken)
p.Scan.PushTyped(token.Value[0:1], ShortFlagToken) p.scan.PushTyped(token.Value[0:1], ShortFlagToken)
case FlagToken: case FlagToken:
if err := matchFlags(p.Flags, token, p.Scan, func(f *Flag) bool { if err := p.matchFlags(func(f *Flag) bool {
return f.Name == token.Value return f.Name == token.Value
}); err != nil { }); err != nil {
return err return err
} }
case ShortFlagToken: case ShortFlagToken:
if err := matchFlags(p.Flags, token, p.Scan, func(f *Flag) bool { if err := p.matchFlags(func(f *Flag) bool {
return string(f.Name) == token.Value return string(f.Name) == token.Value
}); err != nil { }); err != nil {
return err return err
@@ -76,15 +155,15 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo
return fmt.Errorf("unexpected flag argument %q", token.Value) return fmt.Errorf("unexpected flag argument %q", token.Value)
case PositionalArgumentToken: case PositionalArgumentToken:
p.Scan.PushToken(token)
// Ensure we've consumed all positional arguments. // Ensure we've consumed all positional arguments.
if positional < len(node.Positional) { if positional < len(node.Positional) {
arg := node.Positional[positional] arg := node.Positional[positional]
err := arg.Decode(p.Scan) value, err := arg.Parse(p.scan)
if err != nil { if err != nil {
return err return err
} }
p.Command = append(p.Command, "<"+arg.Name+">") p.command = append(p.command, "<"+arg.Name+">")
p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value})
positional++ positional++
break break
} }
@@ -94,16 +173,18 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo
switch { switch {
case branch.Command != nil: case branch.Command != nil:
if branch.Command.Name == token.Value { if branch.Command.Name == token.Value {
p.Scan.Pop() p.scan.Pop()
p.Command = append(p.Command, branch.Command.Name) p.command = append(p.command, branch.Command.Name)
return p.applyNode(branch.Command) p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command})
return p.trace(branch.Command)
} }
case branch.Argument != nil: case branch.Argument != nil:
arg := branch.Argument.Argument arg := branch.Argument.Argument
if err := arg.Decode(p.Scan); err == nil { if value, err := arg.Parse(p.scan); err == nil {
p.Command = append(p.Command, "<"+arg.Name+">") p.command = append(p.command, "<"+arg.Name+">")
return p.applyNode(&branch.Argument.Node) p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value})
return p.trace(&branch.Argument.Node)
} }
} }
} }
@@ -122,13 +203,33 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo
return err return err
} }
if err := checkMissingFlags(node.Children, p.Flags); err != nil { if err := checkMissingFlags(node.Children, p.flags); err != nil {
return err return err
} }
return nil return nil
} }
// Apply traced context to the target grammar.
func (p *ParseContext) Apply() (string, error) {
path := []string{}
for _, trace := range p.Trace {
switch {
case trace.Argument != nil:
path = append(path, "<"+trace.Argument.Name+">")
trace.Argument.Argument.Apply(trace.Value)
case trace.Command != nil:
path = append(path, trace.Command.Name)
case trace.Flag != nil:
trace.Flag.Value.Apply(trace.Value)
case trace.Positional != nil:
path = append(path, "<"+trace.Positional.Name+">")
trace.Positional.Apply(trace.Value)
}
}
return strings.Join(path, " "), nil
}
func checkMissingFlags(children []*Branch, flags []*Flag) error { func checkMissingFlags(children []*Branch, flags []*Flag) error {
// Only check required missing fields at the last child. // Only check required missing fields at the last child.
if len(children) > 0 { if len(children) > 0 {
@@ -186,7 +287,8 @@ func checkMissingPositionals(positional int, values []*Value) error {
return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " "))
} }
func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) {
token := p.scan.Peek()
defer func() { defer func() {
msg := recover() msg := recover()
if test, ok := msg.(Error); ok { if test, ok := msg.(Error); ok {
@@ -195,13 +297,15 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag)
panic(msg) panic(msg)
} }
}() }()
for _, flag := range flags { for _, flag := range p.flags {
// Found a matching flag. // Found a matching flag.
if flag.Name == token.Value { if flag.Name == token.Value {
err := flag.Decode(scan) p.scan.Pop()
value, err := flag.Parse(p.scan)
if err != nil { if err != nil {
return err return err
} }
p.Trace = append(p.Trace, &ParseTrace{Flag: flag, Value: value})
return nil return nil
} }
} }
+1 -1
View File
@@ -222,7 +222,7 @@ func floatDecoder(bits int) DecoderFunc {
func sliceDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { func sliceDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
el := target.Type().Elem() el := target.Type().Elem()
sep, ok := ctx.Value.Field.Tag.Lookup("sep") sep, ok := ctx.Value.Tag.Lookup("sep")
if !ok { if !ok {
sep = "," sep = ","
} }
+3
View File
@@ -9,6 +9,9 @@ const defaultHelp = `{{- with .Application -}}
usage: {{.Name}} usage: {{.Name}}
{{.Help}} {{.Help}}
{{range .Flags}}
--{{.Name}}
{{end}}
{{- end -}} {{- end -}}
` `
+6 -24
View File
@@ -5,7 +5,6 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"text/template" "text/template"
) )
@@ -65,31 +64,14 @@ func (k *Kong) Parse(args []string) (command string, err error) {
panic(msg) panic(msg)
} }
}() }()
k.reset(k.Model) ctx, err := Trace(args, k.Model)
ctx := &ParseContext{ if err != nil {
Scan: Scan(args...), return "", err
}
err = ctx.applyNode(k.Model)
return strings.Join(ctx.Command, " "), err
}
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (k *Kong) reset(node *Node) {
for _, flag := range node.Flags {
flag.Value.Reset()
}
for _, pos := range node.Positional {
pos.Reset()
}
for _, branch := range node.Children {
if branch.Argument != nil {
arg := branch.Argument.Argument
arg.Reset()
k.reset(&branch.Argument.Node)
} else {
k.reset(branch.Command)
} }
if value := ctx.FlagValue(k.Model.HelpFlag); value.IsValid() && value.Bool() {
return "", nil
} }
return ctx.Apply()
} }
func (k *Kong) Errorf(format string, args ...interface{}) { func (k *Kong) Errorf(format string, args ...interface{}) {
+46 -1
View File
@@ -8,7 +8,9 @@ import (
func mustNew(t *testing.T, cli interface{}) *Kong { func mustNew(t *testing.T, cli interface{}) *Kong {
t.Helper() t.Helper()
parser, err := New(cli) parser, err := New(cli, ExitFunction(func(int) {
t.Fatalf("unexpected exit()")
}))
require.NoError(t, err) require.NoError(t, err)
return parser return parser
} }
@@ -307,3 +309,46 @@ func TestEscapedQuote(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "i don't know", cli.DoYouKnow) require.Equal(t, "i don't know", cli.DoYouKnow)
} }
func TestInvalidDefaultErrors(t *testing.T) {
var cli struct {
Flag int `kong:"default='foo'"`
}
p := mustNew(t, &cli)
_, err := p.Parse(nil)
require.Error(t, err)
}
func TestHelp(t *testing.T) {
var cli struct {
Flag string
}
p := mustNew(t, &cli)
_, err := p.Parse([]string{"--flag=hello", "--help"})
require.NoError(t, err)
require.NotEqual(t, "hello", cli.Flag)
}
func TestDuplicateFlag(t *testing.T) {
var cli struct {
Flag bool
Cmd struct {
Flag bool
}
}
_, err := New(&cli)
require.Error(t, err)
}
func TestDuplicateFlagOnPeerCommandIsOkay(t *testing.T) {
var cli struct {
Cmd1 struct {
Flag bool
}
Cmd2 struct {
Flag bool
}
}
_, err := New(&cli)
require.NoError(t, err)
}
+23 -7
View File
@@ -2,7 +2,10 @@ package kong
import "reflect" import "reflect"
type Application = Node type Application struct {
Node
HelpFlag *Flag
}
// A Branch is a command or positional argument that results in a branch in the command tree. // A Branch is a command or positional argument that results in a branch in the command tree.
type Branch struct { type Branch struct {
@@ -27,27 +30,40 @@ type Value struct {
Help string Help string
Default string Default string
Decoder Decoder Decoder Decoder
Field reflect.StructField Tag reflect.StructTag
Value reflect.Value Value reflect.Value
Required bool Required bool
Set bool // Used with Required to test if a value has been given. Set bool // Used with Required to test if a value has been given.
Format string // Formatting directive, if applicable. Format string // Formatting directive, if applicable.
} }
func (v *Value) Decode(scan *Scanner) error { // Parse tokens into value, parse, and validate, but do not write to the field.
err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value) func (v *Value) Parse(scan *Scanner) (reflect.Value, error) {
value := reflect.New(v.Value.Type()).Elem()
err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, value)
if err == nil { if err == nil {
v.Set = true v.Set = true
} }
return err return value, err
} }
func (v *Value) Reset() { // Apply value to field.
func (v *Value) Apply(value reflect.Value) {
v.Value.Set(value)
v.Set = true
}
func (v *Value) Reset() error {
v.Value.Set(reflect.Zero(v.Value.Type())) v.Value.Set(reflect.Zero(v.Value.Type()))
if v.Default != "" { if v.Default != "" {
v.Decode(Scan(v.Default)) value, err := v.Parse(Scan(v.Default))
if err != nil {
return err
}
v.Apply(value)
v.Set = false v.Set = false
} }
return nil
} }
type Positional = Value type Positional = Value
+4
View File
@@ -40,6 +40,10 @@ func (t Token) String() string {
} }
} }
func (t Token) IsEOL() bool {
return t.Type == EOLToken
}
func (t Token) IsAny(types ...TokenType) bool { func (t Token) IsAny(types ...TokenType) bool {
for _, typ := range types { for _, typ := range types {
if t.Type == typ { if t.Type == typ {