From e9d88d6528a48382ffe1e1a32dba29ad9a818552 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Tue, 12 Jun 2018 07:20:55 +1000 Subject: [PATCH] Implement flag "resolvers". (#24) * Propagate errors. * Use junit test output. * Expand role of DecodeContext to include Scanner. * Inject resolved flags as Path elements in the Context. This allows all existing logic to apply seamlessly: hooks, required flags, etc. * Clarify that hooks can be called multiple times. --- .circleci/config.yml | 16 ++- README.md | 8 +- context.go | 71 +++++++++++-- context_test.go | 1 - help_test.go | 1 + kong.go | 6 +- kong_test.go | 44 +++++++- mapper.go | 109 +++++++++++++++----- mapper_test.go | 13 ++- model.go | 8 +- options.go | 15 ++- resolver.go | 84 +++++++++++++++ resolver_test.go | 237 +++++++++++++++++++++++++++++++++++++++++++ scanner.go | 11 +- tag.go | 10 +- tag_test.go | 3 + 16 files changed, 579 insertions(+), 58 deletions(-) mode change 100644 => 100755 context.go delete mode 100644 context_test.go mode change 100644 => 100755 options.go create mode 100755 resolver.go create mode 100755 resolver_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 0813982..f8257b3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,5 +7,17 @@ jobs: working_directory: /go/src/github.com/alecthomas/kong steps: - checkout - - run: go get -v -t -d ./... - - run: go test -v ./... + - run: + name: Prepare + command: | + go get -v github.com/jstemmer/go-junit-report + go get -v -t -d ./... + mkdir ~/report + when: always + - run: + name: Test + command: | + go test -v ./... 2>&1 | tee report.txt && go-junit-report report.txt > ~/report/junit.xml + - store_test_results: + path: ~/report + diff --git a/README.md b/README.md index 7fac08f..82cc74d 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ eg. ``` $ shell --help -usage: shell [] +usage: shell A shell-like example app. @@ -70,10 +70,10 @@ Flags: --debug Debug mode. Commands: - rm [] ... + rm ... Remove files. - ls [] [ ...] + ls [ ...] List paths. ``` @@ -83,7 +83,7 @@ eg. ``` $ shell --help rm -usage: shell rm [] ... +usage: shell rm ... Remove files. diff --git a/context.go b/context.go old mode 100644 new mode 100755 index 0cd8cac..20ef8c5 --- a/context.go +++ b/context.go @@ -23,8 +23,12 @@ type Path struct { // Parsed value for non-commands. Value reflect.Value + + // True if this Path element was created as the result of a resolver. + Resolved bool } +// Context contains the current parse context. type Context struct { App *Kong Path []*Path // A trace through parsed nodes. @@ -64,9 +68,14 @@ func Trace(k *Kong, args []string) (*Context, error) { return nil, err } c.Error = c.trace(&c.App.Model.Node) + err = c.traceResolvers() + if err != nil { + return nil, err + } return c, nil } +// Validate the current context. func (c *Context) Validate() error { for _, path := range c.Path { if err := checkMissingFlags(path.Flags); err != nil { @@ -258,7 +267,6 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo Parent: node, Positional: arg, Value: value, - Flags: node.Flags, }) positional++ break @@ -272,7 +280,7 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo Parent: node, Command: branch, Value: branch.Target, - Flags: node.Flags, + Flags: branch.Flags, }) return c.trace(branch) } @@ -287,7 +295,7 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo Parent: node, Argument: branch, Value: value, - Flags: node.Flags, + Flags: branch.Flags, }) return c.trace(branch) } @@ -305,8 +313,10 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo // Apply traced context to the target grammar. func (c *Context) Apply() (string, error) { path := []string{} + for _, trace := range c.Path { switch { + case trace.App != nil: case trace.Argument != nil: path = append(path, "<"+trace.Argument.Name+">") trace.Argument.Argument.Apply(trace.Value) @@ -317,25 +327,64 @@ func (c *Context) Apply() (string, error) { case trace.Positional != nil: path = append(path, "<"+trace.Positional.Name+">") trace.Positional.Apply(trace.Value) + default: + panic("unsupported path ?!") } } + return strings.Join(path, " "), nil } +// Walk through flags from existing nodes in the path. +func (c *Context) traceResolvers() error { + if len(c.App.resolvers) == 0 { + return nil + } + + inserted := []*Path{} + for _, path := range c.Path { + for _, flag := range path.Flags { + for _, resolver := range c.App.resolvers { + s, err := resolver(c, path, flag) + if err != nil { + return err + } + if s == "" { + continue + } + + scan := Scan().PushTyped(s, FlagValueToken) + value, err := flag.Parse(scan) + if err != nil { + return err + } + inserted = append(inserted, &Path{ + Flag: flag, + Value: value, + Resolved: true, + }) + } + } + } + c.Path = append(inserted, c.Path...) + return nil +} + func (c *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) { defer catch(&err) token := c.scan.Peek() for _, flag := range flags { // Found a matching flag. - if matcher(flag) { - c.scan.Pop() - value, err := flag.Parse(c.scan) - if err != nil { - return err - } - c.Path = append(c.Path, &Path{Flag: flag, Value: value}) - return nil + if !matcher(flag) { + continue } + c.scan.Pop() + value, err := flag.Parse(c.scan) + if err != nil { + return err + } + c.Path = append(c.Path, &Path{Flag: flag, Value: value}) + return nil } return fmt.Errorf("unknown flag --%s", token.Value) } diff --git a/context_test.go b/context_test.go deleted file mode 100644 index 1af827e..0000000 --- a/context_test.go +++ /dev/null @@ -1 +0,0 @@ -package kong diff --git a/help_test.go b/help_test.go index 731bf98..b1dfffe 100644 --- a/help_test.go +++ b/help_test.go @@ -8,6 +8,7 @@ import ( ) func TestHelp(t *testing.T) { + // nolint: govet var cli struct { String string `help:"A string flag."` Bool bool `help:"A bool flag with very long help that wraps a lot and is verbose and is really verbose."` diff --git a/kong.go b/kong.go index f94ddcb..8db9636 100644 --- a/kong.go +++ b/kong.go @@ -17,6 +17,7 @@ func fail(format string, args ...interface{}) { panic(Error{fmt.Sprintf(format, args...)}) } +// Must creates a new Parser or panics if there is an error. func Must(ast interface{}, options ...Option) *Kong { k, err := New(ast, options...) if err != nil { @@ -37,6 +38,7 @@ type Kong struct { Stderr io.Writer before map[reflect.Value]HookFunc + resolvers []ResolverFunc registry *Registry noDefaultHelp bool help func(*Context) error @@ -105,7 +107,7 @@ func (k *Kong) extraFlags() []*Flag { return []*Flag{helpFlag} } -// Path parses the command-line, validating and collecting matching grammar nodes. +// Trace parses the command-line, validating and collecting matching grammar nodes. func (k *Kong) Trace(args []string) (*Context, error) { return Trace(k, args) } @@ -171,7 +173,7 @@ func (k *Kong) Errorf(format string, args ...interface{}) { fmt.Fprintf(k.Stderr, k.Model.Name+": error: "+format, args...) } -// FatalIfError terminates with an error message if err != nil. +// FatalIfErrorf terminates with an error message if err != nil. func (k *Kong) FatalIfErrorf(err error, args ...interface{}) { if err == nil { return diff --git a/kong_test.go b/kong_test.go index 460bbd0..62fbd2d 100644 --- a/kong_test.go +++ b/kong_test.go @@ -101,9 +101,19 @@ func TestFlagSlice(t *testing.T) { require.Equal(t, []int{1, 2, 3}, cli.Slice) } +func TestFlagSliceWithSeparator(t *testing.T) { + var cli struct { + Slice []string + } + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{`--slice=a\,b,c`}) + require.NoError(t, err) + require.Equal(t, []string{"a,b", "c"}, cli.Slice) +} + func TestArgSlice(t *testing.T) { var cli struct { - Slice []int `kong:"arg"` + Slice []int `arg` Flag bool } parser := mustNew(t, &cli) @@ -113,6 +123,18 @@ func TestArgSlice(t *testing.T) { require.Equal(t, true, cli.Flag) } +func TestArgSliceWithSeparator(t *testing.T) { + var cli struct { + Slice []string `arg` + Flag bool + } + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{"a,b", "c", "--flag"}) + require.NoError(t, err) + require.Equal(t, []string{"a,b", "c"}, cli.Slice) + require.Equal(t, true, cli.Flag) +} + func TestUnsupportedFieldErrors(t *testing.T) { var cli struct { Keys map[string]string @@ -356,3 +378,23 @@ func TestShort(t *testing.T) { require.True(t, cli.Bool) require.Equal(t, "hello", cli.String) } + +func TestDuplicateFlagChoosesLast(t *testing.T) { + var cli struct { + Flag int + } + + _, err := mustNew(t, &cli).Parse([]string{"--flag=1", "--flag=2"}) + require.NoError(t, err) + require.Equal(t, 2, cli.Flag) +} + +func TestDuplicateSliceDoesNotAccumulate(t *testing.T) { + var cli struct { + Flag []int + } + + _, err := mustNew(t, &cli).Parse([]string{"--flag=1,2", "--flag=3,4"}) + require.NoError(t, err) + require.Equal(t, []int{3, 4}, cli.Flag) +} diff --git a/mapper.go b/mapper.go index 1702030..9b7c261 100644 --- a/mapper.go +++ b/mapper.go @@ -9,16 +9,30 @@ import ( "time" ) -type DecoderContext struct { +// DecodeContext is passed to a Mapper's Decode(). +// +// It contains the Value being decoded into and the Scanner to parse from. +type DecodeContext struct { // Value being decoded into. Value *Value + // Scan contains the input to scan into Target. + Scan *Scanner +} + +// WithScanner creates a clone of this context with a new Scanner. +func (d *DecodeContext) WithScanner(scan *Scanner) *DecodeContext { + return &DecodeContext{ + Value: d.Value, + Scan: scan, + } } // A Mapper represents how a field is mapped from command-line values to Go. // // Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag. type Mapper interface { - Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error + // Decode ctx.Value with ctx.Scanner into target. + Decode(ctx *DecodeContext, target reflect.Value) error } // A BoolMapper is a Mapper to a value that is a boolean. @@ -27,13 +41,14 @@ type BoolMapper interface { IsBool() bool } -type MapperFunc func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error +// A MapperFunc is a single function that complies with the Mapper interface. +type MapperFunc func(ctx *DecodeContext, target reflect.Value) error -func (d MapperFunc) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - return d(ctx, scan, target) +func (d MapperFunc) Decode(ctx *DecodeContext, target reflect.Value) error { //nolint: golint + return d(ctx, target) } -// A Registry encapsulates a set of fields and lookups to resolve them. +// A Registry contains a set of mappers and supporting lookup methods. type Registry struct { names map[string]Mapper types map[reflect.Type]Mapper @@ -41,6 +56,7 @@ type Registry struct { values map[reflect.Value]Mapper } +// NewRegistry creates a new (empty) Registry. func NewRegistry() *Registry { return &Registry{ names: map[string]Mapper{}, @@ -60,6 +76,7 @@ func (d *Registry) ForNamedType(name string, value reflect.Value) Mapper { return d.ForValue(value) } +// ForValue looks up the Mapper for a reflect.Value. func (d *Registry) ForValue(value reflect.Value) Mapper { if mapper, ok := d.values[value]; ok { return mapper @@ -67,7 +84,7 @@ func (d *Registry) ForValue(value reflect.Value) Mapper { return d.ForType(value.Type()) } -// DecoderForType finds a mapper from a type or kind. +// ForType finds a mapper from a type, by type, then kind. // // Will return nil if a mapper can not be determined. func (d *Registry) ForType(typ reflect.Type) Mapper { @@ -81,6 +98,7 @@ func (d *Registry) ForType(typ reflect.Type) Mapper { return nil } +// RegisterKind registers a Mapper for a reflect.Kind. func (d *Registry) RegisterKind(kind reflect.Kind, mapper Mapper) *Registry { d.kinds[kind] = mapper return d @@ -97,12 +115,13 @@ func (d *Registry) RegisterName(name string, mapper Mapper) *Registry { return d } +// RegisterType registers a Mapper for a reflect.Type. func (d *Registry) RegisterType(typ reflect.Type, mapper Mapper) *Registry { d.types[typ] = mapper return d } -// RegisterValue registers a mapper by a pointer to the mapper value. +// RegisterValue registers a Mapper by pointer to the field value. func (d *Registry) RegisterValue(ptr interface{}, mapper Mapper) *Registry { key := reflect.ValueOf(ptr) if key.Kind() != reflect.Ptr { @@ -113,6 +132,7 @@ func (d *Registry) RegisterValue(ptr interface{}, mapper Mapper) *Registry { return d } +// RegisterDefaults registers Mappers for all builtin supported Go types and some common stdlib types. func (d *Registry) RegisterDefaults() *Registry { return d.RegisterKind(reflect.Int, intDecoder(bits.UintSize)). RegisterKind(reflect.Int8, intDecoder(8)). @@ -126,8 +146,8 @@ func (d *Registry) RegisterDefaults() *Registry { RegisterKind(reflect.Uint64, uintDecoder(64)). RegisterKind(reflect.Float32, floatDecoder(32)). RegisterKind(reflect.Float64, floatDecoder(64)). - RegisterKind(reflect.String, MapperFunc(func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - target.SetString(scan.PopValue("string")) + RegisterKind(reflect.String, MapperFunc(func(ctx *DecodeContext, target reflect.Value) error { + target.SetString(ctx.Scan.PopValue("string")) return nil })). RegisterKind(reflect.Bool, boolMapper{}). @@ -138,15 +158,15 @@ func (d *Registry) RegisterDefaults() *Registry { type boolMapper struct{} -func (boolMapper) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { +func (boolMapper) Decode(ctx *DecodeContext, target reflect.Value) error { target.SetBool(true) return nil } func (boolMapper) IsBool() bool { return true } func durationDecoder() MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - d, err := time.ParseDuration(scan.PopValue("duration")) + return func(ctx *DecodeContext, target reflect.Value) error { + d, err := time.ParseDuration(ctx.Scan.PopValue("duration")) if err != nil { return err } @@ -156,12 +176,12 @@ func durationDecoder() MapperFunc { } func timeDecoder() MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + return func(ctx *DecodeContext, target reflect.Value) error { fmt := time.RFC3339 if ctx.Value.Format != "" { fmt = ctx.Value.Format } - t, err := time.Parse(fmt, scan.PopValue("time")) + t, err := time.Parse(fmt, ctx.Scan.PopValue("time")) if err != nil { return err } @@ -171,8 +191,8 @@ func timeDecoder() MapperFunc { } func intDecoder(bits int) MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - value := scan.PopValue("int") + return func(ctx *DecodeContext, target reflect.Value) error { + value := ctx.Scan.PopValue("int") n, err := strconv.ParseInt(value, 10, bits) if err != nil { return fmt.Errorf("invalid int %q", value) @@ -183,8 +203,8 @@ func intDecoder(bits int) MapperFunc { } func uintDecoder(bits int) MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - value := scan.PopValue("uint") + return func(ctx *DecodeContext, target reflect.Value) error { + value := ctx.Scan.PopValue("uint") n, err := strconv.ParseUint(value, 10, bits) if err != nil { return fmt.Errorf("invalid uint %q", value) @@ -195,8 +215,8 @@ func uintDecoder(bits int) MapperFunc { } func floatDecoder(bits int) MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { - value := scan.PopValue("float") + return func(ctx *DecodeContext, target reflect.Value) error { + value := ctx.Scan.PopValue("float") n, err := strconv.ParseFloat(value, bits) if err != nil { return fmt.Errorf("invalid float %q", value) @@ -207,15 +227,15 @@ func floatDecoder(bits int) MapperFunc { } func sliceDecoder(d *Registry) MapperFunc { - return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { + return func(ctx *DecodeContext, target reflect.Value) error { el := target.Type().Elem() sep := ctx.Value.Tag.Sep var childScanner *Scanner if ctx.Value.Flag != nil { // If decoding a flag, we need an argument. - childScanner = Scan(strings.Split(scan.PopValue("list"), sep)...) + childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("list"), sep)...) } else { - tokens := scan.PopUntil(func(t Token) bool { return !t.IsValue() }) + tokens := ctx.Scan.PopUntil(func(t Token) bool { return !t.IsValue() }) childScanner = Scan(tokens...) } childDecoder := d.ForType(el) @@ -224,7 +244,7 @@ func sliceDecoder(d *Registry) MapperFunc { } for childScanner.Peek().Type != EOLToken { childValue := reflect.New(el).Elem() - err := childDecoder.Decode(ctx, childScanner, childValue) + err := childDecoder.Decode(ctx.WithScanner(childScanner), childValue) if err != nil { return err } @@ -233,3 +253,42 @@ func sliceDecoder(d *Registry) MapperFunc { return nil } } + +// SplitEscaped splits a string on a separator. +// +// It differs from strings.Split() in that the separator can exist in a field by escaping it with a \. eg. +// +// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"} +func SplitEscaped(s string, sep rune) (out []string) { + escaped := false + token := "" + for _, ch := range s { + if escaped { + token += string(ch) + escaped = false + } else if ch == '\\' { + escaped = true + } else if ch == sep && !escaped { + out = append(out, token) + token = "" + escaped = false + } else { + token += string(ch) + } + } + if token != "" { + out = append(out, token) + } + return +} + +// JoinEscaped joins a slice of strings on sep, but also escapes any instances of sep in the fields with \. eg. +// +// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob` +func JoinEscaped(s []string, sep rune) string { + escaped := []string{} + for _, e := range s { + escaped = append(escaped, strings.Replace(e, string(sep), `\`+string(sep), -1)) + } + return strings.Join(escaped, string(sep)) +} diff --git a/mapper_test.go b/mapper_test.go index e635167..44008f5 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -36,7 +36,7 @@ func TestNamedMapper(t *testing.T) { type testMooMapper struct{} -func (testMooMapper) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { +func (testMooMapper) Decode(ctx *DecodeContext, target reflect.Value) error { target.SetString("MOO") return nil } @@ -64,3 +64,14 @@ func TestDurationMapper(t *testing.T) { require.NoError(t, err) require.Equal(t, time.Second*5, cli.Flag) } + +func TestSplitEscaped(t *testing.T) { + require.Equal(t, []string{"a", "b"}, SplitEscaped("a,b", ',')) + require.Equal(t, []string{"a,b", "c"}, SplitEscaped(`a\,b,c`, ',')) +} + +func TestJoinEscaped(t *testing.T) { + require.Equal(t, `a,b`, JoinEscaped([]string{"a", "b"}, ',')) + require.Equal(t, `a\,b,c`, JoinEscaped([]string{`a,b`, `c`}, ',')) + require.Equal(t, JoinEscaped(SplitEscaped(`a\,b,c`, ','), ','), `a\,b,c`) +} diff --git a/model.go b/model.go index 70f446b..809d1d5 100644 --- a/model.go +++ b/model.go @@ -139,6 +139,7 @@ type Value struct { Position int // Position (for positional arguments). } +// Summary returns a human-readable summary of the value. func (v *Value) Summary() string { if v.Flag != nil { if v.IsBool() { @@ -156,10 +157,12 @@ func (v *Value) Summary() string { return argText } +// IsCumulative returns true of the value is a slice. func (v *Value) IsCumulative() bool { return v.Value.Kind() == reflect.Slice } +// IsBool returns true if the underlying value is a boolean. func (v *Value) IsBool() bool { if m, ok := v.Mapper.(BoolMapper); ok && m.IsBool() { return true @@ -170,7 +173,7 @@ func (v *Value) IsBool() bool { // Parse tokens into value, parse, and validate, but do not write to the field. func (v *Value) Parse(scan *Scanner) (reflect.Value, error) { value := reflect.New(v.Value.Type()).Elem() - err := v.Mapper.Decode(&DecoderContext{Value: v}, scan, value) + err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, value) if err == nil { v.Set = true } @@ -196,8 +199,10 @@ func (v *Value) Reset() error { return nil } +// A Positional represents a non-branching command-line positional argument. type Positional = Value +// A Flag represents a command-line flag. type Flag struct { *Value PlaceHolder string @@ -217,6 +222,7 @@ func (f *Flag) String() string { return out } +// FormatPlaceHolder formats the placeholder string for a Flag. func (f *Flag) FormatPlaceHolder() string { tail := "" if f.Value.IsCumulative() { diff --git a/options.go b/options.go old mode 100644 new mode 100755 index c685981..7af0c54 --- a/options.go +++ b/options.go @@ -69,9 +69,12 @@ func Writers(stdout, stderr io.Writer) Option { // HookFunc is a callback tied to a field of the grammar, called before a value is applied. type HookFunc func(ctx *Context, path *Path) error -// Hook to aply before a command, flag or positional argument is encountered. +// Hook to apply before a command, flag or positional argument is encountered. // // "ptr" is a pointer to a field of the grammar. +// +// Note that the hook will be called once for each time the corresponding node is encountered. This means that if a flag +// is passed twice, its hook will be called twice. func Hook(ptr interface{}, hook HookFunc) Option { key := reflect.ValueOf(ptr) if key.Kind() != reflect.Ptr { @@ -82,13 +85,21 @@ func Hook(ptr interface{}, hook HookFunc) Option { } } +// HelpFunction is the type of a function used to display help. type HelpFunction func(*Context) error // Help function to use. // // Defaults to PrintHelp. -func Help(help func(*Context) error) Option { +func Help(help HelpFunction) Option { return func(k *Kong) { k.help = help } } + +// Resolver registers flag resolvers. +func Resolver(resolvers ...ResolverFunc) Option { + return func(k *Kong) { + k.resolvers = append(k.resolvers, resolvers...) + } +} diff --git a/resolver.go b/resolver.go new file mode 100755 index 0000000..833b94a --- /dev/null +++ b/resolver.go @@ -0,0 +1,84 @@ +package kong + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" +) + +// ResolverFunc resolves a Flag value from an external source. +type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error) + +// JSONResolver returns a Resolver that retrieves values from a JSON source. +// +// Hyphens in flag names are replaced with underscores. +func JSONResolver(r io.Reader) (ResolverFunc, error) { + values := map[string]interface{}{} + err := json.NewDecoder(r).Decode(&values) + if err != nil { + return nil, err + } + f := func(context *Context, parent *Path, flag *Flag) (string, error) { + name := strings.Replace(flag.Name, "-", "_", -1) + raw, ok := values[name] + if !ok { + return "", nil + } + value, err := jsonDecodeValue(flag.Tag.Sep, raw) + if err != nil { + return "", err + } + return value, nil + } + + return f, nil +} + +func jsonDecodeValue(sep rune, value interface{}) (string, error) { + switch v := value.(type) { + case string: + return v, nil + case float64: + return fmt.Sprintf("%v", v), nil + case []interface{}: + out := []string{} + for _, el := range v { + sel, err := jsonDecodeValue(sep, el) + if err != nil { + return "", err + } + out = append(out, sel) + } + return JoinEscaped(out, sep), nil + case bool: + if v { + return "true", nil + } + return "false", nil + } + return "", fmt.Errorf("unsupported JSON value %v (of type %T)", value, value) +} + +// PerFlagEnvResolver automatically determines environment variables based on the name of each flag, transformed to +// uppercase and underscored, e.g. `my-flag` -> `MY_FLAG` The environment variable key can be overridden with the `env` +// tag. +func PerFlagEnvResolver(prefix string) ResolverFunc { + return func(context *Context, parent *Path, flag *Flag) (string, error) { + v, _ := os.LookupEnv(envString(prefix, flag)) + return v, nil + } +} + +func envString(prefix string, flag *Flag) string { + if env, ok := flag.Tag.Get("env"); ok { + return env + } + + env := strings.ToUpper(flag.Name) + env = strings.Replace(env, "-", "_", -1) + env = prefix + env + + return env +} diff --git a/resolver_test.go b/resolver_test.go new file mode 100755 index 0000000..721cd4e --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,237 @@ +package kong + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type envMap map[string]string + +func tempEnv(env envMap) func() { + for k, v := range env { + os.Setenv(k, v) + } + + return func() { + for k := range env { + os.Unsetenv(k) + } + } +} + +func newEnvParser(t *testing.T, cli interface{}, env envMap) (*Kong, func()) { + t.Helper() + restoreEnv := tempEnv(env) + parser := mustNew(t, cli, Resolver(PerFlagEnvResolver("KONG_"))) + return parser, restoreEnv +} + +func TestEnvResolverFlagBasic(t *testing.T) { + var cli struct { + String string + Slice []int + } + parser, unsetEnvs := newEnvParser(t, &cli, envMap{ + "KONG_STRING": "bye", + "KONG_SLICE": "5,2,9", + }) + defer unsetEnvs() + + _, err := parser.Parse([]string{}) + require.NoError(t, err) + require.Equal(t, "bye", cli.String) + require.Equal(t, []int{5, 2, 9}, cli.Slice) +} + +func TestEnvResolverFlagOverride(t *testing.T) { + var cli struct { + Flag string + } + parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_FLAG": "bye"}) + defer restoreEnv() + + _, err := parser.Parse([]string{"--flag=hello"}) + require.NoError(t, err) + require.Equal(t, "hello", cli.Flag) +} + +func TestEnvResolverOnlyPopulateUsedBranches(t *testing.T) { + // nolint + var cli struct { + UnvisitedArg struct { + UnvisitedArg string `arg` + Int int + } `arg` + UnvisitedCmd struct { + Int int + } `cmd` + Visited struct { + Int int + } `cmd` + } + parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_INT": "512"}) + defer restoreEnv() + + _, err := parser.Parse([]string{"visited"}) + require.NoError(t, err) + + require.Equal(t, 512, cli.Visited.Int) + require.Equal(t, 0, cli.UnvisitedArg.Int) + require.Equal(t, 0, cli.UnvisitedCmd.Int) +} + +func TestEnvResolverTag(t *testing.T) { + var cli struct { + Slice []int `env:"KONG_NUMBERS"` + } + parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_NUMBERS": "5,2,9"}) + defer restoreEnv() + + _, err := parser.Parse([]string{}) + require.NoError(t, err) + require.Equal(t, []int{5, 2, 9}, cli.Slice) +} + +func TestJSONResolverBasic(t *testing.T) { + var cli struct { + String string + Slice []int + SliceWithCommas []string + Bool bool + } + + json := `{ + "string": "🍕", + "slice": [5, 8], + "bool": true, + "slice_with_commas": ["a,b", "c"] + }` + + r, err := JSONResolver(strings.NewReader(json)) + require.NoError(t, err) + + parser := mustNew(t, &cli, Resolver(r)) + _, err = parser.Parse([]string{}) + require.NoError(t, err) + require.Equal(t, "🍕", cli.String) + require.Equal(t, []int{5, 8}, cli.Slice) + require.Equal(t, []string{"a,b", "c"}, cli.SliceWithCommas) + require.True(t, cli.Bool) +} + +func TestResolvedValueTriggersHooks(t *testing.T) { + var cli struct { + Int int + } + resolver := func(context *Context, parent *Path, flag *Flag) (string, error) { + if flag.Name == "int" { + return "1", nil + } + return "", nil + } + hooked := 0 + p := mustNew(t, &cli, Resolver(resolver), Hook(&cli.Int, func(ctx *Context, path *Path) error { + hooked++ + return nil + })) + _, err := p.Parse(nil) + require.NoError(t, err) + require.Equal(t, 1, cli.Int) + require.Equal(t, 1, hooked) + + hooked = 0 + _, err = p.Parse([]string{"--int=2"}) + require.NoError(t, err) + require.Equal(t, 2, cli.Int) + require.Equal(t, 2, hooked) +} + +type testUppercaseMapper struct{} + +func (testUppercaseMapper) Decode(ctx *DecodeContext, target reflect.Value) error { + value := ctx.Scan.PopValue("lowercase") + target.SetString(strings.ToUpper(value)) + return nil +} + +func TestResolversWithMappers(t *testing.T) { + var cli struct { + Flag string `env:"KONG_MOO" type:"upper"` + } + + restoreEnv := tempEnv(envMap{"KONG_MOO": "meow"}) + defer restoreEnv() + + r := PerFlagEnvResolver("KONG_") + + parser := mustNew(t, &cli, + NamedMapper("upper", testUppercaseMapper{}), + Resolver(r), + ) + _, err := parser.Parse([]string{}) + require.NoError(t, err) + require.Equal(t, "MEOW", cli.Flag) +} + +func TestResolverWithBool(t *testing.T) { + var cli struct { + Bool bool + } + + resolver := func(context *Context, parent *Path, flag *Flag) (string, error) { + if flag.Name == "bool" { + return "true", nil + } + return "", nil + } + + p := mustNew(t, &cli, Resolver(resolver)) + + _, err := p.Parse(nil) + require.NoError(t, err) + require.True(t, cli.Bool) +} + +func TestLastResolverWins(t *testing.T) { + var cli struct { + Int []int + } + + var first ResolverFunc = func(context *Context, parent *Path, flag *Flag) (string, error) { + if flag.Name == "int" { + return "1", nil + } + return "", nil + } + + var second ResolverFunc = func(context *Context, parent *Path, flag *Flag) (string, error) { + if flag.Name == "int" { + return "2", nil + } + return "", nil + } + + p := mustNew(t, &cli, Resolver(first), Resolver(second)) + _, err := p.Parse(nil) + require.NoError(t, err) + require.Equal(t, []int{2}, cli.Int) +} + +func TestResolverSatisfiesRequired(t *testing.T) { + var cli struct { + Int int `required` + } + resolver := func(context *Context, parent *Path, flag *Flag) (string, error) { + if flag.Name == "int" { + return "1", nil + } + return "", nil + } + _, err := mustNew(t, &cli, Resolver(resolver)).Parse(nil) + require.NoError(t, err) + require.Equal(t, 1, cli.Int) +} diff --git a/scanner.go b/scanner.go index 9564787..583531b 100644 --- a/scanner.go +++ b/scanner.go @@ -7,8 +7,10 @@ import ( //go:generate stringer -type=TokenType +// TokenType is the type of a token. type TokenType int +// Token types. const ( UntypedToken TokenType = iota EOLToken @@ -128,14 +130,17 @@ func (s *Scanner) Peek() Token { return s.args[0] } -func (s *Scanner) Push(arg string) { +func (s *Scanner) Push(arg string) *Scanner { s.PushToken(Token{Value: arg}) + return s } -func (s *Scanner) PushTyped(arg string, typ TokenType) { +func (s *Scanner) PushTyped(arg string, typ TokenType) *Scanner { s.PushToken(Token{Value: arg, Type: typ}) + return s } -func (s *Scanner) PushToken(token Token) { +func (s *Scanner) PushToken(token Token) *Scanner { s.args = append([]Token{token}, s.args...) + return s } diff --git a/tag.go b/tag.go index 407f924..fd445c0 100644 --- a/tag.go +++ b/tag.go @@ -21,7 +21,7 @@ type Tag struct { Env string Short rune Hidden bool - Sep string + Sep rune // Storage for all tag keys for arbitrary lookups. items map[string]string @@ -128,12 +128,12 @@ func parseTag(fv reflect.Value, ft reflect.StructField) *Tag { t.Short, _ = t.GetRune("short") t.Hidden = t.Has("hidden") t.Format, _ = t.Get("format") - t.Sep, _ = t.Get("sep") - if t.Sep == "" { + t.Sep, _ = t.GetRune("sep") + if t.Sep == 0 { if t.Cmd || t.Arg { - t.Sep = " " + t.Sep = ' ' } else { - t.Sep = "," + t.Sep = ',' } } diff --git a/tag_test.go b/tag_test.go index 167aa71..49ce4da 100644 --- a/tag_test.go +++ b/tag_test.go @@ -65,6 +65,7 @@ func TestEscapedQuote(t *testing.T) { } func TestBareTags(t *testing.T) { + // nolint: govet var cli struct { Cmd struct { Arg string `arg` @@ -80,6 +81,7 @@ func TestBareTags(t *testing.T) { } func TestBareTagsWithJsonTag(t *testing.T) { + // nolint: govet var cli struct { Cmd struct { Arg string `json:"-" optional arg` @@ -95,6 +97,7 @@ func TestBareTagsWithJsonTag(t *testing.T) { } func TestManySeps(t *testing.T) { + // nolint: govet var cli struct { Arg string `arg optional default:"hi"` }