Ensure values aren't nil before decoding.
This commit is contained in:
@@ -551,6 +551,14 @@ func (c *Context) getValue(value *Value) reflect.Value {
|
||||
v, ok := c.values[value]
|
||||
if !ok {
|
||||
v = reflect.New(value.Target.Type()).Elem()
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr:
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
case reflect.Slice:
|
||||
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
|
||||
case reflect.Map:
|
||||
v.Set(reflect.MakeMap(v.Type()))
|
||||
}
|
||||
c.values[value] = v
|
||||
}
|
||||
return v
|
||||
|
||||
@@ -936,3 +936,17 @@ func TestValidateArg(t *testing.T) {
|
||||
_, err := p.Parse([]string{"one"})
|
||||
require.EqualError(t, err, "<arg>: flag error")
|
||||
}
|
||||
|
||||
func TestPointers(t *testing.T) {
|
||||
cli := struct {
|
||||
Mapped *mappedValue
|
||||
JSON *jsonUnmarshalerValue
|
||||
}{}
|
||||
p := mustNew(t, &cli)
|
||||
_, err := p.Parse([]string{"--mapped=mapped", "--json=\"foo\""})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cli.Mapped)
|
||||
require.Equal(t, "mapped", cli.Mapped.decoded)
|
||||
require.NotNil(t, cli.JSON)
|
||||
require.Equal(t, "FOO", string(*cli.JSON))
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
var (
|
||||
mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem()
|
||||
boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem()
|
||||
jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
|
||||
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
|
||||
)
|
||||
@@ -89,6 +90,20 @@ func (m *binaryUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Val
|
||||
return target.Addr().Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value))
|
||||
}
|
||||
|
||||
type jsonUnmarshalerAdapter struct{}
|
||||
|
||||
func (j *jsonUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error {
|
||||
var value string
|
||||
err := ctx.Scan.PopValueInto("value", &value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if target.Type().Implements(jsonUnmarshalerType) {
|
||||
return target.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(value))
|
||||
}
|
||||
return target.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(value))
|
||||
}
|
||||
|
||||
// A Mapper represents how a field is mapped from command-line values to Go.
|
||||
//
|
||||
// Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag.
|
||||
@@ -177,10 +192,13 @@ func (r *Registry) ForType(typ reflect.Type) Mapper {
|
||||
}
|
||||
// Next try stdlib unmarshaler interfaces.
|
||||
for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} {
|
||||
if impl.Implements(textUnmarshalerType) {
|
||||
switch {
|
||||
case impl.Implements(textUnmarshalerType):
|
||||
return &textUnmarshalerAdapter{}
|
||||
} else if impl.Implements(binaryUnmarshalerType) {
|
||||
case impl.Implements(binaryUnmarshalerType):
|
||||
return &binaryUnmarshalerAdapter{}
|
||||
case impl.Implements(jsonUnmarshalerType):
|
||||
return &jsonUnmarshalerAdapter{}
|
||||
}
|
||||
}
|
||||
// Finally try registered kinds.
|
||||
|
||||
+1
-1
@@ -72,7 +72,7 @@ func TestJSONUnmarshaler(t *testing.T) {
|
||||
Value jsonUnmarshalerValue
|
||||
}
|
||||
p := mustNew(t, &cli)
|
||||
_, err := p.Parse([]string{"--value=hello"})
|
||||
_, err := p.Parse([]string{"--value=\"hello\""})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "HELLO", string(cli.Value))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user