From 5d7703774f4e8a9a4cb7c3d031e66771125b8b20 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sun, 23 Sep 2018 19:59:25 +1000 Subject: [PATCH] Validate enums (finally). --- context.go | 11 +++++++++++ kong_test.go | 8 ++++++++ model.go | 5 +++++ 3 files changed, 24 insertions(+) diff --git a/context.go b/context.go index fca1e81..7f67199 100644 --- a/context.go +++ b/context.go @@ -119,6 +119,17 @@ func (c *Context) Empty() bool { // Validate the current context. func (c *Context) Validate() error { + err := Visit(c.Model, func(node Visitable, next Next) error { + if value, ok := node.(*Value); ok { + if value.Enum != "" && !value.EnumMap()[fmt.Sprintf("%v", value.Target.Interface())] { + return fmt.Errorf("%s must be one of %s but got %q", value.Summary(), value.Enum, value.Target.Interface()) + } + } + return next(nil) + }) + if err != nil { + return err + } for _, resolver := range c.combineResolvers() { if err := resolver.Validate(c.Model); err != nil { return err diff --git a/kong_test.go b/kong_test.go index b3e8d16..5278c54 100644 --- a/kong_test.go +++ b/kong_test.go @@ -654,3 +654,11 @@ func TestHooksCalledForDefault(t *testing.T) { require.Equal(t, "default", string(cli.Flag)) require.Equal(t, []string{"before:default", "after:default"}, ctx.values) } + +func TestEnum(t *testing.T) { + var cli struct { + Flag string `enum:"a,b,c"` + } + _, err := mustNew(t, &cli).Parse([]string{"--flag", "d"}) + require.EqualError(t, err, "--flag=STRING must be one of a,b,c but got \"\"") +} diff --git a/model.go b/model.go index 5b90741..6256379 100644 --- a/model.go +++ b/model.go @@ -269,6 +269,11 @@ func (v *Value) IsBool() bool { // Parse tokens into value, parse, and validate, but do not write to the field. func (v *Value) Parse(scan *Scanner, target reflect.Value) error { + defer func() { + if err := recover(); err != nil { + panic(fmt.Sprintf("mapper %T failed to apply to %s: %s", v.Mapper, v.Summary(), err)) + } + }() err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target) if err != nil { return fmt.Errorf("%s: %s", v.Summary(), err)