From 8e96da517db24a7ce0bbe57f0834a30645bd78fd Mon Sep 17 00:00:00 2001 From: Gerald Kaszuba Date: Sat, 19 May 2018 20:54:26 +1000 Subject: [PATCH] Fixes #3 Required and optional flags+args (#6) --- build.go | 34 ++++++++++++++++---- kong.go | 85 +++++++++++++++++++++++++++++++++++++++--------- kong_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++ model.go | 1 + 4 files changed, 188 insertions(+), 23 deletions(-) diff --git a/build.go b/build.go index 656c452..03aae9f 100644 --- a/build.go +++ b/build.go @@ -103,19 +103,24 @@ func buildNode(v reflect.Value, cmd bool) *Node { if decoder == nil { fail("no decoder for %s.%s (of type %s)", v.Type(), ft.Name, ft.Type) } + + flag := !arg + value := Value{ - Name: name, - Help: help, - Decoder: decoder, - Value: fv, - Field: ft, - Required: !optional || required, + Name: name, + Flag: flag, + Help: help, + Decoder: decoder, + Value: fv, + Field: ft, + + // Flags are optional by default, and args are required by default. + Required: (flag && required) || (arg && !optional), Format: format, } if arg { node.Positional = append(node.Positional, &value) } else { - value.Flag = true node.Flags = append(node.Flags, &Flag{ Value: value, Short: short, @@ -126,5 +131,20 @@ func buildNode(v reflect.Value, cmd bool) *Node { } } } + + // Scan through argument positionals to ensure optional is never before a required + last := true + for _, p := range node.Positional { + if p.Flag { + continue + } + + if !last && p.Required { + fail("arguments can not be required after an optional: %v", p.Name) + } + + last = p.Required + } + return node } diff --git a/kong.go b/kong.go index b150b8e..f6c6b91 100644 --- a/kong.go +++ b/kong.go @@ -201,27 +201,79 @@ func (k *Kong) applyNode(scan *Scanner, node *Node, flags []*Flag) (command []st return nil, fmt.Errorf("unexpected token %s", token) } } - if positional < len(node.Positional) { - missing := []string{} - for ; positional < len(node.Positional); positional++ { - missing = append(missing, "<"+node.Positional[positional].Name+">") - } - return nil, fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) + + if err := checkMissingPositionals(positional, node.Positional); err != nil { + return nil, err } - if len(node.Children) > 0 { - missing := []string{} - for _, child := range node.Children { - if child.Argument != nil { - missing = append(missing, "<"+child.Argument.Name+">") - } else { - missing = append(missing, child.Command.Name) - } - } - return nil, fmt.Errorf("expected one of %s", strings.Join(missing, ", ")) + + if err := checkMissingChildren(node.Children); err != nil { + return nil, err } + + if err := chickMissingFlags(node.Children, flags); err != nil { + return nil, err + } + return } +func chickMissingFlags(children []*Branch, flags []*Flag) error { + // Only check required missing fields at the last child. + if len(children) > 0 { + return nil + } + missing := []string{} + for _, flag := range flags { + if !flag.Required || flag.Set { + continue + } + missing = append(missing, flag.Name) + } + if len(missing) == 0 { + return nil + } + + return fmt.Errorf("missing flags: %s", strings.Join(missing, ", ")) +} + +func checkMissingChildren(children []*Branch) error { + missing := []string{} + for _, child := range children { + if child.Argument != nil { + if !child.Argument.Argument.Required { + continue + } + missing = append(missing, "<"+child.Argument.Name+">") + } else { + missing = append(missing, child.Command.Name) + } + } + if len(missing) == 0 { + return nil + } + + return fmt.Errorf("expected one of %s", strings.Join(missing, ", ")) +} + +// If we're missing any positionals and they're required, return an error. +func checkMissingPositionals(positional int, values []*Value) error { + // All the positionals are in. + if positional == len(values) { + return nil + } + + // We're low on supplied positionals, but the missing one is optional. + if !values[positional].Required { + return nil + } + + missing := []string{} + for ; positional < len(values); positional++ { + missing = append(missing, "<"+values[positional].Name+">") + } + return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) +} + func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) bool) (err error) { defer func() { msg := recover() @@ -238,6 +290,7 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) if err != nil { return err } + flag.Set = true return nil } } diff --git a/kong_test.go b/kong_test.go index 35018f3..066453a 100644 --- a/kong_test.go +++ b/kong_test.go @@ -150,3 +150,94 @@ func TestPropagatedFlags(t *testing.T) { require.Equal(t, "moo", cli.Flag1) require.Equal(t, true, cli.Command1.Flag2) } + +func TestRequiredFlag(t *testing.T) { + var cli struct { + Flag string `required:""` + } + + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{}) + require.Error(t, err) +} + +func TestOptionalArg(t *testing.T) { + var cli struct { + Arg string `arg:"" optional:""` + } + + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{}) + require.NoError(t, err) +} + +func TestRequiredArg(t *testing.T) { + var cli struct { + Arg string `arg:""` + } + + parser := mustNew(t, &cli) + _, err := parser.Parse([]string{}) + require.Error(t, err) +} + +func TestInvalidRequiredAfterOptional(t *testing.T) { + var cli struct { + ID int `arg:"" optional:""` + Name string `arg:""` + } + + _, err := New(&cli) + require.Error(t, err) +} + +func TestOptionalStructArg(t *testing.T) { + var cli struct { + Name struct { + Name string `arg:"" optional:""` + Enabled bool + } `arg:"" optional:""` + } + + parser := mustNew(t, &cli) + + t.Run("WithFlag", func(t *testing.T) { + _, err := parser.Parse([]string{"gak", "--enabled"}) + require.NoError(t, err) + require.Equal(t, "gak", cli.Name.Name) + require.Equal(t, true, cli.Name.Enabled) + }) + + t.Run("WithoutFlag", func(t *testing.T) { + _, err := parser.Parse([]string{"gak"}) + require.NoError(t, err) + require.Equal(t, "gak", cli.Name.Name) + }) + + t.Run("WithNothing", func(t *testing.T) { + _, err := parser.Parse([]string{}) + require.NoError(t, err) + }) +} + +func TestMixedRequiredArgs(t *testing.T) { + var cli struct { + Name string `arg:""` + ID int `arg:"" optional:""` + } + + parser := mustNew(t, &cli) + + t.Run("SingleRequired", func(t *testing.T) { + _, err := parser.Parse([]string{"gak", "5"}) + require.NoError(t, err) + require.Equal(t, "gak", cli.Name) + require.Equal(t, 5, cli.ID) + }) + + t.Run("ExtraOptional", func(t *testing.T) { + _, err := parser.Parse([]string{"gak"}) + require.NoError(t, err) + require.Equal(t, "gak", cli.Name) + }) +} diff --git a/model.go b/model.go index e7beda6..8d0bd2e 100644 --- a/model.go +++ b/model.go @@ -29,6 +29,7 @@ type Value struct { Field reflect.StructField Value reflect.Value Required bool + Set bool // Used with Required to test if a value has been given. Format string // Formatting directive, if applicable. }