From b4ae5fb86e73545c37886286b7e49e58e8dd7fa4 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sat, 1 Feb 2020 18:19:03 +1100 Subject: [PATCH] Correctly support encoding.{TextUnmarshaler,BinaryUnmarsheler} --- .circleci/config.yml | 2 +- .golangci.yml | 1 + mapper.go | 48 +++++++++++++++++++++++++++++++++++++++++--- mapper_test.go | 13 +++++++++--- 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index d703844..076df0f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,7 +12,7 @@ jobs: command: | go get -v github.com/jstemmer/go-junit-report go get -v -t -d ./... - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.22.2 + curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | bash -s v1.23.1 mkdir ~/report when: always - run: diff --git a/.golangci.yml b/.golangci.yml index bffbfc7..a5b2806 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -14,6 +14,7 @@ linters: - funlen - gocognit - gomnd + - goprintffuncname linters-settings: govet: diff --git a/mapper.go b/mapper.go index a2a60e4..1482f1b 100644 --- a/mapper.go +++ b/mapper.go @@ -1,6 +1,7 @@ package kong import ( + "encoding" "encoding/json" "io/ioutil" "math/bits" @@ -15,8 +16,10 @@ import ( ) var ( - mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem() - boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem() + mapperValueType = reflect.TypeOf((*MapperValue)(nil)).Elem() + boolMapperType = reflect.TypeOf((*BoolMapper)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() ) // DecodeContext is passed to a Mapper's Decode(). @@ -57,6 +60,34 @@ func (m *mapperValueAdapter) IsBool() bool { return m.isBool } +type textUnmarshalerAdapter struct{} + +func (m *textUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error { + var value string + err := ctx.Scan.PopValueInto("value", &value) + if err != nil { + return err + } + if target.Type().Implements(textUnmarshalerType) { + return target.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)) + } + return target.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)) +} + +type binaryUnmarshalerAdapter struct{} + +func (m *binaryUnmarshalerAdapter) Decode(ctx *DecodeContext, target reflect.Value) error { + var value string + err := ctx.Scan.PopValueInto("value", &value) + if err != nil { + return err + } + if target.Type().Implements(textUnmarshalerType) { + return target.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value)) + } + return target.Addr().Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary([]byte(value)) +} + // A Mapper represents how a field is mapped from command-line values to Go. // // Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag. @@ -137,11 +168,22 @@ func (r *Registry) ForType(typ reflect.Type) Mapper { return &mapperValueAdapter{impl.Implements(boolMapperType)} } } + // Next, try explicitly registered types. var mapper Mapper var ok bool if mapper, ok = r.types[typ]; ok { return mapper - } else if mapper, ok = r.kinds[typ.Kind()]; ok { + } + // Next try stdlib unmarshaler interfaces. + for _, impl := range []reflect.Type{typ, reflect.PtrTo(typ)} { + if impl.Implements(textUnmarshalerType) { + return &textUnmarshalerAdapter{} + } else if impl.Implements(binaryUnmarshalerType) { + return &binaryUnmarshalerAdapter{} + } + } + // Finally try registered kinds. + if mapper, ok = r.kinds[typ.Kind()]; ok { return mapper } return nil diff --git a/mapper_test.go b/mapper_test.go index 3338cff..e67cfc2 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -30,10 +30,15 @@ func TestValueMapper(t *testing.T) { require.Equal(t, "MOO", cli.Flag) } -type textUnmarshalerValue string +type textUnmarshalerValue int func (m *textUnmarshalerValue) UnmarshalText(text []byte) error { - *m = textUnmarshalerValue(strings.ToUpper(string(text))) + s := string(text) + if s == "hello" { + *m = 10 + } else { + return fmt.Errorf("expected \"hello\"") + } return nil } @@ -44,7 +49,9 @@ func TestTextUnmarshaler(t *testing.T) { p := mustNew(t, &cli) _, err := p.Parse([]string{"--value=hello"}) require.NoError(t, err) - require.Equal(t, "HELLO", string(cli.Value)) + require.Equal(t, 10, int(cli.Value)) + _, err = p.Parse([]string{"--value=other"}) + require.Error(t, err) } type jsonUnmarshalerValue string