diff --git a/context.go b/context.go index 654de50..bd6ebae 100644 --- a/context.go +++ b/context.go @@ -551,6 +551,14 @@ func (c *Context) getValue(value *Value) reflect.Value { v, ok := c.values[value] if !ok { v = reflect.New(value.Target.Type()).Elem() + switch v.Kind() { + case reflect.Ptr: + v.Set(reflect.New(v.Type().Elem())) + case reflect.Slice: + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + case reflect.Map: + v.Set(reflect.MakeMap(v.Type())) + } c.values[value] = v } return v diff --git a/kong_test.go b/kong_test.go index 3703f4d..03ff1d6 100644 --- a/kong_test.go +++ b/kong_test.go @@ -936,3 +936,17 @@ func TestValidateArg(t *testing.T) { _, err := p.Parse([]string{"one"}) require.EqualError(t, err, ": flag error") } + +func TestPointers(t *testing.T) { + cli := struct { + Mapped *mappedValue + JSON *jsonUnmarshalerValue + }{} + p := mustNew(t, &cli) + _, err := p.Parse([]string{"--mapped=mapped", "--json=\"foo\""}) + require.NoError(t, err) + require.NotNil(t, cli.Mapped) + require.Equal(t, "mapped", cli.Mapped.decoded) + require.NotNil(t, cli.JSON) + require.Equal(t, "FOO", string(*cli.JSON)) +} diff --git a/mapper.go b/mapper.go index 01343b0..183ec09 100644 --- a/mapper.go +++ b/mapper.go @@ -19,6 +19,7 @@ import ( var ( mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem() boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem() + jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() ) @@ -89,6 +90,20 @@ func (m *binaryUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Val return target.Addr().Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value)) } +type jsonUnmarshalerAdapter struct{} + +func (j *jsonUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error { + var value string + err := ctx.Scan.PopValueInto("value", &value) + if err != nil { + return err + } + if target.Type().Implements(jsonUnmarshalerType) { + return target.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(value)) + } + return target.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(value)) +} + // A Mapper represents how a field is mapped from command-line values to Go. // // Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag. @@ -177,10 +192,13 @@ func (r *Registry) ForType(typ reflect.Type) Mapper { } // Next try stdlib unmarshaler interfaces. for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} { - if impl.Implements(textUnmarshalerType) { + switch { + case impl.Implements(textUnmarshalerType): return &textUnmarshalerAdapter{} - } else if impl.Implements(binaryUnmarshalerType) { + case impl.Implements(binaryUnmarshalerType): return &binaryUnmarshalerAdapter{} + case impl.Implements(jsonUnmarshalerType): + return &jsonUnmarshalerAdapter{} } } // Finally try registered kinds. diff --git a/mapper_test.go b/mapper_test.go index 25fb543..091322c 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -72,7 +72,7 @@ func TestJSONUnmarshaler(t *testing.T) { Value jsonUnmarshalerValue } p := mustNew(t, &cli) - _, err := p.Parse([]string{"--value=hello"}) + _, err := p.Parse([]string{"--value=\"hello\""}) require.NoError(t, err) require.Equal(t, "HELLO", string(cli.Value)) }