Implement flag "resolvers". (#24)

* Propagate errors.
* Use junit test output.
* Expand role of DecodeContext to include Scanner.
* Inject resolved flags as Path elements in the Context.
  This allows all existing logic to apply seamlessly: hooks, required
flags, etc.
* Clarify that hooks can be called multiple times.
This commit is contained in:
Alec Thomas
2018-06-12 07:20:55 +10:00
committed by Gerald Kaszuba
parent 73064a687f
commit e9d88d6528
16 changed files with 579 additions and 58 deletions
+14 -2
View File
@@ -7,5 +7,17 @@ jobs:
working_directory: /go/src/github.com/alecthomas/kong
steps:
- checkout
- run: go get -v -t -d ./...
- run: go test -v ./...
- run:
name: Prepare
command: |
go get -v github.com/jstemmer/go-junit-report
go get -v -t -d ./...
mkdir ~/report
when: always
- run:
name: Test
command: |
go test -v ./... 2>&1 | tee report.txt && go-junit-report report.txt > ~/report/junit.xml
- store_test_results:
path: ~/report
+4 -4
View File
@@ -61,7 +61,7 @@ eg.
```
$ shell --help
usage: shell [<flags>]
usage: shell <command>
A shell-like example app.
@@ -70,10 +70,10 @@ Flags:
--debug Debug mode.
Commands:
rm [<flags>] <paths> ...
rm <paths> ...
Remove files.
ls [<flags>] [<paths> ...]
ls [<paths> ...]
List paths.
```
@@ -83,7 +83,7 @@ eg.
```
$ shell --help rm
usage: shell rm [<flags>] <paths> ...
usage: shell rm <paths> ...
Remove files.
Regular → Executable
+60 -11
View File
@@ -23,8 +23,12 @@ type Path struct {
// Parsed value for non-commands.
Value reflect.Value
// True if this Path element was created as the result of a resolver.
Resolved bool
}
// Context contains the current parse context.
type Context struct {
App *Kong
Path []*Path // A trace through parsed nodes.
@@ -64,9 +68,14 @@ func Trace(k *Kong, args []string) (*Context, error) {
return nil, err
}
c.Error = c.trace(&c.App.Model.Node)
err = c.traceResolvers()
if err != nil {
return nil, err
}
return c, nil
}
// Validate the current context.
func (c *Context) Validate() error {
for _, path := range c.Path {
if err := checkMissingFlags(path.Flags); err != nil {
@@ -258,7 +267,6 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
Parent: node,
Positional: arg,
Value: value,
Flags: node.Flags,
})
positional++
break
@@ -272,7 +280,7 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
Parent: node,
Command: branch,
Value: branch.Target,
Flags: node.Flags,
Flags: branch.Flags,
})
return c.trace(branch)
}
@@ -287,7 +295,7 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
Parent: node,
Argument: branch,
Value: value,
Flags: node.Flags,
Flags: branch.Flags,
})
return c.trace(branch)
}
@@ -305,8 +313,10 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
// Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) {
path := []string{}
for _, trace := range c.Path {
switch {
case trace.App != nil:
case trace.Argument != nil:
path = append(path, "<"+trace.Argument.Name+">")
trace.Argument.Argument.Apply(trace.Value)
@@ -317,25 +327,64 @@ func (c *Context) Apply() (string, error) {
case trace.Positional != nil:
path = append(path, "<"+trace.Positional.Name+">")
trace.Positional.Apply(trace.Value)
default:
panic("unsupported path ?!")
}
}
return strings.Join(path, " "), nil
}
// Walk through flags from existing nodes in the path.
func (c *Context) traceResolvers() error {
if len(c.App.resolvers) == 0 {
return nil
}
inserted := []*Path{}
for _, path := range c.Path {
for _, flag := range path.Flags {
for _, resolver := range c.App.resolvers {
s, err := resolver(c, path, flag)
if err != nil {
return err
}
if s == "" {
continue
}
scan := Scan().PushTyped(s, FlagValueToken)
value, err := flag.Parse(scan)
if err != nil {
return err
}
inserted = append(inserted, &Path{
Flag: flag,
Value: value,
Resolved: true,
})
}
}
}
c.Path = append(inserted, c.Path...)
return nil
}
func (c *Context) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) {
defer catch(&err)
token := c.scan.Peek()
for _, flag := range flags {
// Found a matching flag.
if matcher(flag) {
c.scan.Pop()
value, err := flag.Parse(c.scan)
if err != nil {
return err
}
c.Path = append(c.Path, &Path{Flag: flag, Value: value})
return nil
if !matcher(flag) {
continue
}
c.scan.Pop()
value, err := flag.Parse(c.scan)
if err != nil {
return err
}
c.Path = append(c.Path, &Path{Flag: flag, Value: value})
return nil
}
return fmt.Errorf("unknown flag --%s", token.Value)
}
-1
View File
@@ -1 +0,0 @@
package kong
+1
View File
@@ -8,6 +8,7 @@ import (
)
func TestHelp(t *testing.T) {
// nolint: govet
var cli struct {
String string `help:"A string flag."`
Bool bool `help:"A bool flag with very long help that wraps a lot and is verbose and is really verbose."`
+4 -2
View File
@@ -17,6 +17,7 @@ func fail(format string, args ...interface{}) {
panic(Error{fmt.Sprintf(format, args...)})
}
// Must creates a new Parser or panics if there is an error.
func Must(ast interface{}, options ...Option) *Kong {
k, err := New(ast, options...)
if err != nil {
@@ -37,6 +38,7 @@ type Kong struct {
Stderr io.Writer
before map[reflect.Value]HookFunc
resolvers []ResolverFunc
registry *Registry
noDefaultHelp bool
help func(*Context) error
@@ -105,7 +107,7 @@ func (k *Kong) extraFlags() []*Flag {
return []*Flag{helpFlag}
}
// Path parses the command-line, validating and collecting matching grammar nodes.
// Trace parses the command-line, validating and collecting matching grammar nodes.
func (k *Kong) Trace(args []string) (*Context, error) {
return Trace(k, args)
}
@@ -171,7 +173,7 @@ func (k *Kong) Errorf(format string, args ...interface{}) {
fmt.Fprintf(k.Stderr, k.Model.Name+": error: "+format, args...)
}
// FatalIfError terminates with an error message if err != nil.
// FatalIfErrorf terminates with an error message if err != nil.
func (k *Kong) FatalIfErrorf(err error, args ...interface{}) {
if err == nil {
return
+43 -1
View File
@@ -101,9 +101,19 @@ func TestFlagSlice(t *testing.T) {
require.Equal(t, []int{1, 2, 3}, cli.Slice)
}
func TestFlagSliceWithSeparator(t *testing.T) {
var cli struct {
Slice []string
}
parser := mustNew(t, &cli)
_, err := parser.Parse([]string{`--slice=a\,b,c`})
require.NoError(t, err)
require.Equal(t, []string{"a,b", "c"}, cli.Slice)
}
func TestArgSlice(t *testing.T) {
var cli struct {
Slice []int `kong:"arg"`
Slice []int `arg`
Flag bool
}
parser := mustNew(t, &cli)
@@ -113,6 +123,18 @@ func TestArgSlice(t *testing.T) {
require.Equal(t, true, cli.Flag)
}
func TestArgSliceWithSeparator(t *testing.T) {
var cli struct {
Slice []string `arg`
Flag bool
}
parser := mustNew(t, &cli)
_, err := parser.Parse([]string{"a,b", "c", "--flag"})
require.NoError(t, err)
require.Equal(t, []string{"a,b", "c"}, cli.Slice)
require.Equal(t, true, cli.Flag)
}
func TestUnsupportedFieldErrors(t *testing.T) {
var cli struct {
Keys map[string]string
@@ -356,3 +378,23 @@ func TestShort(t *testing.T) {
require.True(t, cli.Bool)
require.Equal(t, "hello", cli.String)
}
func TestDuplicateFlagChoosesLast(t *testing.T) {
var cli struct {
Flag int
}
_, err := mustNew(t, &cli).Parse([]string{"--flag=1", "--flag=2"})
require.NoError(t, err)
require.Equal(t, 2, cli.Flag)
}
func TestDuplicateSliceDoesNotAccumulate(t *testing.T) {
var cli struct {
Flag []int
}
_, err := mustNew(t, &cli).Parse([]string{"--flag=1,2", "--flag=3,4"})
require.NoError(t, err)
require.Equal(t, []int{3, 4}, cli.Flag)
}
+84 -25
View File
@@ -9,16 +9,30 @@ import (
"time"
)
type DecoderContext struct {
// DecodeContext is passed to a Mapper's Decode().
//
// It contains the Value being decoded into and the Scanner to parse from.
type DecodeContext struct {
// Value being decoded into.
Value *Value
// Scan contains the input to scan into Target.
Scan *Scanner
}
// WithScanner creates a clone of this context with a new Scanner.
func (d *DecodeContext) WithScanner(scan *Scanner) *DecodeContext {
return &DecodeContext{
Value: d.Value,
Scan: scan,
}
}
// A Mapper represents how a field is mapped from command-line values to Go.
//
// Mappers can be associated with concrete fields via pointer, reflect.Type, reflect.Kind, or via a "type" tag.
type Mapper interface {
Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error
// Decode ctx.Value with ctx.Scanner into target.
Decode(ctx *DecodeContext, target reflect.Value) error
}
// A BoolMapper is a Mapper to a value that is a boolean.
@@ -27,13 +41,14 @@ type BoolMapper interface {
IsBool() bool
}
type MapperFunc func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error
// A MapperFunc is a single function that complies with the Mapper interface.
type MapperFunc func(ctx *DecodeContext, target reflect.Value) error
func (d MapperFunc) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
return d(ctx, scan, target)
func (d MapperFunc) Decode(ctx *DecodeContext, target reflect.Value) error { //nolint: golint
return d(ctx, target)
}
// A Registry encapsulates a set of fields and lookups to resolve them.
// A Registry contains a set of mappers and supporting lookup methods.
type Registry struct {
names map[string]Mapper
types map[reflect.Type]Mapper
@@ -41,6 +56,7 @@ type Registry struct {
values map[reflect.Value]Mapper
}
// NewRegistry creates a new (empty) Registry.
func NewRegistry() *Registry {
return &Registry{
names: map[string]Mapper{},
@@ -60,6 +76,7 @@ func (d *Registry) ForNamedType(name string, value reflect.Value) Mapper {
return d.ForValue(value)
}
// ForValue looks up the Mapper for a reflect.Value.
func (d *Registry) ForValue(value reflect.Value) Mapper {
if mapper, ok := d.values[value]; ok {
return mapper
@@ -67,7 +84,7 @@ func (d *Registry) ForValue(value reflect.Value) Mapper {
return d.ForType(value.Type())
}
// DecoderForType finds a mapper from a type or kind.
// ForType finds a mapper from a type, by type, then kind.
//
// Will return nil if a mapper can not be determined.
func (d *Registry) ForType(typ reflect.Type) Mapper {
@@ -81,6 +98,7 @@ func (d *Registry) ForType(typ reflect.Type) Mapper {
return nil
}
// RegisterKind registers a Mapper for a reflect.Kind.
func (d *Registry) RegisterKind(kind reflect.Kind, mapper Mapper) *Registry {
d.kinds[kind] = mapper
return d
@@ -97,12 +115,13 @@ func (d *Registry) RegisterName(name string, mapper Mapper) *Registry {
return d
}
// RegisterType registers a Mapper for a reflect.Type.
func (d *Registry) RegisterType(typ reflect.Type, mapper Mapper) *Registry {
d.types[typ] = mapper
return d
}
// RegisterValue registers a mapper by a pointer to the mapper value.
// RegisterValue registers a Mapper by pointer to the field value.
func (d *Registry) RegisterValue(ptr interface{}, mapper Mapper) *Registry {
key := reflect.ValueOf(ptr)
if key.Kind() != reflect.Ptr {
@@ -113,6 +132,7 @@ func (d *Registry) RegisterValue(ptr interface{}, mapper Mapper) *Registry {
return d
}
// RegisterDefaults registers Mappers for all builtin supported Go types and some common stdlib types.
func (d *Registry) RegisterDefaults() *Registry {
return d.RegisterKind(reflect.Int, intDecoder(bits.UintSize)).
RegisterKind(reflect.Int8, intDecoder(8)).
@@ -126,8 +146,8 @@ func (d *Registry) RegisterDefaults() *Registry {
RegisterKind(reflect.Uint64, uintDecoder(64)).
RegisterKind(reflect.Float32, floatDecoder(32)).
RegisterKind(reflect.Float64, floatDecoder(64)).
RegisterKind(reflect.String, MapperFunc(func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
target.SetString(scan.PopValue("string"))
RegisterKind(reflect.String, MapperFunc(func(ctx *DecodeContext, target reflect.Value) error {
target.SetString(ctx.Scan.PopValue("string"))
return nil
})).
RegisterKind(reflect.Bool, boolMapper{}).
@@ -138,15 +158,15 @@ func (d *Registry) RegisterDefaults() *Registry {
type boolMapper struct{}
func (boolMapper) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
func (boolMapper) Decode(ctx *DecodeContext, target reflect.Value) error {
target.SetBool(true)
return nil
}
func (boolMapper) IsBool() bool { return true }
func durationDecoder() MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
d, err := time.ParseDuration(scan.PopValue("duration"))
return func(ctx *DecodeContext, target reflect.Value) error {
d, err := time.ParseDuration(ctx.Scan.PopValue("duration"))
if err != nil {
return err
}
@@ -156,12 +176,12 @@ func durationDecoder() MapperFunc {
}
func timeDecoder() MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
return func(ctx *DecodeContext, target reflect.Value) error {
fmt := time.RFC3339
if ctx.Value.Format != "" {
fmt = ctx.Value.Format
}
t, err := time.Parse(fmt, scan.PopValue("time"))
t, err := time.Parse(fmt, ctx.Scan.PopValue("time"))
if err != nil {
return err
}
@@ -171,8 +191,8 @@ func timeDecoder() MapperFunc {
}
func intDecoder(bits int) MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
value := scan.PopValue("int")
return func(ctx *DecodeContext, target reflect.Value) error {
value := ctx.Scan.PopValue("int")
n, err := strconv.ParseInt(value, 10, bits)
if err != nil {
return fmt.Errorf("invalid int %q", value)
@@ -183,8 +203,8 @@ func intDecoder(bits int) MapperFunc {
}
func uintDecoder(bits int) MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
value := scan.PopValue("uint")
return func(ctx *DecodeContext, target reflect.Value) error {
value := ctx.Scan.PopValue("uint")
n, err := strconv.ParseUint(value, 10, bits)
if err != nil {
return fmt.Errorf("invalid uint %q", value)
@@ -195,8 +215,8 @@ func uintDecoder(bits int) MapperFunc {
}
func floatDecoder(bits int) MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
value := scan.PopValue("float")
return func(ctx *DecodeContext, target reflect.Value) error {
value := ctx.Scan.PopValue("float")
n, err := strconv.ParseFloat(value, bits)
if err != nil {
return fmt.Errorf("invalid float %q", value)
@@ -207,15 +227,15 @@ func floatDecoder(bits int) MapperFunc {
}
func sliceDecoder(d *Registry) MapperFunc {
return func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
return func(ctx *DecodeContext, target reflect.Value) error {
el := target.Type().Elem()
sep := ctx.Value.Tag.Sep
var childScanner *Scanner
if ctx.Value.Flag != nil {
// If decoding a flag, we need an argument.
childScanner = Scan(strings.Split(scan.PopValue("list"), sep)...)
childScanner = Scan(SplitEscaped(ctx.Scan.PopValue("list"), sep)...)
} else {
tokens := scan.PopUntil(func(t Token) bool { return !t.IsValue() })
tokens := ctx.Scan.PopUntil(func(t Token) bool { return !t.IsValue() })
childScanner = Scan(tokens...)
}
childDecoder := d.ForType(el)
@@ -224,7 +244,7 @@ func sliceDecoder(d *Registry) MapperFunc {
}
for childScanner.Peek().Type != EOLToken {
childValue := reflect.New(el).Elem()
err := childDecoder.Decode(ctx, childScanner, childValue)
err := childDecoder.Decode(ctx.WithScanner(childScanner), childValue)
if err != nil {
return err
}
@@ -233,3 +253,42 @@ func sliceDecoder(d *Registry) MapperFunc {
return nil
}
}
// SplitEscaped splits a string on a separator.
//
// It differs from strings.Split() in that the separator can exist in a field by escaping it with a \. eg.
//
// SplitEscaped(`hello\,there,bob`, ',') == []string{"hello,there", "bob"}
func SplitEscaped(s string, sep rune) (out []string) {
escaped := false
token := ""
for _, ch := range s {
if escaped {
token += string(ch)
escaped = false
} else if ch == '\\' {
escaped = true
} else if ch == sep && !escaped {
out = append(out, token)
token = ""
escaped = false
} else {
token += string(ch)
}
}
if token != "" {
out = append(out, token)
}
return
}
// JoinEscaped joins a slice of strings on sep, but also escapes any instances of sep in the fields with \. eg.
//
// JoinEscaped([]string{"hello,there", "bob"}, ',') == `hello\,there,bob`
func JoinEscaped(s []string, sep rune) string {
escaped := []string{}
for _, e := range s {
escaped = append(escaped, strings.Replace(e, string(sep), `\`+string(sep), -1))
}
return strings.Join(escaped, string(sep))
}
+12 -1
View File
@@ -36,7 +36,7 @@ func TestNamedMapper(t *testing.T) {
type testMooMapper struct{}
func (testMooMapper) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
func (testMooMapper) Decode(ctx *DecodeContext, target reflect.Value) error {
target.SetString("MOO")
return nil
}
@@ -64,3 +64,14 @@ func TestDurationMapper(t *testing.T) {
require.NoError(t, err)
require.Equal(t, time.Second*5, cli.Flag)
}
func TestSplitEscaped(t *testing.T) {
require.Equal(t, []string{"a", "b"}, SplitEscaped("a,b", ','))
require.Equal(t, []string{"a,b", "c"}, SplitEscaped(`a\,b,c`, ','))
}
func TestJoinEscaped(t *testing.T) {
require.Equal(t, `a,b`, JoinEscaped([]string{"a", "b"}, ','))
require.Equal(t, `a\,b,c`, JoinEscaped([]string{`a,b`, `c`}, ','))
require.Equal(t, JoinEscaped(SplitEscaped(`a\,b,c`, ','), ','), `a\,b,c`)
}
+7 -1
View File
@@ -139,6 +139,7 @@ type Value struct {
Position int // Position (for positional arguments).
}
// Summary returns a human-readable summary of the value.
func (v *Value) Summary() string {
if v.Flag != nil {
if v.IsBool() {
@@ -156,10 +157,12 @@ func (v *Value) Summary() string {
return argText
}
// IsCumulative returns true of the value is a slice.
func (v *Value) IsCumulative() bool {
return v.Value.Kind() == reflect.Slice
}
// IsBool returns true if the underlying value is a boolean.
func (v *Value) IsBool() bool {
if m, ok := v.Mapper.(BoolMapper); ok && m.IsBool() {
return true
@@ -170,7 +173,7 @@ func (v *Value) IsBool() bool {
// Parse tokens into value, parse, and validate, but do not write to the field.
func (v *Value) Parse(scan *Scanner) (reflect.Value, error) {
value := reflect.New(v.Value.Type()).Elem()
err := v.Mapper.Decode(&DecoderContext{Value: v}, scan, value)
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, value)
if err == nil {
v.Set = true
}
@@ -196,8 +199,10 @@ func (v *Value) Reset() error {
return nil
}
// A Positional represents a non-branching command-line positional argument.
type Positional = Value
// A Flag represents a command-line flag.
type Flag struct {
*Value
PlaceHolder string
@@ -217,6 +222,7 @@ func (f *Flag) String() string {
return out
}
// FormatPlaceHolder formats the placeholder string for a Flag.
func (f *Flag) FormatPlaceHolder() string {
tail := ""
if f.Value.IsCumulative() {
Regular → Executable
+13 -2
View File
@@ -69,9 +69,12 @@ func Writers(stdout, stderr io.Writer) Option {
// HookFunc is a callback tied to a field of the grammar, called before a value is applied.
type HookFunc func(ctx *Context, path *Path) error
// Hook to aply before a command, flag or positional argument is encountered.
// Hook to apply before a command, flag or positional argument is encountered.
//
// "ptr" is a pointer to a field of the grammar.
//
// Note that the hook will be called once for each time the corresponding node is encountered. This means that if a flag
// is passed twice, its hook will be called twice.
func Hook(ptr interface{}, hook HookFunc) Option {
key := reflect.ValueOf(ptr)
if key.Kind() != reflect.Ptr {
@@ -82,13 +85,21 @@ func Hook(ptr interface{}, hook HookFunc) Option {
}
}
// HelpFunction is the type of a function used to display help.
type HelpFunction func(*Context) error
// Help function to use.
//
// Defaults to PrintHelp.
func Help(help func(*Context) error) Option {
func Help(help HelpFunction) Option {
return func(k *Kong) {
k.help = help
}
}
// Resolver registers flag resolvers.
func Resolver(resolvers ...ResolverFunc) Option {
return func(k *Kong) {
k.resolvers = append(k.resolvers, resolvers...)
}
}
Executable
+84
View File
@@ -0,0 +1,84 @@
package kong
import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
)
// ResolverFunc resolves a Flag value from an external source.
type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error)
// JSONResolver returns a Resolver that retrieves values from a JSON source.
//
// Hyphens in flag names are replaced with underscores.
func JSONResolver(r io.Reader) (ResolverFunc, error) {
values := map[string]interface{}{}
err := json.NewDecoder(r).Decode(&values)
if err != nil {
return nil, err
}
f := func(context *Context, parent *Path, flag *Flag) (string, error) {
name := strings.Replace(flag.Name, "-", "_", -1)
raw, ok := values[name]
if !ok {
return "", nil
}
value, err := jsonDecodeValue(flag.Tag.Sep, raw)
if err != nil {
return "", err
}
return value, nil
}
return f, nil
}
func jsonDecodeValue(sep rune, value interface{}) (string, error) {
switch v := value.(type) {
case string:
return v, nil
case float64:
return fmt.Sprintf("%v", v), nil
case []interface{}:
out := []string{}
for _, el := range v {
sel, err := jsonDecodeValue(sep, el)
if err != nil {
return "", err
}
out = append(out, sel)
}
return JoinEscaped(out, sep), nil
case bool:
if v {
return "true", nil
}
return "false", nil
}
return "", fmt.Errorf("unsupported JSON value %v (of type %T)", value, value)
}
// PerFlagEnvResolver automatically determines environment variables based on the name of each flag, transformed to
// uppercase and underscored, e.g. `my-flag` -> `MY_FLAG` The environment variable key can be overridden with the `env`
// tag.
func PerFlagEnvResolver(prefix string) ResolverFunc {
return func(context *Context, parent *Path, flag *Flag) (string, error) {
v, _ := os.LookupEnv(envString(prefix, flag))
return v, nil
}
}
func envString(prefix string, flag *Flag) string {
if env, ok := flag.Tag.Get("env"); ok {
return env
}
env := strings.ToUpper(flag.Name)
env = strings.Replace(env, "-", "_", -1)
env = prefix + env
return env
}
+237
View File
@@ -0,0 +1,237 @@
package kong
import (
"os"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type envMap map[string]string
func tempEnv(env envMap) func() {
for k, v := range env {
os.Setenv(k, v)
}
return func() {
for k := range env {
os.Unsetenv(k)
}
}
}
func newEnvParser(t *testing.T, cli interface{}, env envMap) (*Kong, func()) {
t.Helper()
restoreEnv := tempEnv(env)
parser := mustNew(t, cli, Resolver(PerFlagEnvResolver("KONG_")))
return parser, restoreEnv
}
func TestEnvResolverFlagBasic(t *testing.T) {
var cli struct {
String string
Slice []int
}
parser, unsetEnvs := newEnvParser(t, &cli, envMap{
"KONG_STRING": "bye",
"KONG_SLICE": "5,2,9",
})
defer unsetEnvs()
_, err := parser.Parse([]string{})
require.NoError(t, err)
require.Equal(t, "bye", cli.String)
require.Equal(t, []int{5, 2, 9}, cli.Slice)
}
func TestEnvResolverFlagOverride(t *testing.T) {
var cli struct {
Flag string
}
parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_FLAG": "bye"})
defer restoreEnv()
_, err := parser.Parse([]string{"--flag=hello"})
require.NoError(t, err)
require.Equal(t, "hello", cli.Flag)
}
func TestEnvResolverOnlyPopulateUsedBranches(t *testing.T) {
// nolint
var cli struct {
UnvisitedArg struct {
UnvisitedArg string `arg`
Int int
} `arg`
UnvisitedCmd struct {
Int int
} `cmd`
Visited struct {
Int int
} `cmd`
}
parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_INT": "512"})
defer restoreEnv()
_, err := parser.Parse([]string{"visited"})
require.NoError(t, err)
require.Equal(t, 512, cli.Visited.Int)
require.Equal(t, 0, cli.UnvisitedArg.Int)
require.Equal(t, 0, cli.UnvisitedCmd.Int)
}
func TestEnvResolverTag(t *testing.T) {
var cli struct {
Slice []int `env:"KONG_NUMBERS"`
}
parser, restoreEnv := newEnvParser(t, &cli, envMap{"KONG_NUMBERS": "5,2,9"})
defer restoreEnv()
_, err := parser.Parse([]string{})
require.NoError(t, err)
require.Equal(t, []int{5, 2, 9}, cli.Slice)
}
func TestJSONResolverBasic(t *testing.T) {
var cli struct {
String string
Slice []int
SliceWithCommas []string
Bool bool
}
json := `{
"string": "🍕",
"slice": [5, 8],
"bool": true,
"slice_with_commas": ["a,b", "c"]
}`
r, err := JSONResolver(strings.NewReader(json))
require.NoError(t, err)
parser := mustNew(t, &cli, Resolver(r))
_, err = parser.Parse([]string{})
require.NoError(t, err)
require.Equal(t, "🍕", cli.String)
require.Equal(t, []int{5, 8}, cli.Slice)
require.Equal(t, []string{"a,b", "c"}, cli.SliceWithCommas)
require.True(t, cli.Bool)
}
func TestResolvedValueTriggersHooks(t *testing.T) {
var cli struct {
Int int
}
resolver := func(context *Context, parent *Path, flag *Flag) (string, error) {
if flag.Name == "int" {
return "1", nil
}
return "", nil
}
hooked := 0
p := mustNew(t, &cli, Resolver(resolver), Hook(&cli.Int, func(ctx *Context, path *Path) error {
hooked++
return nil
}))
_, err := p.Parse(nil)
require.NoError(t, err)
require.Equal(t, 1, cli.Int)
require.Equal(t, 1, hooked)
hooked = 0
_, err = p.Parse([]string{"--int=2"})
require.NoError(t, err)
require.Equal(t, 2, cli.Int)
require.Equal(t, 2, hooked)
}
type testUppercaseMapper struct{}
func (testUppercaseMapper) Decode(ctx *DecodeContext, target reflect.Value) error {
value := ctx.Scan.PopValue("lowercase")
target.SetString(strings.ToUpper(value))
return nil
}
func TestResolversWithMappers(t *testing.T) {
var cli struct {
Flag string `env:"KONG_MOO" type:"upper"`
}
restoreEnv := tempEnv(envMap{"KONG_MOO": "meow"})
defer restoreEnv()
r := PerFlagEnvResolver("KONG_")
parser := mustNew(t, &cli,
NamedMapper("upper", testUppercaseMapper{}),
Resolver(r),
)
_, err := parser.Parse([]string{})
require.NoError(t, err)
require.Equal(t, "MEOW", cli.Flag)
}
func TestResolverWithBool(t *testing.T) {
var cli struct {
Bool bool
}
resolver := func(context *Context, parent *Path, flag *Flag) (string, error) {
if flag.Name == "bool" {
return "true", nil
}
return "", nil
}
p := mustNew(t, &cli, Resolver(resolver))
_, err := p.Parse(nil)
require.NoError(t, err)
require.True(t, cli.Bool)
}
func TestLastResolverWins(t *testing.T) {
var cli struct {
Int []int
}
var first ResolverFunc = func(context *Context, parent *Path, flag *Flag) (string, error) {
if flag.Name == "int" {
return "1", nil
}
return "", nil
}
var second ResolverFunc = func(context *Context, parent *Path, flag *Flag) (string, error) {
if flag.Name == "int" {
return "2", nil
}
return "", nil
}
p := mustNew(t, &cli, Resolver(first), Resolver(second))
_, err := p.Parse(nil)
require.NoError(t, err)
require.Equal(t, []int{2}, cli.Int)
}
func TestResolverSatisfiesRequired(t *testing.T) {
var cli struct {
Int int `required`
}
resolver := func(context *Context, parent *Path, flag *Flag) (string, error) {
if flag.Name == "int" {
return "1", nil
}
return "", nil
}
_, err := mustNew(t, &cli, Resolver(resolver)).Parse(nil)
require.NoError(t, err)
require.Equal(t, 1, cli.Int)
}
+8 -3
View File
@@ -7,8 +7,10 @@ import (
//go:generate stringer -type=TokenType
// TokenType is the type of a token.
type TokenType int
// Token types.
const (
UntypedToken TokenType = iota
EOLToken
@@ -128,14 +130,17 @@ func (s *Scanner) Peek() Token {
return s.args[0]
}
func (s *Scanner) Push(arg string) {
func (s *Scanner) Push(arg string) *Scanner {
s.PushToken(Token{Value: arg})
return s
}
func (s *Scanner) PushTyped(arg string, typ TokenType) {
func (s *Scanner) PushTyped(arg string, typ TokenType) *Scanner {
s.PushToken(Token{Value: arg, Type: typ})
return s
}
func (s *Scanner) PushToken(token Token) {
func (s *Scanner) PushToken(token Token) *Scanner {
s.args = append([]Token{token}, s.args...)
return s
}
+5 -5
View File
@@ -21,7 +21,7 @@ type Tag struct {
Env string
Short rune
Hidden bool
Sep string
Sep rune
// Storage for all tag keys for arbitrary lookups.
items map[string]string
@@ -128,12 +128,12 @@ func parseTag(fv reflect.Value, ft reflect.StructField) *Tag {
t.Short, _ = t.GetRune("short")
t.Hidden = t.Has("hidden")
t.Format, _ = t.Get("format")
t.Sep, _ = t.Get("sep")
if t.Sep == "" {
t.Sep, _ = t.GetRune("sep")
if t.Sep == 0 {
if t.Cmd || t.Arg {
t.Sep = " "
t.Sep = ' '
} else {
t.Sep = ","
t.Sep = ','
}
}
+3
View File
@@ -65,6 +65,7 @@ func TestEscapedQuote(t *testing.T) {
}
func TestBareTags(t *testing.T) {
// nolint: govet
var cli struct {
Cmd struct {
Arg string `arg`
@@ -80,6 +81,7 @@ func TestBareTags(t *testing.T) {
}
func TestBareTagsWithJsonTag(t *testing.T) {
// nolint: govet
var cli struct {
Cmd struct {
Arg string `json:"-" optional arg`
@@ -95,6 +97,7 @@ func TestBareTagsWithJsonTag(t *testing.T) {
}
func TestManySeps(t *testing.T) {
// nolint: govet
var cli struct {
Arg string `arg optional default:"hi"`
}