Slice support.

This commit is contained in:
Alec Thomas
2018-05-17 10:40:46 +10:00
parent 1d00dfef7b
commit 8f26b13088
6 changed files with 69 additions and 27 deletions
+1
View File
@@ -85,6 +85,7 @@ func buildNode(v reflect.Value) *Node {
Help: help, Help: help,
Decoder: decoder, Decoder: decoder,
Value: fv, Value: fv,
Field: ft,
Required: !optional || required, Required: !optional || required,
} }
if arg { if arg {
+40 -12
View File
@@ -4,15 +4,23 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
) )
type Decoder interface { type DecoderContext struct {
Decode(scan *Scanner, target reflect.Value) error // Value being decoded into.
Value *Value
} }
type DecoderFunc func(scan *Scanner, target reflect.Value) error type Decoder interface {
Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error
}
func (d DecoderFunc) Decode(scan *Scanner, target reflect.Value) error { return d(scan, target) } type DecoderFunc func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error
func (d DecoderFunc) Decode(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
return d(ctx, scan, target)
}
var _ Decoder = DecoderFunc(nil) var _ Decoder = DecoderFunc(nil)
@@ -76,6 +84,8 @@ var (
kindDecoders map[reflect.Kind]KindDecoder kindDecoders map[reflect.Kind]KindDecoder
) )
// DecoderForField finds a decoder for a struct field.
//
func DecoderForField(field reflect.StructField) Decoder { func DecoderForField(field reflect.StructField) Decoder {
name, ok := field.Tag.Lookup("type") name, ok := field.Tag.Lookup("type")
if ok { if ok {
@@ -116,7 +126,7 @@ func RegisterDecoder(decoders ...Decoder) {
} }
func init() { func init() {
intDecoder := NewKindDecoder(reflect.Int, func(scan *Scanner, target reflect.Value) error { intDecoder := NewKindDecoder(reflect.Int, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
n, err := strconv.ParseInt(scan.PopValue("int"), 10, 64) n, err := strconv.ParseInt(scan.PopValue("int"), 10, 64)
if err != nil { if err != nil {
return err return err
@@ -124,7 +134,7 @@ func init() {
target.SetInt(n) target.SetInt(n)
return nil return nil
}) })
uintDecoder := NewKindDecoder(reflect.Uint, func(scan *Scanner, target reflect.Value) error { uintDecoder := NewKindDecoder(reflect.Uint, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
n, err := strconv.ParseUint(scan.PopValue("uint"), 10, 64) n, err := strconv.ParseUint(scan.PopValue("uint"), 10, 64)
if err != nil { if err != nil {
return err return err
@@ -143,7 +153,7 @@ func init() {
reflect.Uint16: uintDecoder, reflect.Uint16: uintDecoder,
reflect.Uint32: uintDecoder, reflect.Uint32: uintDecoder,
reflect.Uint64: uintDecoder, reflect.Uint64: uintDecoder,
reflect.Float32: NewKindDecoder(reflect.Float32, func(scan *Scanner, target reflect.Value) error { reflect.Float32: NewKindDecoder(reflect.Float32, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
n, err := strconv.ParseFloat(scan.PopValue("float"), 32) n, err := strconv.ParseFloat(scan.PopValue("float"), 32)
if err != nil { if err != nil {
return err return err
@@ -151,7 +161,7 @@ func init() {
target.SetFloat(n) target.SetFloat(n)
return nil return nil
}), }),
reflect.Float64: NewKindDecoder(reflect.Float64, func(scan *Scanner, target reflect.Value) error { reflect.Float64: NewKindDecoder(reflect.Float64, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
n, err := strconv.ParseFloat(scan.PopValue("float"), 64) n, err := strconv.ParseFloat(scan.PopValue("float"), 64)
if err != nil { if err != nil {
return err return err
@@ -159,17 +169,35 @@ func init() {
target.SetFloat(n) target.SetFloat(n)
return nil return nil
}), }),
reflect.String: NewKindDecoder(reflect.String, func(scan *Scanner, target reflect.Value) error { reflect.String: NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
target.SetString(scan.PopValue("string")) target.SetString(scan.PopValue("string"))
return nil return nil
}), }),
reflect.Bool: NewKindDecoder(reflect.Bool, func(scan *Scanner, target reflect.Value) error { reflect.Bool: NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
target.SetBool(true) target.SetBool(true)
return nil return nil
}), }),
reflect.Slice: NewKindDecoder(reflect.Slice, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
el := target.Type().Elem()
sep, ok := ctx.Value.Field.Tag.Lookup("sep")
if !ok {
sep = ","
}
childScanner := Scan(strings.Split(scan.PopValue("slice"), sep)...)
childDecoder := DecoderForType(el)
for childScanner.Peek().Type != EOLToken {
childValue := reflect.New(el).Elem()
err := childDecoder.Decode(ctx, childScanner, childValue)
if err != nil {
return err
}
target.Set(reflect.Append(target, childValue))
}
return nil
}),
} }
} }
var missingDecoder DecoderFunc = func(scan *Scanner, target reflect.Value) error { var missingDecoder DecoderFunc = func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
return fmt.Errorf("no decoder for %q (of type %T)", target.String(), target.Type()) return fmt.Errorf("no decoder for %q (of type %T) for field %q", target.String(), target.Type(), ctx.Value.Field.Name)
} }
+12 -9
View File
@@ -49,10 +49,9 @@ func (k *Kong) Parse(args []string) (command string, err error) {
// Recursively reset values to defaults (as specified in the grammar) or the zero value. // Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (k *Kong) reset(node *Node) { func (k *Kong) reset(node *Node) {
for _, flag := range node.Flags { for _, flag := range node.Flags {
flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type()))
if flag.Default != "" { if flag.Default != "" {
flag.Decoder.Decode(Scan(flag.Default), flag.Value.Value) flag.Decoder.Decode(&DecoderContext{Value: &flag.Value}, Scan(flag.Default), flag.Value.Value)
} else {
flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type()))
} }
} }
for _, pos := range node.Positional { for _, pos := range node.Positional {
@@ -76,35 +75,39 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
switch { switch {
// -- indicates end of parsing. All remaining arguments are treated as positional arguments only. // -- indicates end of parsing. All remaining arguments are treated as positional arguments only.
case token.Value == "--": case token.Value == "--":
args := []string{}
for { for {
token = scan.Pop() token = scan.Pop()
if token.Type == EOLToken { if token.Type == EOLToken {
break break
} }
scan.PushTyped(token.Value, PositionalArgumentToken) args = append(args, token.Value)
}
for i := range args {
scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken)
} }
// Long flag. // Long flag.
case strings.HasPrefix(token.Value, "--"): case strings.HasPrefix(token.Value, "--"):
// Parse it and push the tokens. // Parse it and push the tokens.
parts := strings.SplitN(token.Value[2:], "=", 2) parts := strings.SplitN(token.Value[2:], "=", 2)
scan.PushTyped(parts[0], FlagToken)
if len(parts) > 1 { if len(parts) > 1 {
scan.PushTyped(parts[1], FlagValueToken) scan.PushTyped(parts[1], FlagValueToken)
} }
scan.PushTyped(parts[0], FlagToken)
// Short flag. // Short flag.
case strings.HasPrefix(token.Value, "-"): case strings.HasPrefix(token.Value, "-"):
scan.PushTyped(token.Value[1:2], ShortFlagToken)
scan.PushTyped(token.Value[2:], ShortFlagTailToken) scan.PushTyped(token.Value[2:], ShortFlagTailToken)
scan.PushTyped(token.Value[1:2], ShortFlagToken)
default: default:
scan.PushTyped(token.Value, PositionalArgumentToken) scan.PushTyped(token.Value, PositionalArgumentToken)
} }
case ShortFlagTailToken: case ShortFlagTailToken:
scan.PushTyped(token.Value[0:1], ShortFlagToken)
scan.PushTyped(token.Value[1:], ShortFlagTailToken) scan.PushTyped(token.Value[1:], ShortFlagTailToken)
scan.PushTyped(token.Value[0:1], ShortFlagToken)
case FlagToken: case FlagToken:
if err := matchFlags(node.Flags, token, scan, func(f *Flag) bool { if err := matchFlags(node.Flags, token, scan, func(f *Flag) bool {
@@ -140,7 +143,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
case branch.Argument != nil: case branch.Argument != nil:
arg := branch.Argument.Argument arg := branch.Argument.Argument
if err := arg.Decoder.Decode(scan, arg.Value); err == nil { if err := arg.Decoder.Decode(&DecoderContext{Value: arg}, scan, arg.Value); err == nil {
command = append(command, "<"+arg.Name+">") command = append(command, "<"+arg.Name+">")
cmd, err := k.applyNode(scan, &branch.Argument.Node) cmd, err := k.applyNode(scan, &branch.Argument.Node)
if err != nil { if err != nil {
@@ -171,7 +174,7 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag)
for _, flag := range flags { for _, flag := range flags {
// Found a matching flag. // Found a matching flag.
if flag.Name == token.Value { if flag.Name == token.Value {
err := flag.Decoder.Decode(scan, flag.Value.Value) err := flag.Decoder.Decode(&DecoderContext{Value: &flag.Value}, scan, flag.Value.Value)
if err != nil { if err != nil {
return err return err
} }
+10 -4
View File
@@ -1,12 +1,9 @@
package kong package kong
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/alecthomas/repr"
) )
func mustNew(t *testing.T, cli interface{}) *Kong { func mustNew(t *testing.T, cli interface{}) *Kong {
@@ -41,7 +38,6 @@ func TestArgument(t *testing.T) {
} `arg:"true"` } `arg:"true"`
} }
p := mustNew(t, &cli) p := mustNew(t, &cli)
repr.Println(p.Model, repr.Hide(reflect.Value{}))
cmd, err := p.Parse([]string{"10", "delete"}) cmd, err := p.Parse([]string{"10", "delete"})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, cli.Id.Id) require.Equal(t, 10, cli.Id.Id)
@@ -61,3 +57,13 @@ func TestResetWithDefaults(t *testing.T) {
require.Equal(t, "", cli.Flag) require.Equal(t, "", cli.Flag)
require.Equal(t, "default", cli.FlagWithDefault) require.Equal(t, "default", cli.FlagWithDefault)
} }
func TestSlice(t *testing.T) {
var cli struct {
Slice []int
}
parser := mustNew(t, &cli)
_, err := parser.Parse([]string{"--slice=1,2,3"})
require.NoError(t, err)
require.Equal(t, []int{1, 2, 3}, cli.Slice)
}
+1
View File
@@ -24,6 +24,7 @@ type Value struct {
Name string Name string
Help string Help string
Decoder Decoder Decoder Decoder
Field reflect.StructField
Value reflect.Value Value reflect.Value
Required bool Required bool
} }
+5 -2
View File
@@ -60,18 +60,21 @@ func (t Token) IsValue() bool {
} }
type Scanner struct { type Scanner struct {
raw []string
args []Token args []Token
} }
func Scan(args ...string) *Scanner { func Scan(args ...string) *Scanner {
s := &Scanner{raw: args} s := &Scanner{}
for _, arg := range args { for _, arg := range args {
s.args = append(s.args, Token{Value: arg}) s.args = append(s.args, Token{Value: arg})
} }
return s return s
} }
func (s *Scanner) Len() int {
return len(s.args)
}
func (s *Scanner) Pop() Token { func (s *Scanner) Pop() Token {
if len(s.args) == 0 { if len(s.args) == 0 {
return Token{Type: EOLToken} return Token{Type: EOLToken}