Correctly support encoding.{TextUnmarshaler,BinaryUnmarsheler}

This commit is contained in:
Alec Thomas
2020-02-01 18:19:03 +11:00
parent 60222fe397
commit b4ae5fb86e
4 changed files with 57 additions and 7 deletions
+1 -1
View File
@@ -12,7 +12,7 @@ jobs:
command: | command: |
go get -v github.com/jstemmer/go-junit-report go get -v github.com/jstemmer/go-junit-report
go get -v -t -d ./... go get -v -t -d ./...
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.22.2 curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.23.1
mkdir ~/report mkdir ~/report
when: always when: always
- run: - run:
+1
View File
@@ -14,6 +14,7 @@ linters:
- funlen - funlen
- gocognit - gocognit
- gomnd - gomnd
- goprintffuncname
linters-settings: linters-settings:
govet: govet:
+45 -3
View File
@@ -1,6 +1,7 @@
package kong package kong
import ( import (
"encoding"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"math/bits" "math/bits"
@@ -15,8 +16,10 @@ import (
) )
var ( var (
mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem() mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem()
boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem() boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
) )
// DecodeContext is passed to a Mapper's Decode(). // DecodeContext is passed to a Mapper's Decode().
@@ -57,6 +60,34 @@ func (m *mapperValueAdapter) IsBool() bool {
return m.isBool return m.isBool
} }
type textUnmarshalerAdapter struct{}
func (m *textUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error {
var value string
err := ctx.Scan.PopValueInto("value", &value)
if err != nil {
return err
}
if target.Type().Implements(textUnmarshalerType) {
return target.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value))
}
return target.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value))
}
type binaryUnmarshalerAdapter struct{}
func (m *binaryUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error {
var value string
err := ctx.Scan.PopValueInto("value", &value)
if err != nil {
return err
}
if target.Type().Implements(textUnmarshalerType) {
return target.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value))
}
return target.Addr().Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value))
}
// A Mapper represents how a field is mapped from command-line values to Go. // A Mapper represents how a field is mapped from command-line values to Go.
// //
// Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag. // Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag.
@@ -137,11 +168,22 @@ func (r *Registry) ForType(typ reflect.Type) Mapper {
return &mapperValueAdapter{impl.Implements(boolMapperType)} return &mapperValueAdapter{impl.Implements(boolMapperType)}
} }
} }
// Next, try explicitly registered types.
var mapper Mapper var mapper Mapper
var ok bool var ok bool
if mapper, ok = r.types[typ]; ok { if mapper, ok = r.types[typ]; ok {
return mapper return mapper
} else if mapper, ok = r.kinds[typ.Kind()]; ok { }
// Next try stdlib unmarshaler interfaces.
for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} {
if impl.Implements(textUnmarshalerType) {
return &textUnmarshalerAdapter{}
} else if impl.Implements(binaryUnmarshalerType) {
return &binaryUnmarshalerAdapter{}
}
}
// Finally try registered kinds.
if mapper, ok = r.kinds[typ.Kind()]; ok {
return mapper return mapper
} }
return nil return nil
+10 -3
View File
@@ -30,10 +30,15 @@ func TestValueMapper(t *testing.T) {
require.Equal(t, "MOO", cli.Flag) require.Equal(t, "MOO", cli.Flag)
} }
type textUnmarshalerValue string type textUnmarshalerValue int
func (m *textUnmarshalerValue) UnmarshalText(text []byte) error { func (m *textUnmarshalerValue) UnmarshalText(text []byte) error {
*m = textUnmarshalerValue(strings.ToUpper(string(text))) s := string(text)
if s == "hello" {
*m = 10
} else {
return fmt.Errorf("expected \"hello\"")
}
return nil return nil
} }
@@ -44,7 +49,9 @@ func TestTextUnmarshaler(t *testing.T) {
p := mustNew(t, &cli) p := mustNew(t, &cli)
_, err := p.Parse([]string{"--value=hello"}) _, err := p.Parse([]string{"--value=hello"})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "HELLO", string(cli.Value)) require.Equal(t, 10, int(cli.Value))
_, err = p.Parse([]string{"--value=other"})
require.Error(t, err)
} }
type jsonUnmarshalerValue string type jsonUnmarshalerValue string