From 77a613fb8b3b8afefb4363fdd9ec362a4a10c975 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Fri, 11 Oct 2019 16:55:10 +1100 Subject: [PATCH] Much more thorough checking of enum values. --- context.go | 37 +++++++++++++++++++++++++++++++------ kong_test.go | 10 ++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/context.go b/context.go index 1b11f20..7279f63 100644 --- a/context.go +++ b/context.go @@ -135,13 +135,10 @@ func (c *Context) Empty() bool { func (c *Context) Validate() error { err := Visit(c.Model, func(node Visitable, next Next) error { if value, ok := node.(*Value); ok { - if value.Enum != "" && !value.EnumMap()[fmt.Sprintf("%v", value.Target.Interface())] { - enums := []string{} - for enum := range value.EnumMap() { - enums = append(enums, fmt.Sprintf("%q", enum)) + if value.Enum != "" { + if err := checkEnum(value, value.Target); err != nil { + return err } - sort.Strings(enums) - return fmt.Errorf("%s must be one of %s but got %q", value.ShortSummary(), strings.Join(enums, ","), value.Target.Interface()) } } return next(nil) @@ -646,6 +643,34 @@ func checkMissingPositionals(positional int, values []*Value) error { return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } +func checkEnum(value *Value, target reflect.Value) error { + switch target.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < target.Len(); i++ { + if err := checkEnum(value, target.Index(i)); err != nil { + return err + } + } + return nil + + case reflect.Map, reflect.Struct: + return errors.Errorf("enum can only be applied to a slice or value") + + default: + enumMap := value.EnumMap() + v := fmt.Sprintf("%v", target) + if enumMap[v] { + return nil + } + enums := []string{} + for enum := range enumMap { + enums = append(enums, fmt.Sprintf("%q", enum)) + } + sort.Strings(enums) + return fmt.Errorf("%s must be one of %s but got %q", value.ShortSummary(), strings.Join(enums, ","), target.Interface()) + } +} + func checkXorDuplicates(paths []*Path) error { for _, path := range paths { seen := map[string]*Flag{} diff --git a/kong_test.go b/kong_test.go index 03fd2cb..b6fcc5f 100644 --- a/kong_test.go +++ b/kong_test.go @@ -781,3 +781,13 @@ func TestXorChild(t *testing.T) { _, err = p.Parse([]string{"--two=hi", "cmd", "--three"}) require.Error(t, err, "--two and --three can't be used together") } + +func TestEnumSequence(t *testing.T) { + var cli struct { + State []string `enum:"a,b,c" default:"a"` + } + p := mustNew(t, &cli) + _, err := p.Parse(nil) + require.NoError(t, err) + require.Equal(t, []string{"a"}, cli.State) +}