diff --git a/build.go b/build.go index 5433358..7a8dccd 100644 --- a/build.go +++ b/build.go @@ -85,6 +85,7 @@ func buildNode(v reflect.Value) *Node { Help: help, Decoder: decoder, Value: fv, + Field: ft, Required: !optional || required, } if arg { diff --git a/decoders.go b/decoders.go index 1f12605..26ddfca 100644 --- a/decoders.go +++ b/decoders.go @@ -4,15 +4,23 @@ import ( "fmt" "reflect" "strconv" + "strings" ) -type Decoder interface { - Decode(scan *Scanner, target reflect.Value) error +type DecoderContext struct { + // Value being decoded into. + Value *Value } -type DecoderFunc func(scan *Scanner, target reflect.Value) error +type Decoder interface { + Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error +} -func (d DecoderFunc) Decode(scan *Scanner, target reflect.Value) error { return d(scan, target) } +type DecoderFunc func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error + +func (d DecoderFunc) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + return d(ctx, scan, target) +} var _ Decoder = DecoderFunc(nil) @@ -76,6 +84,8 @@ var ( kindDecoders map[reflect.Kind]KindDecoder ) +// DecoderForField finds a decoder for a struct field. +// func DecoderForField(field reflect.StructField) Decoder { name, ok := field.Tag.Lookup("type") if ok { @@ -116,7 +126,7 @@ func RegisterDecoder(decoders ...Decoder) { } func init() { - intDecoder := NewKindDecoder(reflect.Int, func(scan *Scanner, target reflect.Value) error { + intDecoder := NewKindDecoder(reflect.Int, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { n, err := strconv.ParseInt(scan.PopValue("int"), 10, 64) if err != nil { return err @@ -124,7 +134,7 @@ func init() { target.SetInt(n) return nil }) - uintDecoder := NewKindDecoder(reflect.Uint, func(scan *Scanner, target reflect.Value) error { + uintDecoder := NewKindDecoder(reflect.Uint, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { n, err := strconv.ParseUint(scan.PopValue("uint"), 10, 64) if err != nil { return err @@ -143,7 +153,7 @@ func init() { reflect.Uint16: uintDecoder, reflect.Uint32: uintDecoder, reflect.Uint64: uintDecoder, - reflect.Float32: NewKindDecoder(reflect.Float32, func(scan *Scanner, target reflect.Value) error { + reflect.Float32: NewKindDecoder(reflect.Float32, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { n, err := strconv.ParseFloat(scan.PopValue("float"), 32) if err != nil { return err @@ -151,7 +161,7 @@ func init() { target.SetFloat(n) return nil }), - reflect.Float64: NewKindDecoder(reflect.Float64, func(scan *Scanner, target reflect.Value) error { + reflect.Float64: NewKindDecoder(reflect.Float64, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { n, err := strconv.ParseFloat(scan.PopValue("float"), 64) if err != nil { return err @@ -159,17 +169,35 @@ func init() { target.SetFloat(n) return nil }), - reflect.String: NewKindDecoder(reflect.String, func(scan *Scanner, target reflect.Value) error { + reflect.String: NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { target.SetString(scan.PopValue("string")) return nil }), - reflect.Bool: NewKindDecoder(reflect.Bool, func(scan *Scanner, target reflect.Value) error { + reflect.Bool: NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { target.SetBool(true) return nil }), + reflect.Slice: NewKindDecoder(reflect.Slice, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + el := target.Type().Elem() + sep, ok := ctx.Value.Field.Tag.Lookup("sep") + if !ok { + sep = "," + } + childScanner := Scan(strings.Split(scan.PopValue("slice"), sep)...) + childDecoder := DecoderForType(el) + for childScanner.Peek().Type != EOLToken { + childValue := reflect.New(el).Elem() + err := childDecoder.Decode(ctx, childScanner, childValue) + if err != nil { + return err + } + target.Set(reflect.Append(target, childValue)) + } + return nil + }), } } -var missingDecoder DecoderFunc = func(scan *Scanner, target reflect.Value) error { - return fmt.Errorf("no decoder for %q (of type %T)", target.String(), target.Type()) +var missingDecoder DecoderFunc = func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + return fmt.Errorf("no decoder for %q (of type %T) for field %q", target.String(), target.Type(), ctx.Value.Field.Name) } diff --git a/kong.go b/kong.go index d8ece95..bc04265 100644 --- a/kong.go +++ b/kong.go @@ -49,10 +49,9 @@ func (k *Kong) Parse(args []string) (command string, err error) { // Recursively reset values to defaults (as specified in the grammar) or the zero value. func (k *Kong) reset(node *Node) { for _, flag := range node.Flags { + flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type())) if flag.Default != "" { - flag.Decoder.Decode(Scan(flag.Default), flag.Value.Value) - } else { - flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type())) + flag.Decoder.Decode(&DecoderContext{Value: &flag.Value}, Scan(flag.Default), flag.Value.Value) } } for _, pos := range node.Positional { @@ -76,35 +75,39 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error switch { // -- indicates end of parsing. All remaining arguments are treated as positional arguments only. case token.Value == "--": + args := []string{} for { token = scan.Pop() if token.Type == EOLToken { break } - scan.PushTyped(token.Value, PositionalArgumentToken) + args = append(args, token.Value) + } + for i := range args { + scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) } // Long flag. case strings.HasPrefix(token.Value, "--"): // Parse it and push the tokens. parts := strings.SplitN(token.Value[2:], "=", 2) - scan.PushTyped(parts[0], FlagToken) if len(parts) > 1 { scan.PushTyped(parts[1], FlagValueToken) } + scan.PushTyped(parts[0], FlagToken) // Short flag. case strings.HasPrefix(token.Value, "-"): - scan.PushTyped(token.Value[1:2], ShortFlagToken) scan.PushTyped(token.Value[2:], ShortFlagTailToken) + scan.PushTyped(token.Value[1:2], ShortFlagToken) default: scan.PushTyped(token.Value, PositionalArgumentToken) } case ShortFlagTailToken: - scan.PushTyped(token.Value[0:1], ShortFlagToken) scan.PushTyped(token.Value[1:], ShortFlagTailToken) + scan.PushTyped(token.Value[0:1], ShortFlagToken) case FlagToken: if err := matchFlags(node.Flags, token, scan, func(f *Flag) bool { @@ -140,7 +143,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error case branch.Argument != nil: arg := branch.Argument.Argument - if err := arg.Decoder.Decode(scan, arg.Value); err == nil { + if err := arg.Decoder.Decode(&DecoderContext{Value: arg}, scan, arg.Value); err == nil { command = append(command, "<"+arg.Name+">") cmd, err := k.applyNode(scan, &branch.Argument.Node) if err != nil { @@ -171,7 +174,7 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) for _, flag := range flags { // Found a matching flag. if flag.Name == token.Value { - err := flag.Decoder.Decode(scan, flag.Value.Value) + err := flag.Decoder.Decode(&DecoderContext{Value: &flag.Value}, scan, flag.Value.Value) if err != nil { return err } diff --git a/kong_test.go b/kong_test.go index d6a5a80..37d40de 100644 --- a/kong_test.go +++ b/kong_test.go @@ -1,12 +1,9 @@ package kong import ( - "reflect" "testing" "github.com/stretchr/testify/require" - - "github.com/alecthomas/repr" ) func mustNew(t *testing.T, cli interface{}) *Kong { @@ -41,7 +38,6 @@ func TestArgument(t *testing.T) { } `arg:"true"` } p := mustNew(t, &cli) - repr.Println(p.Model, repr.Hide(reflect.Value{})) cmd, err := p.Parse([]string{"10", "delete"}) require.NoError(t, err) require.Equal(t, 10, cli.Id.Id) @@ -61,3 +57,13 @@ func TestResetWithDefaults(t *testing.T) { require.Equal(t, "", cli.Flag) require.Equal(t, "default", cli.FlagWithDefault) } + +func TestSlice(t *testing.T) { + var cli struct { + Slice []int + } + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{"--slice=1,2,3"}) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3}, cli.Slice) +} diff --git a/model.go b/model.go index 285f3f0..3ec5e0f 100644 --- a/model.go +++ b/model.go @@ -24,6 +24,7 @@ type Value struct { Name string Help string Decoder Decoder + Field reflect.StructField Value reflect.Value Required bool } diff --git a/scanner.go b/scanner.go index 541016b..78c5b3b 100644 --- a/scanner.go +++ b/scanner.go @@ -60,18 +60,21 @@ func (t Token) IsValue() bool { } type Scanner struct { - raw []string args []Token } func Scan(args ...string) *Scanner { - s := &Scanner{raw: args} + s := &Scanner{} for _, arg := range args { s.args = append(s.args, Token{Value: arg}) } return s } +func (s *Scanner) Len() int { + return len(s.args) +} + func (s *Scanner) Pop() Token { if len(s.args) == 0 { return Token{Type: EOLToken}