From 1d00dfef7b1cb798caa03b4b395fa44f9dfbc295 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Wed, 16 May 2018 20:33:18 +1000 Subject: [PATCH] Implemented most of the base parser. This includes branching arguments as well as commands, eg. app user create app user delete app user rename Of note, required/optional flags and positional arguments are not currently enforced. --- build.go | 105 ++++++++++++++++++++++++++++ decoders.go | 92 ++++++++++++++++++++++-- decoders_test.go | 6 ++ global.go | 2 +- kong.go | 165 +++++++++++++++++++++++++++++++++++++++++--- kong_test.go | 63 +++++++++++++++++ model.go | 62 ++++++++--------- scanner.go | 110 +++++++++++++++++++++++++++++ scanner_test.go | 26 +++++++ tokentype_string.go | 16 +++++ 10 files changed, 601 insertions(+), 46 deletions(-) create mode 100644 build.go create mode 100644 decoders_test.go create mode 100644 kong_test.go create mode 100644 scanner.go create mode 100644 scanner_test.go create mode 100644 tokentype_string.go diff --git a/build.go b/build.go new file mode 100644 index 0000000..5433358 --- /dev/null +++ b/build.go @@ -0,0 +1,105 @@ +package kong + +import ( + "fmt" + "reflect" + "strings" + "unicode/utf8" +) + +func build(ast interface{}) (app *Application, err error) { + defer func() { + msg := recover() + if test, ok := recover().(error); ok { + app = nil + err = test + } else if msg != nil { + panic(msg) + } + }() + v := reflect.ValueOf(ast) + iv := reflect.Indirect(v) + if v.Kind() != reflect.Ptr || iv.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected a pointer to a struct but got %T", ast) + } + + return buildNode(iv), nil +} + +func buildNode(v reflect.Value) *Node { + node := &Node{} + for i := 0; i < v.NumField(); i++ { + ft := v.Type().Field(i) + fv := v.Field(i) + + name := ft.Tag.Get("name") + if name == "" { + name = strings.ToLower(strings.Join(camelCase(ft.Name), "-")) + } + help := ft.Tag.Get("help") + decoder := DecoderForField(ft) + dflt := ft.Tag.Get("default") + placeholder := ft.Tag.Get("placeholder") + if placeholder == "" { + placeholder = strings.ToUpper(strings.Join(camelCase(fv.Type().Name()), "-")) + } + short, _ := utf8.DecodeRuneInString(ft.Tag.Get("short")) + if short == utf8.RuneError { + short = 0 + } + // group := ft.Tag.Get("group") + _, required := ft.Tag.Lookup("required") + _, optional := ft.Tag.Lookup("optional") + _, arg := ft.Tag.Lookup("arg") + env := ft.Tag.Get("env") + + // Nested structs are commands. + if ft.Type.Kind() == reflect.Struct { + child := buildNode(fv) + child.Help = help + + // A branching argument. This is a bit hairy, as we let buildNode() do the parsing, then check that + // a positional argument is provided to the child, and move it to the branching argument field. + if arg { + if len(child.Positional) == 0 { + panic(fmt.Errorf("positional branch %s.%s must have at least one child positional argument", + v.Type().Name(), ft.Name)) + } + value := child.Positional[0] + child.Positional = child.Positional[1:] + if child.Help == "" { + child.Help = value.Help + } + child.Name = value.Name + node.Children = append(node.Children, &Branch{Argument: &Argument{ + Node: *child, + Argument: value, + }}) + } else { + child.Name = name + node.Children = append(node.Children, &Branch{Command: child}) + } + } else { + value := Value{ + Name: name, + Help: help, + Decoder: decoder, + Value: fv, + Required: !optional || required, + } + if arg { + node.Positional = append(node.Positional, &value) + } else { + node.Flags = append(node.Flags, &Flag{ + Value: value, + Short: short, + Default: dflt, + Placeholder: placeholder, + Env: env, + }) + } + } + } + + return node +} diff --git a/decoders.go b/decoders.go index 514b140..1f12605 100644 --- a/decoders.go +++ b/decoders.go @@ -1,14 +1,18 @@ package kong -import "reflect" +import ( + "fmt" + "reflect" + "strconv" +) type Decoder interface { - Decode(input string, target reflect.Value) error + Decode(scan *Scanner, target reflect.Value) error } -type DecoderFunc func(input string, target reflect.Value) error +type DecoderFunc func(scan *Scanner, target reflect.Value) error -func (d DecoderFunc) Decode(input string, target reflect.Value) error { return d(input, target) } +func (d DecoderFunc) Decode(scan *Scanner, target reflect.Value) error { return d(scan, target) } var _ Decoder = DecoderFunc(nil) @@ -69,9 +73,30 @@ var _ NamedDecoder = &namedDecoder{} var ( namedDecoders = map[string]NamedDecoder{} typeDecoders = map[reflect.Type]TypeDecoder{} - kindDecoders = map[reflect.Kind]KindDecoder{} + kindDecoders map[reflect.Kind]KindDecoder ) +func DecoderForField(field reflect.StructField) Decoder { + name, ok := field.Tag.Lookup("type") + if ok { + if decoder, ok := namedDecoders[name]; ok { + return decoder + } + } + return DecoderForType(field.Type) +} + +func DecoderForType(typ reflect.Type) Decoder { + var decoder Decoder + var ok bool + if decoder, ok = typeDecoders[typ]; ok { + return decoder + } else if decoder, ok = kindDecoders[typ.Kind()]; ok { + return decoder + } + return missingDecoder +} + // RegisterDecoder registers decoders. // // Decoders must be one of TypeDecoder, KindDecoder or NamedDecoder. @@ -91,5 +116,60 @@ func RegisterDecoder(decoders ...Decoder) { } func init() { - RegisterDecoder() + intDecoder := NewKindDecoder(reflect.Int, func(scan *Scanner, target reflect.Value) error { + n, err := strconv.ParseInt(scan.PopValue("int"), 10, 64) + if err != nil { + return err + } + target.SetInt(n) + return nil + }) + uintDecoder := NewKindDecoder(reflect.Uint, func(scan *Scanner, target reflect.Value) error { + n, err := strconv.ParseUint(scan.PopValue("uint"), 10, 64) + if err != nil { + return err + } + target.SetUint(n) + return nil + }) + kindDecoders = map[reflect.Kind]KindDecoder{ + reflect.Int: intDecoder, + reflect.Int8: intDecoder, + reflect.Int16: intDecoder, + reflect.Int32: intDecoder, + reflect.Int64: intDecoder, + reflect.Uint: uintDecoder, + reflect.Uint8: uintDecoder, + reflect.Uint16: uintDecoder, + reflect.Uint32: uintDecoder, + reflect.Uint64: uintDecoder, + reflect.Float32: NewKindDecoder(reflect.Float32, func(scan *Scanner, target reflect.Value) error { + n, err := strconv.ParseFloat(scan.PopValue("float"), 32) + if err != nil { + return err + } + target.SetFloat(n) + return nil + }), + reflect.Float64: NewKindDecoder(reflect.Float64, func(scan *Scanner, target reflect.Value) error { + n, err := strconv.ParseFloat(scan.PopValue("float"), 64) + if err != nil { + return err + } + target.SetFloat(n) + return nil + }), + reflect.String: NewKindDecoder(reflect.String, func(scan *Scanner, target reflect.Value) error { + target.SetString(scan.PopValue("string")) + return nil + }), + reflect.Bool: NewKindDecoder(reflect.Bool, func(scan *Scanner, target reflect.Value) error { + target.SetBool(true) + return nil + }), + } +} + +var missingDecoder DecoderFunc = func(scan *Scanner, target reflect.Value) error { + return fmt.Errorf("no decoder for %q (of type %T)", target.String(), target.Type()) } diff --git a/decoders_test.go b/decoders_test.go new file mode 100644 index 0000000..0eaa188 --- /dev/null +++ b/decoders_test.go @@ -0,0 +1,6 @@ +package kong + +import "testing" + +func TestDecoders(t *testing.T) { +} diff --git a/global.go b/global.go index 2acabe9..180fd7f 100644 --- a/global.go +++ b/global.go @@ -5,6 +5,6 @@ import "os" func Parse(cli interface{}) { parser, err := New("", "", cli) parser.FatalIfErrorf(err) - err = parser.Parse(os.Args[1:]) + _, err = parser.Parse(os.Args[1:]) parser.FatalIfErrorf(err) } diff --git a/kong.go b/kong.go index 7cfbc6c..d8ece95 100644 --- a/kong.go +++ b/kong.go @@ -4,23 +4,27 @@ import ( "fmt" "os" "path/filepath" + "reflect" + "strings" ) type Kong struct { - Model *ApplicationModel + Model *Application // Termination function (defaults to os.Exit) Terminate func(int) } -// New creates a new Kong parser into grammar. -func New(name, description string, grammar interface{}) (*Kong, error) { +// New creates a new Kong parser into ast. +func New(name, description string, ast interface{}) (*Kong, error) { if name == "" { name = filepath.Base(os.Args[0]) } - model := &ApplicationModel{ - Name: name, - Description: description, + model, err := build(ast) + if err != nil { + return nil, err } + model.Name = name + model.Help = description return &Kong{ Model: model, Terminate: os.Exit, @@ -28,8 +32,153 @@ func New(name, description string, grammar interface{}) (*Kong, error) { } // Parse arguments into target. -func (k *Kong) Parse(args []string) error { - return nil +func (k *Kong) Parse(args []string) (command string, err error) { + defer func() { + msg := recover() + if test, ok := msg.(TokenAssertionError); ok { + err = test + } else if msg != nil { + panic(msg) + } + }() + k.reset(k.Model) + cmd, err := k.applyNode(Scan(args...), k.Model) + return strings.Join(cmd, " "), err +} + +// Recursively reset values to defaults (as specified in the grammar) or the zero value. +func (k *Kong) reset(node *Node) { + for _, flag := range node.Flags { + if flag.Default != "" { + flag.Decoder.Decode(Scan(flag.Default), flag.Value.Value) + } else { + flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type())) + } + } + for _, pos := range node.Positional { + pos.Value.Set(reflect.Zero(pos.Value.Type())) + } + for _, branch := range node.Children { + if branch.Argument != nil { + arg := branch.Argument.Argument + arg.Value.Set(reflect.Zero(arg.Value.Type())) + k.reset(&branch.Argument.Node) + } else { + k.reset(branch.Command) + } + } +} + +func (k *Kong) applyNode(scan *Scanner, node *Node) (command []string, err error) { + for token := scan.Pop(); token.Type != EOLToken; token = scan.Pop() { + switch token.Type { + case UntypedToken: + switch { + // -- indicates end of parsing. All remaining arguments are treated as positional arguments only. + case token.Value == "--": + for { + token = scan.Pop() + if token.Type == EOLToken { + break + } + scan.PushTyped(token.Value, PositionalArgumentToken) + } + + // Long flag. + case strings.HasPrefix(token.Value, "--"): + // Parse it and push the tokens. + parts := strings.SplitN(token.Value[2:], "=", 2) + scan.PushTyped(parts[0], FlagToken) + if len(parts) > 1 { + scan.PushTyped(parts[1], FlagValueToken) + } + + // Short flag. + case strings.HasPrefix(token.Value, "-"): + scan.PushTyped(token.Value[1:2], ShortFlagToken) + scan.PushTyped(token.Value[2:], ShortFlagTailToken) + + default: + scan.PushTyped(token.Value, PositionalArgumentToken) + } + + case ShortFlagTailToken: + scan.PushTyped(token.Value[0:1], ShortFlagToken) + scan.PushTyped(token.Value[1:], ShortFlagTailToken) + + case FlagToken: + if err := matchFlags(node.Flags, token, scan, func(f *Flag) bool { + return f.Name == token.Value + }); err != nil { + return nil, err + } + + case ShortFlagToken: + if err := matchFlags(node.Flags, token, scan, func(f *Flag) bool { + return string(f.Name) == token.Value + }); err != nil { + return nil, err + } + + case FlagValueToken: + return nil, fmt.Errorf("unexpected flag argument %q", token.Value) + + case PositionalArgumentToken: + scan.PushToken(token) + for _, branch := range node.Children { + switch { + case branch.Command != nil: + if branch.Command.Name == token.Value { + scan.Pop() + command = append(command, branch.Command.Name) + cmd, err := k.applyNode(scan, branch.Command) + if err != nil { + return nil, err + } + return append(command, cmd...), nil + } + + case branch.Argument != nil: + arg := branch.Argument.Argument + if err := arg.Decoder.Decode(scan, arg.Value); err == nil { + command = append(command, "<"+arg.Name+">") + cmd, err := k.applyNode(scan, &branch.Argument.Node) + if err != nil { + return nil, err + } + return append(command, cmd...), nil + } + } + } + return nil, fmt.Errorf("unexpected positional argument %s", token) + + default: + return nil, fmt.Errorf("unexpected token %s", token) + } + } + return +} + +func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { + defer func() { + msg := recover() + if test, ok := msg.(TokenAssertionError); ok { + err = fmt.Errorf("%s %s", token, test) + } else if msg != nil { + panic(msg) + } + }() + for _, flag := range flags { + // Found a matching flag. + if flag.Name == token.Value { + err := flag.Decoder.Decode(scan, flag.Value.Value) + if err != nil { + return err + } + return nil + } + } + return fmt.Errorf("unknown flag --%s", token.Value) } func (k *Kong) Errorf(format string, args ...interface{}) { diff --git a/kong_test.go b/kong_test.go new file mode 100644 index 0000000..d6a5a80 --- /dev/null +++ b/kong_test.go @@ -0,0 +1,63 @@ +package kong + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/alecthomas/repr" +) + +func mustNew(t *testing.T, cli interface{}) *Kong { + t.Helper() + parser, err := New("", "", cli) + require.NoError(t, err) + return parser +} + +func TestArgument(t *testing.T) { + /* + app user create + app user delete + app user rename + + */ + var cli struct { + Create struct { + Id string `arg:"true"` + First string `arg:"true"` + Last string `arg:"true"` + } + + // Branching argument. + Id struct { + Id int `arg:"true"` + Flag int + Delete struct{} + Rename struct { + To string + } + } `arg:"true"` + } + p := mustNew(t, &cli) + repr.Println(p.Model, repr.Hide(reflect.Value{})) + cmd, err := p.Parse([]string{"10", "delete"}) + require.NoError(t, err) + require.Equal(t, 10, cli.Id.Id) + require.Equal(t, " delete", cmd) +} + +func TestResetWithDefaults(t *testing.T) { + var cli struct { + Flag string + FlagWithDefault string `default:"default"` + } + cli.Flag = "BLAH" + cli.FlagWithDefault = "BLAH" + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{}) + require.NoError(t, err) + require.Equal(t, "", cli.Flag) + require.Equal(t, "default", cli.FlagWithDefault) +} diff --git a/model.go b/model.go index d8a5647..285f3f0 100644 --- a/model.go +++ b/model.go @@ -1,44 +1,44 @@ package kong -import "flag" +import "reflect" -type Value = flag.Getter +type Application = Node -type ApplicationModel struct { - Name string - Description string - - NodeModel +// A Branch is a command or positional argument that results in a branch in the command tree. +type Branch struct { + Command *Command + Argument *Argument } -type NodeModel struct { - Groups []*GroupModel - // Positional arguments. - Arguments []*ArgumentModel +type Command = Node + +type Node struct { + Name string + Help string + Flags []*Flag + Positional []*Value + Children []*Branch } -type GroupModel struct { - // Flags. - Flags []*FlagModel - // Command hierarchy. - Commands []*CommandModel +type Value struct { + Name string + Help string + Decoder Decoder + Value reflect.Value + Required bool } -type ValueModel struct { - Name string - Help string - Value flag.Value +type Positional = Value + +type Argument struct { + Node + Argument *Value } -type CommandModel struct { - NodeModel -} - -type ArgumentModel struct { - ValueModel -} - -type FlagModel struct { - ValueModel - Short rune +type Flag struct { + Value + Placeholder string + Env string + Short rune + Default string } diff --git a/scanner.go b/scanner.go new file mode 100644 index 0000000..541016b --- /dev/null +++ b/scanner.go @@ -0,0 +1,110 @@ +package kong + +import ( + "fmt" + "strconv" +) + +//go:generate stringer -type=TokenType + +type TokenType int + +const ( + UntypedToken TokenType = iota + EOLToken + FlagToken // -- + FlagValueToken // = + ShortFlagToken // -[ + PositionalArgumentToken // +) + +type TokenAssertionError struct{ err error } + +func (t TokenAssertionError) Error() string { + return t.err.Error() +} + +type Token struct { + Value string + Type TokenType +} + +func (t Token) String() string { + switch t.Type { + case FlagToken: + return "--" + t.Value + + case ShortFlagToken: + return "-" + t.Value + + case EOLToken: + return "EOL" + + default: + return strconv.Quote(t.Value) + } +} + +func (t Token) IsAny(types ...TokenType) bool { + for _, typ := range types { + if t.Type == typ { + return true + } + } + return false +} + +func (t Token) IsValue() bool { + return t.IsAny(FlagValueToken, ShortFlagTailToken, PositionalArgumentToken, UntypedToken) +} + +type Scanner struct { + raw []string + args []Token +} + +func Scan(args ...string) *Scanner { + s := &Scanner{raw: args} + for _, arg := range args { + s.args = append(s.args, Token{Value: arg}) + } + return s +} + +func (s *Scanner) Pop() Token { + if len(s.args) == 0 { + return Token{Type: EOLToken} + } + arg := s.args[0] + s.args = s.args[1:] + return arg +} + +// PopValue or panic with TokenAssertionError. +func (s *Scanner) PopValue(context string) string { + t := s.Pop() + if !t.IsValue() { + panic(TokenAssertionError{fmt.Errorf("expected %s value but got %s", context, t)}) + } + return t.Value +} + +func (s *Scanner) Peek() Token { + if len(s.args) == 0 { + return Token{Type: EOLToken} + } + return s.args[0] +} + +func (s *Scanner) Push(arg string) { + s.PushToken(Token{Value: arg}) +} + +func (s *Scanner) PushTyped(arg string, typ TokenType) { + s.PushToken(Token{Value: arg, Type: typ}) +} + +func (s *Scanner) PushToken(token Token) { + s.args = append([]Token{token}, s.args...) +} diff --git a/scanner_test.go b/scanner_test.go new file mode 100644 index 0000000..3e37d38 --- /dev/null +++ b/scanner_test.go @@ -0,0 +1,26 @@ +package kong + +import ( + "testing" + + "github.com/gotestyourself/gotestyourself/assert" +) + +func TestScannerTake(t *testing.T) { + s := Scan("a", "b", "c") + assert.Assert(t, s.Pop().Value == "a") + assert.Assert(t, s.Pop().Value == "b") + assert.Assert(t, s.Pop().Value == "c") + assert.Assert(t, s.Pop().Type == EOLToken) +} + +func TestScannerPeek(t *testing.T) { + s := Scan("a", "b", "c") + assert.Assert(t, s.Peek().Value == "a") + assert.Assert(t, s.Pop().Value == "a") + assert.Assert(t, s.Peek().Value == "b") + assert.Assert(t, s.Pop().Value == "b") + assert.Assert(t, s.Peek().Value == "c") + assert.Assert(t, s.Pop().Value == "c") + assert.Assert(t, s.Peek().Type == EOLToken) +} diff --git a/tokentype_string.go b/tokentype_string.go new file mode 100644 index 0000000..7b68434 --- /dev/null +++ b/tokentype_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type=TokenType"; DO NOT EDIT. + +package kong + +import "strconv" + +const _TokenType_name = "UntypedTokenEOLTokenFlagTokenFlagValueTokenShortFlagTokenShortFlagTailTokenPositionalArgumentToken" + +var _TokenType_index = [...]uint8{0, 12, 20, 29, 43, 57, 75, 98} + +func (i TokenType) String() string { + if i < 0 || i >= TokenType(len(_TokenType_index)-1) { + return "TokenType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TokenType_name[_TokenType_index[i]:_TokenType_index[i+1]] +}