Slice support.

This commit is contained in:
Alec Thomas
2018-05-17 19:39:48 +10:00
parent cb88963909
commit b9d002b746
6 changed files with 149 additions and 68 deletions
+14 -8
View File
@@ -36,10 +36,10 @@ func buildNode(v reflect.Value) *Node {
if name == "" { if name == "" {
name = strings.ToLower(strings.Join(camelCase(ft.Name), "-")) name = strings.ToLower(strings.Join(camelCase(ft.Name), "-"))
} }
help := ft.Tag.Get("help") decoder := DecoderForField(ft)
decoder, err := DecoderForField(ft) help, ok := ft.Tag.Lookup("help")
if err != nil && ft.Type.Kind() != reflect.Struct { if !ok {
panic(err) continue
} }
dflt := ft.Tag.Get("default") dflt := ft.Tag.Get("default")
placeholder := ft.Tag.Get("placeholder") placeholder := ft.Tag.Get("placeholder")
@@ -53,11 +53,13 @@ func buildNode(v reflect.Value) *Node {
// group := ft.Tag.Get("group") // group := ft.Tag.Get("group")
_, required := ft.Tag.Lookup("required") _, required := ft.Tag.Lookup("required")
_, optional := ft.Tag.Lookup("optional") _, optional := ft.Tag.Lookup("optional")
// Force field to be an argument, not a flag.
_, arg := ft.Tag.Lookup("arg") _, arg := ft.Tag.Lookup("arg")
env := ft.Tag.Get("env") env := ft.Tag.Get("env")
format := ft.Tag.Get("format")
// Nested structs are commands. // Nested structs are either commands or args.
if ft.Type.Kind() == reflect.Struct { if ft.Type.Kind() == reflect.Struct && decoder == nil {
child := buildNode(fv) child := buildNode(fv)
child.Help = help child.Help = help
@@ -65,8 +67,8 @@ func buildNode(v reflect.Value) *Node {
// a positional argument is provided to the child, and move it to the branching argument field. // a positional argument is provided to the child, and move it to the branching argument field.
if arg { if arg {
if len(child.Positional) == 0 { if len(child.Positional) == 0 {
panic(fmt.Errorf("positional branch %s.%s must have at least one child positional argument", fail("positional branch %s.%s must have at least one child positional argument",
v.Type().Name(), ft.Name)) v.Type().Name(), ft.Name)
} }
value := child.Positional[0] value := child.Positional[0]
child.Positional = child.Positional[1:] child.Positional = child.Positional[1:]
@@ -83,6 +85,9 @@ func buildNode(v reflect.Value) *Node {
node.Children = append(node.Children, &Branch{Command: child}) node.Children = append(node.Children, &Branch{Command: child})
} }
} else { } else {
if decoder == nil {
fail("no decoder for %s.%s (of type %s)", v.Type(), ft.Name, ft.Type)
}
value := Value{ value := Value{
Name: name, Name: name,
Help: help, Help: help,
@@ -90,6 +95,7 @@ func buildNode(v reflect.Value) *Node {
Value: fv, Value: fv,
Field: ft, Field: ft,
Required: !optional || required, Required: !optional || required,
Format: format,
} }
if arg { if arg {
node.Positional = append(node.Positional, &value) node.Positional = append(node.Positional, &value)
+57 -27
View File
@@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
type DecoderContext struct { type DecoderContext struct {
@@ -81,29 +82,34 @@ var _ NamedDecoder = &namedDecoder{}
var ( var (
namedDecoders = map[string]NamedDecoder{} namedDecoders = map[string]NamedDecoder{}
typeDecoders = map[reflect.Type]TypeDecoder{} typeDecoders = map[reflect.Type]TypeDecoder{}
kindDecoders map[reflect.Kind]KindDecoder kindDecoders = map[reflect.Kind]KindDecoder{}
) )
// DecoderForField finds a decoder for a struct field. // DecoderForField finds a decoder for a struct field.
func DecoderForField(field reflect.StructField) (Decoder, error) { //
// Will return nil if a decoder can not be determined.
func DecoderForField(field reflect.StructField) Decoder {
name, ok := field.Tag.Lookup("type") name, ok := field.Tag.Lookup("type")
if ok { if ok {
if decoder, ok := namedDecoders[name]; ok { if decoder, ok := namedDecoders[name]; ok {
return decoder, nil return decoder
} }
} }
return DecoderForType(field.Type) return DecoderForType(field.Type)
} }
func DecoderForType(typ reflect.Type) (Decoder, error) { // DecoderForType finds a decoder via a type or kind.
//
// Will return nil if a decoder can not be determined.
func DecoderForType(typ reflect.Type) Decoder {
var decoder Decoder var decoder Decoder
var ok bool var ok bool
if decoder, ok = typeDecoders[typ]; ok { if decoder, ok = typeDecoders[typ]; ok {
return decoder, nil return decoder
} else if decoder, ok = kindDecoders[typ.Kind()]; ok { } else if decoder, ok = kindDecoders[typ.Kind()]; ok {
return decoder, nil return decoder
} }
return nil, fmt.Errorf("no decoder for type %s", typ) return nil
} }
// RegisterDecoder registers decoders. // RegisterDecoder registers decoders.
@@ -119,35 +125,59 @@ func RegisterDecoder(decoders ...Decoder) {
case NamedDecoder: case NamedDecoder:
namedDecoders[decoder.Name()] = decoder namedDecoders[decoder.Name()] = decoder
default: default:
panic("unsupported decoder type " + reflect.TypeOf(decoder).String()) fail("unsupported decoder type " + reflect.TypeOf(decoder).String())
} }
} }
} }
func init() { func init() {
kindDecoders = map[reflect.Kind]KindDecoder{ RegisterDecoder(
reflect.Int: NewKindDecoder(reflect.Int, intDecoder), NewKindDecoder(reflect.Int, intDecoder),
reflect.Int8: NewKindDecoder(reflect.Int8, intDecoder), NewKindDecoder(reflect.Int8, intDecoder),
reflect.Int16: NewKindDecoder(reflect.Int16, intDecoder), NewKindDecoder(reflect.Int16, intDecoder),
reflect.Int32: NewKindDecoder(reflect.Int32, intDecoder), NewKindDecoder(reflect.Int32, intDecoder),
reflect.Int64: NewKindDecoder(reflect.Int64, intDecoder), NewKindDecoder(reflect.Int64, intDecoder),
reflect.Uint: NewKindDecoder(reflect.Uint, uintDecoder), NewKindDecoder(reflect.Uint, uintDecoder),
reflect.Uint8: NewKindDecoder(reflect.Uint8, uintDecoder), NewKindDecoder(reflect.Uint8, uintDecoder),
reflect.Uint16: NewKindDecoder(reflect.Uint16, uintDecoder), NewKindDecoder(reflect.Uint16, uintDecoder),
reflect.Uint32: NewKindDecoder(reflect.Uint32, uintDecoder), NewKindDecoder(reflect.Uint32, uintDecoder),
reflect.Uint64: NewKindDecoder(reflect.Uint64, uintDecoder), NewKindDecoder(reflect.Uint64, uintDecoder),
reflect.Float32: NewKindDecoder(reflect.Float32, floatDecoder(32)), NewKindDecoder(reflect.Float32, floatDecoder(32)),
reflect.Float64: NewKindDecoder(reflect.Float64, floatDecoder(64)), NewKindDecoder(reflect.Float64, floatDecoder(64)),
reflect.String: NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { NewKindDecoder(reflect.String, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
target.SetString(scan.PopValue("string")) target.SetString(scan.PopValue("string"))
return nil return nil
}), }),
reflect.Bool: NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { NewKindDecoder(reflect.Bool, func(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
target.SetBool(true) target.SetBool(true)
return nil return nil
}), }),
reflect.Slice: NewKindDecoder(reflect.Slice, sliceDecoder), NewKindDecoder(reflect.Slice, sliceDecoder),
NewTypeDecoder(reflect.TypeOf(time.Time{}), timeDecoder),
NewTypeDecoder(reflect.TypeOf(time.Duration(0)), durationDecoder),
)
}
func durationDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
d, err := time.ParseDuration(scan.PopValue("duration"))
if err != nil {
return err
} }
target.Set(reflect.ValueOf(d))
return nil
}
func timeDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
fmt := time.RFC3339
if ctx.Value.Format != "" {
fmt = ctx.Value.Format
}
t, err := time.Parse(fmt, scan.PopValue("time"))
if err != nil {
return err
}
target.Set(reflect.ValueOf(t))
return nil
} }
func intDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error { func intDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) error {
@@ -186,9 +216,9 @@ func sliceDecoder(ctx *DecoderContext, scan *Scanner, target reflect.Value) erro
sep = "," sep = ","
} }
childScanner := Scan(strings.Split(scan.PopValue("list"), sep)...) childScanner := Scan(strings.Split(scan.PopValue("list"), sep)...)
childDecoder, err := DecoderForType(el) childDecoder := DecoderForType(el)
if err != nil { if childDecoder == nil {
return err return fmt.Errorf("no decoder for element type of %s", target.Type())
} }
for childScanner.Peek().Type != EOLToken { for childScanner.Peek().Type != EOLToken {
childValue := reflect.New(el).Elem() childValue := reflect.New(el).Elem()
+29 -3
View File
@@ -8,6 +8,16 @@ import (
"strings" "strings"
) )
type Error struct {
msg string
}
func (e Error) Error() string { return e.msg }
func fail(format string, args ...interface{}) {
panic(Error{fmt.Sprintf(format, args...)})
}
type Kong struct { type Kong struct {
Model *Application Model *Application
// Termination function (defaults to os.Exit) // Termination function (defaults to os.Exit)
@@ -35,7 +45,7 @@ func New(name, description string, ast interface{}) (*Kong, error) {
func (k *Kong) Parse(args []string) (command string, err error) { func (k *Kong) Parse(args []string) (command string, err error) {
defer func() { defer func() {
msg := recover() msg := recover()
if test, ok := msg.(TokenAssertionError); ok { if test, ok := msg.(Error); ok {
err = test err = test
} else if msg != nil { } else if msg != nil {
panic(msg) panic(msg)
@@ -68,7 +78,8 @@ func (k *Kong) reset(node *Node) {
} }
} }
func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error) { func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error) { // nolint: gocyclo
positional := 0
for token := scan.Pop(); token.Type != EOLToken; token = scan.Pop() { for token := scan.Pop(); token.Type != EOLToken; token = scan.Pop() {
switch token.Type { switch token.Type {
case UntypedToken: case UntypedToken:
@@ -98,6 +109,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
// Short flag. // Short flag.
case strings.HasPrefix(token.Value, "-"): case strings.HasPrefix(token.Value, "-"):
// Note: tokens must be pushed in reverse order.
scan.PushTyped(token.Value[2:], ShortFlagTailToken) scan.PushTyped(token.Value[2:], ShortFlagTailToken)
scan.PushTyped(token.Value[1:2], ShortFlagToken) scan.PushTyped(token.Value[1:2], ShortFlagToken)
@@ -106,6 +118,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
} }
case ShortFlagTailToken: case ShortFlagTailToken:
// Note: tokens must be pushed in reverse order.
scan.PushTyped(token.Value[1:], ShortFlagTailToken) scan.PushTyped(token.Value[1:], ShortFlagTailToken)
scan.PushTyped(token.Value[0:1], ShortFlagToken) scan.PushTyped(token.Value[0:1], ShortFlagToken)
@@ -128,6 +141,19 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
case PositionalArgumentToken: case PositionalArgumentToken:
scan.PushToken(token) scan.PushToken(token)
// Ensure we've consumed all positional arguments.
if positional < len(node.Positional) {
arg := node.Positional[positional]
err := arg.Decoder.Decode(&DecoderContext{Value: arg}, scan, arg.Value)
if err != nil {
return nil, err
}
command = append(command, "<"+arg.Name+">")
positional++
break
}
// After positional arguments have been consumed, handle commands and branching arguments.
for _, branch := range node.Children { for _, branch := range node.Children {
switch { switch {
case branch.Command != nil: case branch.Command != nil:
@@ -165,7 +191,7 @@ func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error
func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) {
defer func() { defer func() {
msg := recover() msg := recover()
if test, ok := msg.(TokenAssertionError); ok { if test, ok := msg.(Error); ok {
err = fmt.Errorf("%s %s", token, test) err = fmt.Errorf("%s %s", token, test)
} else if msg != nil { } else if msg != nil {
panic(msg) panic(msg)
+46 -21
View File
@@ -13,7 +13,23 @@ func mustNew(t *testing.T, cli interface{}) *Kong {
return parser return parser
} }
func TestArgument(t *testing.T) { func TestArgumentSequence(t *testing.T) {
var cli struct {
User struct {
Create struct {
ID int `arg:"" help:""`
First string `arg:"" help:""`
Last string `arg:"" help:""`
} `help:""`
} `help:""`
}
p := mustNew(t, &cli)
cmd, err := p.Parse([]string{"user", "create", "10", "Alec", "Thomas"})
require.NoError(t, err)
require.Equal(t, "user create <id> <first> <last>", cmd)
}
func TestBranchingArgument(t *testing.T) {
/* /*
app user create <id> <first> <last> app user create <id> <first> <last>
app user <id> delete app user <id> delete
@@ -21,33 +37,35 @@ func TestArgument(t *testing.T) {
*/ */
var cli struct { var cli struct {
Create struct { User struct {
Id string `arg:"true"` Create struct {
First string `arg:"true"` ID string `arg:"" help:""`
Last string `arg:"true"` First string `arg:"" help:""`
} Last string `arg:"" help:""`
} `help:""`
// Branching argument. // Branching argument.
Id struct { ID struct {
Id int `arg:"true"` ID int `arg:"" help:""`
Flag int Flag int `help:""`
Delete struct{} Delete struct{} `help:""`
Rename struct { Rename struct {
To string To string
} } `help:""`
} `arg:"true"` } `arg:"" help:""`
} `help:"Manage users."`
} }
p := mustNew(t, &cli) p := mustNew(t, &cli)
cmd, err := p.Parse([]string{"10", "delete"}) cmd, err := p.Parse([]string{"user", "10", "delete"})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, cli.Id.Id) require.Equal(t, 10, cli.User.ID.ID)
require.Equal(t, "<id> delete", cmd) require.Equal(t, "user <id> delete", cmd)
} }
func TestResetWithDefaults(t *testing.T) { func TestResetWithDefaults(t *testing.T) {
var cli struct { var cli struct {
Flag string Flag string `help:""`
FlagWithDefault string `default:"default"` FlagWithDefault string `default:"default" help:""`
} }
cli.Flag = "BLAH" cli.Flag = "BLAH"
cli.FlagWithDefault = "BLAH" cli.FlagWithDefault = "BLAH"
@@ -60,10 +78,17 @@ func TestResetWithDefaults(t *testing.T) {
func TestSlice(t *testing.T) { func TestSlice(t *testing.T) {
var cli struct { var cli struct {
Slice []int Slice []int `help:""`
} }
parser := mustNew(t, &cli) parser := mustNew(t, &cli)
_, err := parser.Parse([]string{"--slice=1,2,3"}) _, err := parser.Parse([]string{"--slice=1,2,3"})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int{1, 2, 3}, cli.Slice) require.Equal(t, []int{1, 2, 3}, cli.Slice)
} }
func TestUnsupportedfieldErrors(t *testing.T) {
var cli struct {
Keys map[string]string `help:""`
}
require.Panics(t, func() { mustNew(t, &cli) })
}
+1
View File
@@ -27,6 +27,7 @@ type Value struct {
Field reflect.StructField Field reflect.StructField
Value reflect.Value Value reflect.Value
Required bool Required bool
Format string // Formatting directive, if applicable.
} }
type Positional = Value type Positional = Value
+2 -9
View File
@@ -1,7 +1,6 @@
package kong package kong
import ( import (
"fmt"
"strconv" "strconv"
) )
@@ -19,12 +18,6 @@ const (
PositionalArgumentToken // <arg> PositionalArgumentToken // <arg>
) )
type TokenAssertionError struct{ err error }
func (t TokenAssertionError) Error() string {
return t.err.Error()
}
type Token struct { type Token struct {
Value string Value string
Type TokenType Type TokenType
@@ -84,11 +77,11 @@ func (s *Scanner) Pop() Token {
return arg return arg
} }
// PopValue or panic with TokenAssertionError. // PopValue or panic with Error.
func (s *Scanner) PopValue(context string) string { func (s *Scanner) PopValue(context string) string {
t := s.Pop() t := s.Pop()
if !t.IsValue() { if !t.IsValue() {
panic(TokenAssertionError{fmt.Errorf("expected %s value but got %s", context, t)}) fail("expected %s value but got %s", context, t)
} }
return t.Value return t.Value
} }