From 439c674f7ae06d6568a94c52ef9fd9402f8608bd Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Wed, 24 Apr 2019 23:18:57 +1000 Subject: [PATCH] Use interface{} instead of string in tokens. This allows the scanner and resolvers to pass Go types around rather than having to serialise/deserialise to/from strings. --- build.go | 6 +- context.go | 83 +++++++++--------- go.mod | 2 + go.sum | 4 + mapper.go | 218 ++++++++++++++++++++++++++++++++++++++--------- mapper_test.go | 4 +- resolver.go | 53 ++---------- resolver_test.go | 44 +++++----- scanner.go | 54 +++++++----- 9 files changed, 296 insertions(+), 172 deletions(-) diff --git a/build.go b/build.go index 829e16e..e87393f 100644 --- a/build.go +++ b/build.go @@ -6,8 +6,6 @@ import ( "strings" ) -var helpProviderType = reflect.TypeOf((*HelpProvider)(nil)).Elem() - func build(k *Kong, ast interface{}) (app *Application, err error) { defer catch(&err) v := reflect.ValueOf(ast) @@ -135,8 +133,8 @@ func buildChild(k *Kong, node *Node, typ NodeType, v reflect.Value, ft reflect.S child.Hidden = tag.Hidden child.Group = tag.Group - if fv.Type().Implements(helpProviderType) { - child.Detail = fv.Interface().(HelpProvider).Help() + if provider, ok := fv.Addr().Interface().(HelpProvider); ok { + child.Detail = provider.Help() } // A branching argument. This is a bit hairy, as we let buildNode() do the parsing, then check that diff --git a/context.go b/context.go index ab7b397..9ca1f31 100644 --- a/context.go +++ b/context.go @@ -258,42 +258,49 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo token := c.scan.Peek() switch token.Type { case UntypedToken: - switch { - // Indicates end of parsing. All remaining arguments are treated as positional arguments only. - case token.Value == "--": - c.scan.Pop() - args := []string{} - for { - token = c.scan.Pop() - if token.Type == EOLToken { - break + switch v := token.Value.(type) { + case string: + + switch { + // Indicates end of parsing. All remaining arguments are treated as positional arguments only. + case v == "--": + c.scan.Pop() + args := []string{} + for { + token = c.scan.Pop() + if token.Type == EOLToken { + break + } + args = append(args, token.String()) + } + // Note: tokens must be pushed in reverse order. + for i := range args { + c.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) } - args = append(args, token.Value) - } - // Note: tokens must be pushed in reverse order. - for i := range args { - c.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) - } - // Long flag. - case strings.HasPrefix(token.Value, "--"): - c.scan.Pop() - // Parse it and push the tokens. - parts := strings.SplitN(token.Value[2:], "=", 2) - if len(parts) > 1 { - c.scan.PushTyped(parts[1], FlagValueToken) - } - c.scan.PushTyped(parts[0], FlagToken) + // Long flag. + case strings.HasPrefix(v, "--"): + c.scan.Pop() + // Parse it and push the tokens. + parts := strings.SplitN(v[2:], "=", 2) + if len(parts) > 1 { + c.scan.PushTyped(parts[1], FlagValueToken) + } + c.scan.PushTyped(parts[0], FlagToken) - // Short flag. - case strings.HasPrefix(token.Value, "-"): - c.scan.Pop() - // Note: tokens must be pushed in reverse order. - if tail := token.Value[2:]; tail != "" { - c.scan.PushTyped(tail, ShortFlagTailToken) - } - c.scan.PushTyped(token.Value[1:2], ShortFlagToken) + // Short flag. + case strings.HasPrefix(v, "-"): + c.scan.Pop() + // Note: tokens must be pushed in reverse order. + if tail := v[2:]; tail != "" { + c.scan.PushTyped(tail, ShortFlagTailToken) + } + c.scan.PushTyped(v[1:2], ShortFlagToken) + default: + c.scan.Pop() + c.scan.PushTyped(token.Value, PositionalArgumentToken) + } default: c.scan.Pop() c.scan.PushTyped(token.Value, PositionalArgumentToken) @@ -302,18 +309,18 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo case ShortFlagTailToken: c.scan.Pop() // Note: tokens must be pushed in reverse order. - if tail := token.Value[1:]; tail != "" { + if tail := token.String()[1:]; tail != "" { c.scan.PushTyped(tail, ShortFlagTailToken) } - c.scan.PushTyped(token.Value[0:1], ShortFlagToken) + c.scan.PushTyped(token.String()[0:1], ShortFlagToken) case FlagToken: - if err := c.parseFlag(flags, "--"+token.Value); err != nil { + if err := c.parseFlag(flags, token.String()); err != nil { return err } case ShortFlagToken: - if err := c.parseFlag(flags, "-"+token.Value); err != nil { + if err := c.parseFlag(flags, token.String()); err != nil { return err } @@ -369,7 +376,7 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo } } - return findPotentialCandidates(token.Value, candidates, "unexpected argument %s", token) + return findPotentialCandidates(token.String(), candidates, "unexpected argument %s", token) default: return fmt.Errorf("unexpected token %s", token) } @@ -396,7 +403,7 @@ func (c *Context) Resolve() error { if err != nil { return err } - if s == "" { + if s == nil { continue } diff --git a/go.mod b/go.mod index da4aff7..757bb09 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/alecthomas/kong require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mitchellh/mapstructure v1.1.2 + github.com/pkg/errors v0.8.1 github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.2.2 ) diff --git a/go.sum b/go.sum index e03ee77..22a1f3c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= diff --git a/mapper.go b/mapper.go index 2fdd440..46f13a7 100644 --- a/mapper.go +++ b/mapper.go @@ -1,7 +1,7 @@ package kong import ( - "fmt" + "encoding/json" "io/ioutil" "math/bits" "net/url" @@ -10,6 +10,8 @@ import ( "strconv" "strings" "time" + + "github.com/pkg/errors" ) var ( @@ -194,9 +196,8 @@ func (r *Registry) RegisterDefaults() *Registry { RegisterKind(reflect.Float32, floatDecoder(32)). RegisterKind(reflect.Float64, floatDecoder(64)). RegisterKind(reflect.String, MapperFunc(func(ctx *DecodeContext, target reflect.Value) error { - token := ctx.Scan.PopValue("string") - target.SetString(token) - return nil + _, err := ctx.Scan.PopValueInto("string", target.Addr().Interface()) + return err })). RegisterKind(reflect.Bool, boolMapper{}). RegisterKind(reflect.Slice, sliceDecoder(r)). @@ -214,7 +215,16 @@ type boolMapper struct{} func (boolMapper) Decode(ctx *DecodeContext, target reflect.Value) error { if ctx.Scan.Peek().Type == FlagValueToken { token := ctx.Scan.Pop() - target.SetBool(token.Value == "true") + switch v := token.Value.(type) { + case string: + target.SetBool(strings.ToLower(v) == "true") + + case bool: + target.SetBool(v) + + default: + return errors.Errorf("expected bool but got %q (%T)", token.Value, token.Value) + } } else { target.SetBool(true) } @@ -224,10 +234,13 @@ func (boolMapper) IsBool() bool { return true } func durationDecoder() MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { - value := ctx.Scan.PopValue("duration") + var value string + if _, err := ctx.Scan.PopValueInto("duration", &value); err != nil { + return err + } r, err := time.ParseDuration(value) if err != nil { - return fmt.Errorf("expected duration but got %q: %s", value, err) + return errors.Errorf("expected duration but got %q: %s", value, err) } target.Set(reflect.ValueOf(r)) return nil @@ -240,7 +253,10 @@ func timeDecoder() MapperFunc { if ctx.Value.Format != "" { format = ctx.Value.Format } - value := ctx.Scan.PopValue("time") + var value string + if _, err := ctx.Scan.PopValueInto("time", &value); err != nil { + return err + } t, err := time.Parse(format, value) if err != nil { return err @@ -250,38 +266,86 @@ func timeDecoder() MapperFunc { } } -func intDecoder(bits int) MapperFunc { +func intDecoder(bits int) MapperFunc { // nolint: dupl return func(ctx *DecodeContext, target reflect.Value) error { - value := ctx.Scan.PopValue("int") - n, err := strconv.ParseInt(value, 10, bits) + t, err := ctx.Scan.PopValue("int") if err != nil { - return fmt.Errorf("expected int but got %q", value) + return err + } + switch v := t.Value.(type) { + case string: + n, err := strconv.ParseInt(v, 10, bits) + if err != nil { + return errors.Errorf("expected an int but got %q (%T)", t, t.Value) + } + target.SetInt(n) + + case float64: + target.SetInt(int64(v)) + + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + target.Set(reflect.ValueOf(v)) + + default: + return errors.Errorf("expected an int but got %q (%T)", t, t.Value) } - target.SetInt(n) return nil } } -func uintDecoder(bits int) MapperFunc { +func uintDecoder(bits int) MapperFunc { // nolint: dupl return func(ctx *DecodeContext, target reflect.Value) error { - value := ctx.Scan.PopValue("uint") - n, err := strconv.ParseUint(value, 10, bits) + t, err := ctx.Scan.PopValue("uint") if err != nil { - return fmt.Errorf("expected unsigned int but got %q", value) + return err + } + switch v := t.Value.(type) { + case string: + n, err := strconv.ParseUint(v, 10, bits) + if err != nil { + return errors.Errorf("expected a uint but got %q (%T)", t, t.Value) + } + target.SetUint(n) + + case float64: + target.SetUint(uint64(v)) + + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + target.Set(reflect.ValueOf(v)) + + default: + return errors.Errorf("expected an int but got %q (%T)", t, t.Value) } - target.SetUint(n) return nil } } func floatDecoder(bits int) MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { - value := ctx.Scan.PopValue("float") - n, err := strconv.ParseFloat(value, bits) + t, err := ctx.Scan.PopValue("float") if err != nil { - return fmt.Errorf("expected float but got %q", value) + return err + } + switch v := t.Value.(type) { + case string: + n, err := strconv.ParseFloat(v, bits) + if err != nil { + return errors.Errorf("expected a float but got %q (%T)", t, t.Value) + } + target.SetFloat(n) + + case float32: + target.SetFloat(float64(v)) + + case float64: + target.SetFloat(v) + + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + target.Set(reflect.ValueOf(v)) + + default: + return errors.Errorf("expected an int but got %q (%T)", t, t.Value) } - target.SetFloat(n) return nil } } @@ -294,17 +358,43 @@ func mapDecoder(r *Registry) MapperFunc { el := target.Type() var childScanner *Scanner if ctx.Value.Flag != nil { + t := ctx.Scan.Pop() // If decoding a flag, we need an argument. - childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("map"), ';')...) + if t.IsEOL() { + return errors.Errorf("unexpected EOL") + } + switch v := t.Value.(type) { + case string: + childScanner = Scan(SplitEscaped(v, ';')...) + + case []map[string]interface{}: + for _, m := range v { + err := jsonTranscode(m, target.Addr().Interface()) + if err != nil { + return errors.WithStack(err) + } + } + return nil + + case map[string]interface{}: + return jsonTranscode(v, target.Addr().Interface()) + + default: + return errors.Errorf("invalid map value %q (of type %T)", t, t.Value) + } } else { tokens := ctx.Scan.PopWhile(func(t Token) bool { return t.IsValue() }) childScanner = ScanFromTokens(tokens...) } for !childScanner.Peek().IsEOL() { - token := childScanner.PopValue("map") + var token string + _, err := childScanner.PopValueInto("map", &token) + if err != nil { + return err + } parts := strings.SplitN(token, "=", 2) if len(parts) != 2 { - return fmt.Errorf("expected \"=\" but got %q", token) + return errors.Errorf("expected \"=\" but got %q", token) } key, value := parts[0], parts[1] @@ -312,7 +402,7 @@ func mapDecoder(r *Registry) MapperFunc { if typ := ctx.Value.Tag.Type; typ != "" { parts := strings.Split(typ, ":") if len(parts) != 2 { - return fmt.Errorf("type:\"\" on map field must be in the form \"[]:[]\"") + return errors.Errorf("type:\"\" on map field must be in the form \"[]:[]\"") } keyTypeName, valueTypeName = parts[0], parts[1] } @@ -321,14 +411,14 @@ func mapDecoder(r *Registry) MapperFunc { keyDecoder := r.ForNamedType(keyTypeName, el.Key()) keyValue := reflect.New(el.Key()).Elem() if err := keyDecoder.Decode(ctx.WithScanner(keyScanner), keyValue); err != nil { - return fmt.Errorf("invalid map key %q", key) + return errors.Errorf("invalid map key %q", key) } valueScanner := Scan(value) valueDecoder := r.ForNamedType(valueTypeName, el.Elem()) valueValue := reflect.New(el.Elem()).Elem() if err := valueDecoder.Decode(ctx.WithScanner(valueScanner), valueValue); err != nil { - return fmt.Errorf("invalid map value %q", value) + return errors.Errorf("invalid map value %q", value) } target.SetMapIndex(keyValue, valueValue) @@ -343,21 +433,35 @@ func sliceDecoder(r *Registry) MapperFunc { sep := ctx.Value.Tag.Sep var childScanner *Scanner if ctx.Value.Flag != nil { + t := ctx.Scan.Pop() // If decoding a flag, we need an argument. - childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("list"), sep)...) + if t.IsEOL() { + return errors.Errorf("unexpected EOL") + } + switch v := t.Value.(type) { + case string: + childScanner = Scan(SplitEscaped(v, sep)...) + + case []interface{}: + return jsonTranscode(v, target.Addr().Interface()) + + default: + v = []interface{}{v} + return jsonTranscode(v, target.Addr().Interface()) + } } else { tokens := ctx.Scan.PopWhile(func(t Token) bool { return t.IsValue() }) childScanner = ScanFromTokens(tokens...) } childDecoder := r.ForNamedType(ctx.Value.Tag.Type, el) if childDecoder == nil { - return fmt.Errorf("no mapper for element type of %s", target.Type()) + return errors.Errorf("no mapper for element type of %s", target.Type()) } for !childScanner.Peek().IsEOL() { childValue := reflect.New(el).Elem() err := childDecoder.Decode(ctx.WithScanner(childScanner), childValue) if err != nil { - return err + return errors.WithStack(err) } target.Set(reflect.Append(target, childValue)) } @@ -371,9 +475,13 @@ func pathMapper(r *Registry) MapperFunc { return sliceDecoder(r)(ctx, target) } if target.Kind() != reflect.String { - return fmt.Errorf("\"path\" type must be applied to a string not %s", target.Type()) + return errors.Errorf("\"path\" type must be applied to a string not %s", target.Type()) + } + var path string + _, err := ctx.Scan.PopValueInto("file", &path) + if err != nil { + return err } - path := ctx.Scan.PopValue("file") path = ExpandPath(path) target.SetString(path) return nil @@ -386,16 +494,20 @@ func existingFileMapper(r *Registry) MapperFunc { return sliceDecoder(r)(ctx, target) } if target.Kind() != reflect.String { - return fmt.Errorf("\"existingfile\" type must be applied to a string not %s", target.Type()) + return errors.Errorf("\"existingfile\" type must be applied to a string not %s", target.Type()) + } + var path string + _, err := ctx.Scan.PopValueInto("file", &path) + if err != nil { + return err } - path := ctx.Scan.PopValue("file") path = ExpandPath(path) stat, err := os.Stat(path) if err != nil { return err } if stat.IsDir() { - return fmt.Errorf("%q exists but is a directory", path) + return errors.Errorf("%q exists but is a directory", path) } target.SetString(path) return nil @@ -408,16 +520,20 @@ func existingDirMapper(r *Registry) MapperFunc { return sliceDecoder(r)(ctx, target) } if target.Kind() != reflect.String { - return fmt.Errorf("\"existingdir\" must be applied to a string not %s", target.Type()) + return errors.Errorf("\"existingdir\" must be applied to a string not %s", target.Type()) + } + var path string + _, err := ctx.Scan.PopValueInto("file", &path) + if err != nil { + return err } - path := ctx.Scan.PopValue("file") path = ExpandPath(path) stat, err := os.Stat(path) if err != nil { return err } if !stat.IsDir() { - return fmt.Errorf("%q exists but is not a directory", path) + return errors.Errorf("%q exists but is not a directory", path) } target.SetString(path) return nil @@ -426,10 +542,15 @@ func existingDirMapper(r *Registry) MapperFunc { func urlMapper() MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { - url, err := url.Parse(ctx.Scan.PopValue("url")) + var urlStr string + _, err := ctx.Scan.PopValueInto("url", &urlStr) if err != nil { return err } + url, err := url.Parse(urlStr) + if err != nil { + return errors.WithStack(err) + } target.Set(reflect.ValueOf(url)) return nil } @@ -479,11 +600,24 @@ func JoinEscaped(s []string, sep rune) string { type FileContentFlag []byte func (f *FileContentFlag) Decode(ctx *DecodeContext) error { // nolint: golint - filename := ExpandPath(ctx.Scan.PopValue("filename")) + var filename string + _, err := ctx.Scan.PopValueInto("filename", &filename) + if err != nil { + return err + } + filename = ExpandPath(filename) data, err := ioutil.ReadFile(filename) // nolint: gosec if err != nil { - return fmt.Errorf("failed to open %q: %s", filename, err) + return errors.Errorf("failed to open %q: %s", filename, err) } *f = data return nil } + +func jsonTranscode(in, out interface{}) error { + data, err := json.Marshal(in) + if err != nil { + return errors.WithStack(err) + } + return errors.Wrapf(json.Unmarshal(data, out), "%#v -> %T", in, out) +} diff --git a/mapper_test.go b/mapper_test.go index 296a133..a69f36b 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -140,8 +140,8 @@ type mappedValue struct { } func (m *mappedValue) Decode(ctx *kong.DecodeContext) error { - m.decoded = ctx.Scan.PopValue("mapped") - return nil + _, err := ctx.Scan.PopValueInto("mapped", &m.decoded) + return err } func TestMapperValue(t *testing.T) { diff --git a/resolver.go b/resolver.go index 353f63a..8ef764d 100644 --- a/resolver.go +++ b/resolver.go @@ -2,7 +2,6 @@ package kong import ( "encoding/json" - "fmt" "io" "strings" ) @@ -15,15 +14,15 @@ type Resolver interface { Validate(app *Application) error // Resolve the value for a Flag. - Resolve(context *Context, parent *Path, flag *Flag) (string, error) + Resolve(context *Context, parent *Path, flag *Flag) (interface{}, error) } // ResolverFunc is a convenience type for non-validating Resolvers. -type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error) +type ResolverFunc func(context *Context, parent *Path, flag *Flag) (interface{}, error) var _ Resolver = ResolverFunc(nil) -func (r ResolverFunc) Resolve(context *Context, parent *Path, flag *Flag) (string, error) { // nolint: golint +func (r ResolverFunc) Resolve(context *Context, parent *Path, flag *Flag) (interface{}, error) { // nolint: golint return r(context, parent, flag) } func (r ResolverFunc) Validate(app *Application) error { return nil } // nolint: golint @@ -37,54 +36,14 @@ func JSON(r io.Reader) (Resolver, error) { if err != nil { return nil, err } - var f ResolverFunc = func(context *Context, parent *Path, flag *Flag) (string, error) { + var f ResolverFunc = func(context *Context, parent *Path, flag *Flag) (interface{}, error) { name := strings.Replace(flag.Name, "-", "_", -1) raw, ok := values[name] if !ok { - return "", nil + return nil, nil } - sep := flag.Tag.Sep - value, err := jsonDecodeValue(sep, raw) - if err != nil { - return "", err - } - return value, nil + return raw, nil } return f, nil } - -func jsonDecodeValue(sep rune, value interface{}) (string, error) { - switch v := value.(type) { - case string: - return v, nil - case float64: - return fmt.Sprintf("%v", v), nil - case []interface{}: - out := []string{} - for _, el := range v { - sel, err := jsonDecodeValue(sep, el) - if err != nil { - return "", err - } - out = append(out, sel) - } - return JoinEscaped(out, sep), nil - case map[string]interface{}: - out := []string{} - for key, el := range v { - sel, err := jsonDecodeValue(sep, el) - if err != nil { - return "", err - } - out = append(out, fmt.Sprintf("%s=%s", key, sel)) - } - return JoinEscaped(out, ';'), nil - case bool: - if v { - return "true", nil - } - return "false", nil - } - return "", fmt.Errorf("unsupported JSON value %v (of type %T)", value, value) -} diff --git a/resolver_test.go b/resolver_test.go index 1a3f327..8e37aee 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -122,7 +122,11 @@ func TestJSONBasic(t *testing.T) { type testUppercaseMapper struct{} func (testUppercaseMapper) Decode(ctx *kong.DecodeContext, target reflect.Value) error { - value := ctx.Scan.PopValue("lowercase") + var value string + _, err := ctx.Scan.PopValueInto("lowercase", &value) + if err != nil { + return err + } target.SetString(strings.ToUpper(value)) return nil } @@ -148,11 +152,11 @@ func TestResolverWithBool(t *testing.T) { Bool bool } - var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { if flag.Name == "bool" { - return "true", nil + return true, nil } - return "", nil + return nil, nil } p := mustNew(t, &cli, kong.Resolvers(resolver)) @@ -167,18 +171,18 @@ func TestLastResolverWins(t *testing.T) { Int []int } - var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { if flag.Name == "int" { - return "1", nil + return 1, nil } - return "", nil + return nil, nil } - var second kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var second kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { if flag.Name == "int" { - return "2", nil + return 2, nil } - return "", nil + return nil, nil } p := mustNew(t, &cli, kong.Resolvers(first, second)) @@ -192,11 +196,11 @@ func TestResolverSatisfiesRequired(t *testing.T) { var cli struct { Int int `required` } - var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { if flag.Name == "int" { - return "1", nil + return 1, nil } - return "", nil + return nil, nil } _, err := mustNew(t, &cli, kong.Resolvers(resolver)).Parse(nil) require.NoError(t, err) @@ -210,18 +214,18 @@ func TestResolverTriggersHooks(t *testing.T) { Flag hookValue } - var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { if flag.Name == "flag" { - return "1", nil + return "one", nil } - return "", nil + return nil, nil } _, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolvers(first)).Parse(nil) require.NoError(t, err) - require.Equal(t, "1", string(cli.Flag)) - require.Equal(t, []string{"before:", "after:1"}, ctx.values) + require.Equal(t, "one", string(cli.Flag)) + require.Equal(t, []string{"before:", "after:one"}, ctx.values) } type validatingResolver struct { @@ -229,8 +233,8 @@ type validatingResolver struct { } func (v *validatingResolver) Validate(app *kong.Application) error { return v.err } -func (v *validatingResolver) Resolve(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { - return "", nil +func (v *validatingResolver) Resolve(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { + return nil, nil } func TestValidatingResolverErrors(t *testing.T) { diff --git a/scanner.go b/scanner.go index a503b41..b805736 100644 --- a/scanner.go +++ b/scanner.go @@ -1,8 +1,10 @@ package kong import ( - "strconv" + "fmt" "strings" + + "github.com/pkg/errors" ) // TokenType is the type of a token. @@ -41,23 +43,23 @@ func (t TokenType) String() string { // Token created by Scanner. type Token struct { - Value string + Value interface{} Type TokenType } func (t Token) String() string { switch t.Type { case FlagToken: - return "--" + t.Value + return fmt.Sprintf("--%v", t.Value) case ShortFlagToken: - return "-" + t.Value + return fmt.Sprintf("-%v", t.Value) case EOLToken: return "EOL" default: - return strconv.Quote(t.Value) + return fmt.Sprintf("%v", t.Value) } } @@ -67,9 +69,9 @@ func (t Token) IsEOL() bool { } // IsAny returns true if the token's type is any of those provided. -func (t Token) IsAny(types ...TokenType) bool { +func (t TokenType) IsAny(types ...TokenType) bool { for _, typ := range types { - if t.Type == typ { + if t == typ { return true } } @@ -79,10 +81,12 @@ func (t Token) IsAny(types ...TokenType) bool { // InferredType tries to infer the type of a token. func (t Token) InferredType() TokenType { if t.Type == UntypedToken { - if strings.HasPrefix(t.Value, "--") { - return FlagToken - } else if strings.HasPrefix(t.Value, "-") { - return ShortFlagToken + if v, ok := t.Value.(string); ok { + if strings.HasPrefix(v, "--") { + return FlagToken + } else if strings.HasPrefix(v, "-") { + return ShortFlagToken + } } } return t.Type @@ -92,8 +96,9 @@ func (t Token) InferredType() TokenType { // // A parseable value is either a value typed token, or an untyped token NOT starting with a hyphen. func (t Token) IsValue() bool { - return t.IsAny(FlagValueToken, ShortFlagTailToken, PositionalArgumentToken) || - (t.Type == UntypedToken && !strings.HasPrefix(t.Value, "-")) + tt := t.InferredType() + return tt.IsAny(FlagValueToken, ShortFlagTailToken, PositionalArgumentToken) || + (tt == UntypedToken && !strings.HasPrefix(t.String(), "-")) } // Scanner is a stack-based scanner over command-line tokens. @@ -137,15 +142,26 @@ func (s *Scanner) Pop() Token { return arg } -// PopValue token, or panic with Error. +// PopValue pops a value token, or returns an error. // // "context" is used to assist the user if the value can not be popped, eg. "expected value but got " -func (s *Scanner) PopValue(context string) string { +func (s *Scanner) PopValue(context string) (Token, error) { t := s.Pop() if !t.IsValue() { - fail("expected %s value but got %s (%s)", context, t, t.InferredType()) + return t, errors.Errorf("expected %s value but got %q (%s)", context, t, t.InferredType()) } - return t.Value + return t, nil +} + +// PopValueInto pops a value token into target or returns an error. +// +// "context" is used to assist the user if the value can not be popped, eg. "expected value but got " +func (s *Scanner) PopValueInto(context string, target interface{}) (Token, error) { + t := s.Pop() + if !t.IsValue() { + return t, errors.Errorf("expected %s value but got %q (%s)", context, t, t.InferredType()) + } + return t, jsonTranscode(t.Value, target) } // PopWhile predicate returns true. @@ -173,13 +189,13 @@ func (s *Scanner) Peek() Token { } // Push an untyped Token onto the front of the Scanner. -func (s *Scanner) Push(arg string) *Scanner { +func (s *Scanner) Push(arg interface{}) *Scanner { s.PushToken(Token{Value: arg}) return s } // PushTyped pushes a typed token onto the front of the Scanner. -func (s *Scanner) PushTyped(arg string, typ TokenType) *Scanner { +func (s *Scanner) PushTyped(arg interface{}, typ TokenType) *Scanner { s.PushToken(Token{Value: arg, Type: typ}) return s }