diff --git a/build.go b/build.go index 2e38444..63a301f 100644 --- a/build.go +++ b/build.go @@ -36,10 +36,10 @@ func buildNode(v reflect.Value) *Node { if name == "" { name = strings.ToLower(strings.Join(camelCase(ft.Name), "-")) } - help := ft.Tag.Get("help") - decoder, err := DecoderForField(ft) - if err != nil && ft.Type.Kind() != reflect.Struct { - panic(err) + decoder := DecoderForField(ft) + help, ok := ft.Tag.Lookup("help") + if !ok { + continue } dflt := ft.Tag.Get("default") placeholder := ft.Tag.Get("placeholder") @@ -53,11 +53,13 @@ func buildNode(v reflect.Value) *Node { // group := ft.Tag.Get("group") _, required := ft.Tag.Lookup("required") _, optional := ft.Tag.Lookup("optional") + // Force field to be an argument, not a flag. _, arg := ft.Tag.Lookup("arg") env := ft.Tag.Get("env") + format := ft.Tag.Get("format") - // Nested structs are commands. - if ft.Type.Kind() == reflect.Struct { + // Nested structs are either commands or args. + if ft.Type.Kind() == reflect.Struct && decoder == nil { child := buildNode(fv) child.Help = help @@ -65,8 +67,8 @@ func buildNode(v reflect.Value) *Node { // a positional argument is provided to the child, and move it to the branching argument field. if arg { if len(child.Positional) == 0 { - panic(fmt.Errorf("positional branch %s.%s must have at least one child positional argument", - v.Type().Name(), ft.Name)) + fail("positional branch %s.%s must have at least one child positional argument", + v.Type().Name(), ft.Name) } value := child.Positional[0] child.Positional = child.Positional[1:] @@ -83,6 +85,9 @@ func buildNode(v reflect.Value) *Node { node.Children = append(node.Children, &Branch{Command: child}) } } else { + if decoder == nil { + fail("no decoder for %s.%s (of type %s)", v.Type(), ft.Name, ft.Type) + } value := Value{ Name: name, Help: help, @@ -90,6 +95,7 @@ func buildNode(v reflect.Value) *Node { Value: fv, Field: ft, Required: !optional || required, + Format: format, } if arg { node.Positional = append(node.Positional, &value) diff --git a/decoders.go b/decoders.go index 88c5dbd..d1b2c10 100644 --- a/decoders.go +++ b/decoders.go @@ -5,6 +5,7 @@ import ( "reflect" "strconv" "strings" + "time" ) type DecoderContext struct { @@ -81,29 +82,34 @@ var _ NamedDecoder = &namedDecoder{} var ( namedDecoders = map[string]NamedDecoder{} typeDecoders = map[reflect.Type]TypeDecoder{} - kindDecoders map[reflect.Kind]KindDecoder + kindDecoders = map[reflect.Kind]KindDecoder{} ) // DecoderForField finds a decoder for a struct field. -func DecoderForField(field reflect.StructField) (Decoder, error) { +// +// Will return nil if a decoder can not be determined. +func DecoderForField(field reflect.StructField) Decoder { name, ok := field.Tag.Lookup("type") if ok { if decoder, ok := namedDecoders[name]; ok { - return decoder, nil + return decoder } } return DecoderForType(field.Type) } -func DecoderForType(typ reflect.Type) (Decoder, error) { +// DecoderForType finds a decoder via a type or kind. +// +// Will return nil if a decoder can not be determined. +func DecoderForType(typ reflect.Type) Decoder { var decoder Decoder var ok bool if decoder, ok = typeDecoders[typ]; ok { - return decoder, nil + return decoder } else if decoder, ok = kindDecoders[typ.Kind()]; ok { - return decoder, nil + return decoder } - return nil, fmt.Errorf("no decoder for type %s", typ) + return nil } // RegisterDecoder registers decoders. @@ -119,35 +125,59 @@ func RegisterDecoder(decoders ...Decoder) { case NamedDecoder: namedDecoders[decoder.Name()] = decoder default: - panic("unsupported decoder type " + reflect.TypeOf(decoder).String()) + fail("unsupported decoder type " + reflect.TypeOf(decoder).String()) } } } func init() { - kindDecoders = map[reflect.Kind]KindDecoder{ - reflect.Int: NewKindDecoder(reflect.Int, intDecoder), - reflect.Int8: NewKindDecoder(reflect.Int8, intDecoder), - reflect.Int16: NewKindDecoder(reflect.Int16, intDecoder), - reflect.Int32: NewKindDecoder(reflect.Int32, intDecoder), - reflect.Int64: NewKindDecoder(reflect.Int64, intDecoder), - reflect.Uint: NewKindDecoder(reflect.Uint, uintDecoder), - reflect.Uint8: NewKindDecoder(reflect.Uint8, uintDecoder), - reflect.Uint16: NewKindDecoder(reflect.Uint16, uintDecoder), - reflect.Uint32: NewKindDecoder(reflect.Uint32, uintDecoder), - reflect.Uint64: NewKindDecoder(reflect.Uint64, uintDecoder), - reflect.Float32: NewKindDecoder(reflect.Float32, floatDecoder(32)), - reflect.Float64: NewKindDecoder(reflect.Float64, floatDecoder(64)), - reflect.String: NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + RegisterDecoder( + NewKindDecoder(reflect.Int, intDecoder), + NewKindDecoder(reflect.Int8, intDecoder), + NewKindDecoder(reflect.Int16, intDecoder), + NewKindDecoder(reflect.Int32, intDecoder), + NewKindDecoder(reflect.Int64, intDecoder), + NewKindDecoder(reflect.Uint, uintDecoder), + NewKindDecoder(reflect.Uint8, uintDecoder), + NewKindDecoder(reflect.Uint16, uintDecoder), + NewKindDecoder(reflect.Uint32, uintDecoder), + NewKindDecoder(reflect.Uint64, uintDecoder), + NewKindDecoder(reflect.Float32, floatDecoder(32)), + NewKindDecoder(reflect.Float64, floatDecoder(64)), + NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { target.SetString(scan.PopValue("string")) return nil }), - reflect.Bool: NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { target.SetBool(true) return nil }), - reflect.Slice: NewKindDecoder(reflect.Slice, sliceDecoder), + NewKindDecoder(reflect.Slice, sliceDecoder), + NewTypeDecoder(reflect.TypeOf(time.Time{}), timeDecoder), + NewTypeDecoder(reflect.TypeOf(time.Duration(0)), durationDecoder), + ) +} + +func durationDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + d, err := time.ParseDuration(scan.PopValue("duration")) + if err != nil { + return err } + target.Set(reflect.ValueOf(d)) + return nil +} + +func timeDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + fmt := time.RFC3339 + if ctx.Value.Format != "" { + fmt = ctx.Value.Format + } + t, err := time.Parse(fmt, scan.PopValue("time")) + if err != nil { + return err + } + target.Set(reflect.ValueOf(t)) + return nil } func intDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { @@ -186,9 +216,9 @@ func sliceDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) erro sep = "," } childScanner := Scan(strings.Split(scan.PopValue("list"), sep)...) - childDecoder, err := DecoderForType(el) - if err != nil { - return err + childDecoder := DecoderForType(el) + if childDecoder == nil { + return fmt.Errorf("no decoder for element type of %s", target.Type()) } for childScanner.Peek().Type != EOLToken { childValue := reflect.New(el).Elem() diff --git a/kong.go b/kong.go index bc04265..c382b40 100644 --- a/kong.go +++ b/kong.go @@ -8,6 +8,16 @@ import ( "strings" ) +type Error struct { + msg string +} + +func (e Error) Error() string { return e.msg } + +func fail(format string, args ...interface{}) { + panic(Error{fmt.Sprintf(format, args...)}) +} + type Kong struct { Model *Application // Termination function (defaults to os.Exit) @@ -35,7 +45,7 @@ func New(name, description string, ast interface{}) (*Kong, error) { func (k *Kong) Parse(args []string) (command string, err error) { defer func() { msg := recover() - if test, ok := msg.(TokenAssertionError); ok { + if test, ok := msg.(Error); ok { err = test } else if msg != nil { panic(msg) @@ -68,7 +78,8 @@ func (k *Kong) reset(node *Node) { } } -func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error) { +func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error) { // nolint: gocyclo + positional := 0 for token := scan.Pop(); token.Type != EOLToken; token = scan.Pop() { switch token.Type { case UntypedToken: @@ -98,6 +109,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error // Short flag. case strings.HasPrefix(token.Value, "-"): + // Note: tokens must be pushed in reverse order. scan.PushTyped(token.Value[2:], ShortFlagTailToken) scan.PushTyped(token.Value[1:2], ShortFlagToken) @@ -106,6 +118,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error } case ShortFlagTailToken: + // Note: tokens must be pushed in reverse order. scan.PushTyped(token.Value[1:], ShortFlagTailToken) scan.PushTyped(token.Value[0:1], ShortFlagToken) @@ -128,6 +141,19 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error case PositionalArgumentToken: scan.PushToken(token) + // Ensure we've consumed all positional arguments. + if positional < len(node.Positional) { + arg := node.Positional[positional] + err := arg.Decoder.Decode(&DecoderContext{Value: arg}, scan, arg.Value) + if err != nil { + return nil, err + } + command = append(command, "<"+arg.Name+">") + positional++ + break + } + + // After positional arguments have been consumed, handle commands and branching arguments. for _, branch := range node.Children { switch { case branch.Command != nil: @@ -165,7 +191,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { defer func() { msg := recover() - if test, ok := msg.(TokenAssertionError); ok { + if test, ok := msg.(Error); ok { err = fmt.Errorf("%s %s", token, test) } else if msg != nil { panic(msg) diff --git a/kong_test.go b/kong_test.go index 37d40de..2f3d96b 100644 --- a/kong_test.go +++ b/kong_test.go @@ -13,7 +13,23 @@ func mustNew(t *testing.T, cli interface{}) *Kong { return parser } -func TestArgument(t *testing.T) { +func TestArgumentSequence(t *testing.T) { + var cli struct { + User struct { + Create struct { + ID int `arg:"" help:""` + First string `arg:"" help:""` + Last string `arg:"" help:""` + } `help:""` + } `help:""` + } + p := mustNew(t, &cli) + cmd, err := p.Parse([]string{"user", "create", "10", "Alec", "Thomas"}) + require.NoError(t, err) + require.Equal(t, "user create ", cmd) +} + +func TestBranchingArgument(t *testing.T) { /* app user create app user delete @@ -21,33 +37,35 @@ func TestArgument(t *testing.T) { */ var cli struct { - Create struct { - Id string `arg:"true"` - First string `arg:"true"` - Last string `arg:"true"` - } + User struct { + Create struct { + ID string `arg:"" help:""` + First string `arg:"" help:""` + Last string `arg:"" help:""` + } `help:""` - // Branching argument. - Id struct { - Id int `arg:"true"` - Flag int - Delete struct{} - Rename struct { - To string - } - } `arg:"true"` + // Branching argument. + ID struct { + ID int `arg:"" help:""` + Flag int `help:""` + Delete struct{} `help:""` + Rename struct { + To string + } `help:""` + } `arg:"" help:""` + } `help:"Manage users."` } p := mustNew(t, &cli) - cmd, err := p.Parse([]string{"10", "delete"}) + cmd, err := p.Parse([]string{"user", "10", "delete"}) require.NoError(t, err) - require.Equal(t, 10, cli.Id.Id) - require.Equal(t, " delete", cmd) + require.Equal(t, 10, cli.User.ID.ID) + require.Equal(t, "user delete", cmd) } func TestResetWithDefaults(t *testing.T) { var cli struct { - Flag string - FlagWithDefault string `default:"default"` + Flag string `help:""` + FlagWithDefault string `default:"default" help:""` } cli.Flag = "BLAH" cli.FlagWithDefault = "BLAH" @@ -60,10 +78,17 @@ func TestResetWithDefaults(t *testing.T) { func TestSlice(t *testing.T) { var cli struct { - Slice []int + Slice []int `help:""` } parser := mustNew(t, &cli) _, err := parser.Parse([]string{"--slice=1,2,3"}) require.NoError(t, err) require.Equal(t, []int{1, 2, 3}, cli.Slice) } + +func TestUnsupportedfieldErrors(t *testing.T) { + var cli struct { + Keys map[string]string `help:""` + } + require.Panics(t, func() { mustNew(t, &cli) }) +} diff --git a/model.go b/model.go index 3ec5e0f..51c12d9 100644 --- a/model.go +++ b/model.go @@ -27,6 +27,7 @@ type Value struct { Field reflect.StructField Value reflect.Value Required bool + Format string // Formatting directive, if applicable. } type Positional = Value diff --git a/scanner.go b/scanner.go index 78c5b3b..e93b717 100644 --- a/scanner.go +++ b/scanner.go @@ -1,7 +1,6 @@ package kong import ( - "fmt" "strconv" ) @@ -19,12 +18,6 @@ const ( PositionalArgumentToken // ) -type TokenAssertionError struct{ err error } - -func (t TokenAssertionError) Error() string { - return t.err.Error() -} - type Token struct { Value string Type TokenType @@ -84,11 +77,11 @@ func (s *Scanner) Pop() Token { return arg } -// PopValue or panic with TokenAssertionError. +// PopValue or panic with Error. func (s *Scanner) PopValue(context string) string { t := s.Pop() if !t.IsValue() { - panic(TokenAssertionError{fmt.Errorf("expected %s value but got %s", context, t)}) + fail("expected %s value but got %s", context, t) } return t.Value }