diff --git a/help.go b/help.go index 4f92437..8fe9cc7 100644 --- a/help.go +++ b/help.go @@ -5,7 +5,6 @@ import ( "fmt" "go/doc" "io" - "reflect" "strings" ) @@ -197,8 +196,5 @@ func formatFlag(haveShort bool, flag *Flag) string { if !isBool { flagString += fmt.Sprintf("=%s", flag.FormatPlaceHolder()) } - if flag.Value.Target.Kind() == reflect.Slice { - flagString += " ..." - } return flagString } diff --git a/help_test.go b/help_test.go index 5d5e76d..9aa6004 100644 --- a/help_test.go +++ b/help_test.go @@ -10,9 +10,11 @@ import ( func TestHelp(t *testing.T) { // nolint: govet var cli struct { - String string `help:"A string flag."` - Bool bool `help:"A bool flag with very long help that wraps a lot and is verbose and is really verbose."` - Required bool `required help:"A required flag."` + String string `help:"A string flag."` + Bool bool `help:"A bool flag with very long help that wraps a lot and is verbose and is really verbose."` + Slice []string `help:"A slice of strings." placeholder:"STR"` + Map map[string]int `help:"A map of strings to ints."` + Required bool `required help:"A required flag."` One struct { Flag string `help:"Nested flag."` @@ -60,6 +62,8 @@ Flags: --string=STRING A string flag. --bool A bool flag with very long help that wraps a lot and is verbose and is really verbose. + --slice=STR,... A slice of strings. + --map=KEY=VALUE A map of strings to ints. --required A required flag. Commands: @@ -91,6 +95,8 @@ Flags: --string=STRING A string flag. --bool A bool flag with very long help that wraps a lot and is verbose and is really verbose. + --slice=STR,... A slice of strings. + --map=KEY=VALUE A map of strings to ints. --required A required flag. --flag=STRING Nested flag under two. diff --git a/kong_test.go b/kong_test.go index bb3d322..2eaaf72 100644 --- a/kong_test.go +++ b/kong_test.go @@ -139,7 +139,7 @@ func TestArgSliceWithSeparator(t *testing.T) { func TestUnsupportedFieldErrors(t *testing.T) { var cli struct { - Keys map[string]string + Keys struct{} } _, err := New(&cli) require.Error(t, err) @@ -401,3 +401,12 @@ func TestDuplicateSliceAccumulates(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{1, 2, 3, 4}, cli.Flag) } + +func TestMapFlag(t *testing.T) { + var cli struct { + Set map[string]int + } + _, err := mustNew(t, &cli).Parse([]string{"--set", "a=10", "--set", "b=20"}) + require.NoError(t, err) + require.Equal(t, map[string]int{"a": 10, "b": 20}, cli.Set) +} diff --git a/mapper.go b/mapper.go index 6b4d05a..dcbacd0 100644 --- a/mapper.go +++ b/mapper.go @@ -153,7 +153,8 @@ func (d *Registry) RegisterDefaults() *Registry { RegisterKind(reflect.Bool, boolMapper{}). RegisterType(reflect.TypeOf(time.Time{}), timeDecoder()). RegisterType(reflect.TypeOf(time.Duration(0)), durationDecoder()). - RegisterKind(reflect.Slice, sliceDecoder(d)) + RegisterKind(reflect.Slice, sliceDecoder(d)). + RegisterKind(reflect.Map, mapDecoder(d)) } type boolMapper struct{} @@ -228,6 +229,36 @@ func floatDecoder(bits int) MapperFunc { func mapDecoder(d *Registry) MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { + if target.IsNil() { + target.Set(reflect.MakeMap(target.Type())) + } + el := target.Type() + sep := ctx.Value.Tag.Sep + if sep == 0 { + sep = '=' + } + token := ctx.Scan.PopValue("map") + parts := SplitEscaped(token, sep) + if len(parts) != 2 { + return fmt.Errorf("expected \"%c\" but got %q", sep, token) + } + key, value := parts[0], parts[1] + + keyScanner := Scan(key) + keyDecoder := d.ForType(el.Key()) + keyValue := reflect.New(el.Key()).Elem() + if err := keyDecoder.Decode(ctx.WithScanner(keyScanner), keyValue); err != nil { + return fmt.Errorf("invalid map key %q", key) + } + + valueScanner := Scan(value) + valueDecoder := d.ForType(el.Elem()) + valueValue := reflect.New(el.Elem()).Elem() + if err := valueDecoder.Decode(ctx.WithScanner(valueScanner), valueValue); err != nil { + return fmt.Errorf("invalid map value %q", value) + } + + target.SetMapIndex(keyValue, valueValue) return nil } } @@ -236,12 +267,15 @@ func sliceDecoder(d *Registry) MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { el := target.Type().Elem() sep := ctx.Value.Tag.Sep + if sep == 0 { + sep = ',' + } var childScanner *Scanner if ctx.Value.Flag != nil { // If decoding a flag, we need an argument. childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("list"), sep)...) } else { - tokens := ctx.Scan.PopUntil(func(t Token) bool { return !t.IsValue() }) + tokens := ctx.Scan.PopWhile(func(t Token) bool { return t.IsValue() }) childScanner = Scan(tokens...) } childDecoder := d.ForType(el) diff --git a/model.go b/model.go index a7834ca..cdf5e50 100644 --- a/model.go +++ b/model.go @@ -164,11 +164,21 @@ func (v *Value) Summary() string { return argText } -// IsCumulative returns true of the value is a slice. +// IsCumulative returns true if the type can be accumulated into. func (v *Value) IsCumulative() bool { + return v.IsSlice() || v.IsMap() +} + +// IsSlice returns true if the value is a slice. +func (v *Value) IsSlice() bool { return v.Target.Kind() == reflect.Slice } +// IsMap returns true if the value is a map. +func (v *Value) IsMap() bool { + return v.Target.Kind() == reflect.Map +} + // IsBool returns true if the underlying value is a boolean. func (v *Value) IsBool() bool { if m, ok := v.Mapper.(BoolMapper); ok && m.IsBool() { @@ -229,8 +239,8 @@ func (f *Flag) String() string { // FormatPlaceHolder formats the placeholder string for a Flag. func (f *Flag) FormatPlaceHolder() string { tail := "" - if f.Value.IsCumulative() { - tail += ", ..." + if f.Value.IsSlice() { + tail += ",..." } if f.Default != "" { if f.Value.Target.Kind() == reflect.String { @@ -241,5 +251,8 @@ func (f *Flag) FormatPlaceHolder() string { if f.PlaceHolder != "" { return f.PlaceHolder + tail } + if f.Value.IsMap() { + return "KEY=VALUE" + tail + } return strings.ToUpper(f.Name) + tail } diff --git a/resolver.go b/resolver.go index 08520ae..fd5a02d 100755 --- a/resolver.go +++ b/resolver.go @@ -26,7 +26,11 @@ func JSON(r io.Reader) (ResolverFunc, error) { if !ok { return "", nil } - value, err := jsonDecodeValue(flag.Tag.Sep, raw) + sep := flag.Tag.Sep + if sep == 0 { + sep = ',' + } + value, err := jsonDecodeValue(sep, raw) if err != nil { return "", err } diff --git a/tag.go b/tag.go index 17a71e6..1573806 100644 --- a/tag.go +++ b/tag.go @@ -130,14 +130,6 @@ func parseTag(fv reflect.Value, ft reflect.StructField) *Tag { t.Hidden = t.Has("hidden") t.Format, _ = t.Get("format") t.Sep, _ = t.GetRune("sep") - if t.Sep == 0 { - if t.Cmd || t.Arg { - t.Sep = ' ' - } else { - t.Sep = ',' - } - } - t.PlaceHolder, _ = t.Get("placeholder") if t.PlaceHolder == "" { t.PlaceHolder = strings.ToUpper(dashedString(fv.Type().Name()))