diff --git a/build.go b/build.go index 93c6520..2e70d9d 100644 --- a/build.go +++ b/build.go @@ -22,21 +22,32 @@ func build(ast interface{}) (app *Application, err error) { return nil, fmt.Errorf("expected a pointer to a struct but got %T", ast) } - node, err := buildNode(iv, true) - if err != nil { - return node, err + app = &Application{ + // Synthesize a --help flag. + HelpFlag: &Flag{ + Value: Value{ + Name: "help", + Help: "Show context-sensitive help.", + Flag: true, + Value: reflect.New(reflect.TypeOf(false)).Elem(), + Decoder: kindDecoders[reflect.Bool], + }}, } + node := buildNode(iv, map[string]bool{"help": true}, true) if len(node.Positional) > 0 && len(node.Children) > 0 { return nil, fmt.Errorf("can't mix positional arguments and branching arguments on %T", ast) } - return node, nil + // Prepend --help flag. + node.Flags = append([]*Flag{app.HelpFlag}, node.Flags...) + app.Node = *node + return app, nil } func dashedString(s string) string { return strings.Join(camelCase(s), "-") } -func buildNode(v reflect.Value, cmd bool) (*Node, error) { +func buildNode(v reflect.Value, seenFlags map[string]bool, cmd bool) *Node { node := &Node{} for i := 0; i < v.NumField(); i++ { ft := v.Type().Field(i) @@ -52,7 +63,7 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { tag, err := parseTag(fv, ft.Tag.Get("kong")) if err != nil { - return nil, err + fail("%s", err) } decoder := DecoderForField(tag.Type, ft) @@ -66,10 +77,7 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { // Nested structs are either commands or args. if ft.Type.Kind() == reflect.Struct && (cmd || tag.Arg) { - child, err := buildNode(fv, false) - if err != nil { - return nil, err - } + child := buildNode(fv, seenFlags, false) child.Help = tag.Help // A branching argument. This is a bit hairy, as we let buildNode() do the parsing, then check that @@ -115,7 +123,6 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { Default: tag.Default, Decoder: decoder, Value: fv, - Field: ft, // Flags are optional by default, and args are required by default. Required: (flag && tag.Required) || (tag.Arg && !tag.Optional), @@ -124,6 +131,10 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { if tag.Arg { node.Positional = append(node.Positional, &value) } else { + if seenFlags[value.Name] { + fail("duplicate flag --%s", value.Name) + } + seenFlags[value.Name] = true node.Flags = append(node.Flags, &Flag{ Value: value, Short: tag.Short, @@ -134,6 +145,11 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { } } + // "Unsee" flags. + for _, flag := range node.Flags { + delete(seenFlags, flag.Name) + } + // Scan through argument positionals to ensure optional is never before a required. last := true for _, p := range node.Positional { @@ -144,5 +160,5 @@ func buildNode(v reflect.Value, cmd bool) (*Node, error) { last = p.Required } - return node, nil + return node } diff --git a/context.go b/context.go index d712909..babbbaf 100644 --- a/context.go +++ b/context.go @@ -2,28 +2,103 @@ package kong import ( "fmt" + "reflect" "strings" ) -type ParseContext struct { - Scan *Scanner - Command []string - Flags []*Flag +// ParseTrace records the nodes and parsed values from the current command-line. +type ParseTrace struct { + // One of these will be non-nil. + Positional *Value + Flag *Flag + Argument *Argument + Command *Command + + // Parsed value for non-commands. + Value reflect.Value } -func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo - positional := 0 - p.Flags = append(p.Flags, node.Flags...) +type ParseContext struct { + Trace []*ParseTrace // A trace through parsed nodes. - for token := p.Scan.Pop(); token.Type != EOLToken; token = p.Scan.Pop() { + command []string // Full command path. + flags []*Flag // Accumulated available flags. + node *Node // Current node being parsed. + + args []string + app *Application + scan *Scanner +} + +// Trace parses the command-line, validating and collecting matching grammar nodes. +func Trace(args []string, app *Application) (*ParseContext, error) { + p := &ParseContext{ + app: app, + args: args, + } + err := p.reset(&p.app.Node) + if err != nil { + return nil, err + } + return p, p.trace(&p.app.Node) +} + +// FlagValue returns the set value of a flag, if it was encountered and exists. +func (p *ParseContext) FlagValue(flag *Flag) reflect.Value { + for _, trace := range p.Trace { + if trace.Flag == flag { + return trace.Value + } + } + return reflect.Value{} +} + +// Recursively reset values to defaults (as specified in the grammar) or the zero value. +func (p *ParseContext) reset(node *Node) error { + p.scan = Scan(p.args...) + for _, flag := range node.Flags { + err := flag.Value.Reset() + if err != nil { + return err + } + } + for _, pos := range node.Positional { + err := pos.Reset() + if err != nil { + return err + } + } + for _, branch := range node.Children { + if branch.Argument != nil { + arg := branch.Argument.Argument + err := arg.Reset() + if err != nil { + return err + } + p.reset(&branch.Argument.Node) + } else { + p.reset(branch.Command) + } + } + return nil +} + +func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo + positional := 0 + p.node = node + p.flags = append(p.flags, node.Flags...) + + for !p.scan.Peek().IsEOL() { + token := p.scan.Peek() switch token.Type { case UntypedToken: switch { // -- indicates end of parsing. All remaining arguments are treated as positional arguments only. case token.Value == "--": + p.scan.Pop() args := []string{} for { - token = p.Scan.Pop() + token = p.scan.Pop() if token.Type == EOLToken { break } @@ -31,42 +106,46 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo } // Note: tokens must be pushed in reverse order. for i := range args { - p.Scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) + p.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) } // Long flag. case strings.HasPrefix(token.Value, "--"): + p.scan.Pop() // Parse it and push the tokens. parts := strings.SplitN(token.Value[2:], "=", 2) if len(parts) > 1 { - p.Scan.PushTyped(parts[1], FlagValueToken) + p.scan.PushTyped(parts[1], FlagValueToken) } - p.Scan.PushTyped(parts[0], FlagToken) + p.scan.PushTyped(parts[0], FlagToken) // Short flag. case strings.HasPrefix(token.Value, "-"): + p.scan.Pop() // Note: tokens must be pushed in reverse order. - p.Scan.PushTyped(token.Value[2:], ShortFlagTailToken) - p.Scan.PushTyped(token.Value[1:2], ShortFlagToken) + p.scan.PushTyped(token.Value[2:], ShortFlagTailToken) + p.scan.PushTyped(token.Value[1:2], ShortFlagToken) default: - p.Scan.PushTyped(token.Value, PositionalArgumentToken) + p.scan.Pop() + p.scan.PushTyped(token.Value, PositionalArgumentToken) } case ShortFlagTailToken: + p.scan.Pop() // Note: tokens must be pushed in reverse order. - p.Scan.PushTyped(token.Value[1:], ShortFlagTailToken) - p.Scan.PushTyped(token.Value[0:1], ShortFlagToken) + p.scan.PushTyped(token.Value[1:], ShortFlagTailToken) + p.scan.PushTyped(token.Value[0:1], ShortFlagToken) case FlagToken: - if err := matchFlags(p.Flags, token, p.Scan, func(f *Flag) bool { + if err := p.matchFlags(func(f *Flag) bool { return f.Name == token.Value }); err != nil { return err } case ShortFlagToken: - if err := matchFlags(p.Flags, token, p.Scan, func(f *Flag) bool { + if err := p.matchFlags(func(f *Flag) bool { return string(f.Name) == token.Value }); err != nil { return err @@ -76,15 +155,15 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo return fmt.Errorf("unexpected flag argument %q", token.Value) case PositionalArgumentToken: - p.Scan.PushToken(token) // Ensure we've consumed all positional arguments. if positional < len(node.Positional) { arg := node.Positional[positional] - err := arg.Decode(p.Scan) + value, err := arg.Parse(p.scan) if err != nil { return err } - p.Command = append(p.Command, "<"+arg.Name+">") + p.command = append(p.command, "<"+arg.Name+">") + p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value}) positional++ break } @@ -94,16 +173,18 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo switch { case branch.Command != nil: if branch.Command.Name == token.Value { - p.Scan.Pop() - p.Command = append(p.Command, branch.Command.Name) - return p.applyNode(branch.Command) + p.scan.Pop() + p.command = append(p.command, branch.Command.Name) + p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command}) + return p.trace(branch.Command) } case branch.Argument != nil: arg := branch.Argument.Argument - if err := arg.Decode(p.Scan); err == nil { - p.Command = append(p.Command, "<"+arg.Name+">") - return p.applyNode(&branch.Argument.Node) + if value, err := arg.Parse(p.scan); err == nil { + p.command = append(p.command, "<"+arg.Name+">") + p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value}) + return p.trace(&branch.Argument.Node) } } } @@ -122,13 +203,33 @@ func (p *ParseContext) applyNode(node *Node) (err error) { // nolint: gocyclo return err } - if err := checkMissingFlags(node.Children, p.Flags); err != nil { + if err := checkMissingFlags(node.Children, p.flags); err != nil { return err } return nil } +// Apply traced context to the target grammar. +func (p *ParseContext) Apply() (string, error) { + path := []string{} + for _, trace := range p.Trace { + switch { + case trace.Argument != nil: + path = append(path, "<"+trace.Argument.Name+">") + trace.Argument.Argument.Apply(trace.Value) + case trace.Command != nil: + path = append(path, trace.Command.Name) + case trace.Flag != nil: + trace.Flag.Value.Apply(trace.Value) + case trace.Positional != nil: + path = append(path, "<"+trace.Positional.Name+">") + trace.Positional.Apply(trace.Value) + } + } + return strings.Join(path, " "), nil +} + func checkMissingFlags(children []*Branch, flags []*Flag) error { // Only check required missing fields at the last child. if len(children) > 0 { @@ -186,7 +287,8 @@ func checkMissingPositionals(positional int, values []*Value) error { return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } -func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { +func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) { + token := p.scan.Peek() defer func() { msg := recover() if test, ok := msg.(Error); ok { @@ -195,13 +297,15 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) panic(msg) } }() - for _, flag := range flags { + for _, flag := range p.flags { // Found a matching flag. if flag.Name == token.Value { - err := flag.Decode(scan) + p.scan.Pop() + value, err := flag.Parse(p.scan) if err != nil { return err } + p.Trace = append(p.Trace, &ParseTrace{Flag: flag, Value: value}) return nil } } diff --git a/decoders.go b/decoders.go index 65d61a8..f477417 100644 --- a/decoders.go +++ b/decoders.go @@ -222,7 +222,7 @@ func floatDecoder(bits int) DecoderFunc { func sliceDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { el := target.Type().Elem() - sep, ok := ctx.Value.Field.Tag.Lookup("sep") + sep, ok := ctx.Value.Tag.Lookup("sep") if !ok { sep = "," } diff --git a/help.go b/help.go index 7066aa8..b9e9625 100644 --- a/help.go +++ b/help.go @@ -9,6 +9,9 @@ const defaultHelp = `{{- with .Application -}} usage: {{.Name}} {{.Help}} +{{range .Flags}} +--{{.Name}} +{{end}} {{- end -}} ` diff --git a/kong.go b/kong.go index a40aad2..1afd862 100644 --- a/kong.go +++ b/kong.go @@ -5,7 +5,6 @@ import ( "io" "os" "path/filepath" - "strings" "text/template" ) @@ -65,31 +64,14 @@ func (k *Kong) Parse(args []string) (command string, err error) { panic(msg) } }() - k.reset(k.Model) - ctx := &ParseContext{ - Scan: Scan(args...), + ctx, err := Trace(args, k.Model) + if err != nil { + return "", err } - err = ctx.applyNode(k.Model) - return strings.Join(ctx.Command, " "), err -} - -// Recursively reset values to defaults (as specified in the grammar) or the zero value. -func (k *Kong) reset(node *Node) { - for _, flag := range node.Flags { - flag.Value.Reset() - } - for _, pos := range node.Positional { - pos.Reset() - } - for _, branch := range node.Children { - if branch.Argument != nil { - arg := branch.Argument.Argument - arg.Reset() - k.reset(&branch.Argument.Node) - } else { - k.reset(branch.Command) - } + if value := ctx.FlagValue(k.Model.HelpFlag); value.IsValid() && value.Bool() { + return "", nil } + return ctx.Apply() } func (k *Kong) Errorf(format string, args ...interface{}) { diff --git a/kong_test.go b/kong_test.go index 385e659..bb9a6e1 100644 --- a/kong_test.go +++ b/kong_test.go @@ -8,7 +8,9 @@ import ( func mustNew(t *testing.T, cli interface{}) *Kong { t.Helper() - parser, err := New(cli) + parser, err := New(cli, ExitFunction(func(int) { + t.Fatalf("unexpected exit()") + })) require.NoError(t, err) return parser } @@ -307,3 +309,46 @@ func TestEscapedQuote(t *testing.T) { require.NoError(t, err) require.Equal(t, "i don't know", cli.DoYouKnow) } + +func TestInvalidDefaultErrors(t *testing.T) { + var cli struct { + Flag int `kong:"default='foo'"` + } + p := mustNew(t, &cli) + _, err := p.Parse(nil) + require.Error(t, err) +} + +func TestHelp(t *testing.T) { + var cli struct { + Flag string + } + p := mustNew(t, &cli) + _, err := p.Parse([]string{"--flag=hello", "--help"}) + require.NoError(t, err) + require.NotEqual(t, "hello", cli.Flag) +} + +func TestDuplicateFlag(t *testing.T) { + var cli struct { + Flag bool + Cmd struct { + Flag bool + } + } + _, err := New(&cli) + require.Error(t, err) +} + +func TestDuplicateFlagOnPeerCommandIsOkay(t *testing.T) { + var cli struct { + Cmd1 struct { + Flag bool + } + Cmd2 struct { + Flag bool + } + } + _, err := New(&cli) + require.NoError(t, err) +} diff --git a/model.go b/model.go index efa27aa..388f6d7 100644 --- a/model.go +++ b/model.go @@ -2,7 +2,10 @@ package kong import "reflect" -type Application = Node +type Application struct { + Node + HelpFlag *Flag +} // A Branch is a command or positional argument that results in a branch in the command tree. type Branch struct { @@ -27,27 +30,40 @@ type Value struct { Help string Default string Decoder Decoder - Field reflect.StructField + Tag reflect.StructTag Value reflect.Value Required bool Set bool // Used with Required to test if a value has been given. Format string // Formatting directive, if applicable. } -func (v *Value) Decode(scan *Scanner) error { - err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value) +// Parse tokens into value, parse, and validate, but do not write to the field. +func (v *Value) Parse(scan *Scanner) (reflect.Value, error) { + value := reflect.New(v.Value.Type()).Elem() + err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, value) if err == nil { v.Set = true } - return err + return value, err } -func (v *Value) Reset() { +// Apply value to field. +func (v *Value) Apply(value reflect.Value) { + v.Value.Set(value) + v.Set = true +} + +func (v *Value) Reset() error { v.Value.Set(reflect.Zero(v.Value.Type())) if v.Default != "" { - v.Decode(Scan(v.Default)) + value, err := v.Parse(Scan(v.Default)) + if err != nil { + return err + } + v.Apply(value) v.Set = false } + return nil } type Positional = Value diff --git a/scanner.go b/scanner.go index 5d2c268..9564787 100644 --- a/scanner.go +++ b/scanner.go @@ -40,6 +40,10 @@ func (t Token) String() string { } } +func (t Token) IsEOL() bool { + return t.Type == EOLToken +} + func (t Token) IsAny(types ...TokenType) bool { for _, typ := range types { if t.Type == typ {