diff --git a/build.go b/build.go index f3838a7..4d0e2fa 100644 --- a/build.go +++ b/build.go @@ -211,7 +211,7 @@ func buildChild(k *Kong, node *Node, typ NodeType, v reflect.Value, ft reflect.S if child.Help == "" { child.Help = child.Argument.Help } - } else if tag.Default != "" { + } else if tag.HasDefault { if node.DefaultCmd != nil { return failField(v, ft, "can't have more than one default command under %s", node.Summary()) } @@ -239,6 +239,7 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv Name: name, Help: tag.Help, OrigHelp: tag.Help, + HasDefault: tag.HasDefault, Default: tag.Default, DefaultValue: reflect.New(fv.Type()).Elem(), Mapper: mapper, diff --git a/context.go b/context.go index 41a1ea7..6ad4dd0 100644 --- a/context.go +++ b/context.go @@ -167,7 +167,7 @@ func (c *Context) Validate() error { // nolint: gocyclo switch node := node.(type) { case *Value: _, ok := os.LookupEnv(node.Tag.Env) - if node.Enum != "" && (!node.Required || node.Default != "" || (node.Tag.Env != "" && ok)) { + if node.Enum != "" && (!node.Required || node.HasDefault || (node.Tag.Env != "" && ok)) { if err := checkEnum(node, node.Target); err != nil { return err } @@ -175,7 +175,7 @@ func (c *Context) Validate() error { // nolint: gocyclo case *Flag: _, ok := os.LookupEnv(node.Tag.Env) - if node.Enum != "" && (!node.Required || node.Default != "" || (node.Tag.Env != "" && ok)) { + if node.Enum != "" && (!node.Required || node.HasDefault || (node.Tag.Env != "" && ok)) { if err := checkEnum(node.Value, node.Target); err != nil { return err } diff --git a/kong.go b/kong.go index a4d230e..3123dc2 100644 --- a/kong.go +++ b/kong.go @@ -319,7 +319,7 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er } binds := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars)) for _, flag := range node.Flags { - if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { + if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { continue } method := getMethod(flag.Target, name) diff --git a/kong_test.go b/kong_test.go index a415946..0594565 100644 --- a/kong_test.go +++ b/kong_test.go @@ -1432,6 +1432,13 @@ func TestEnumValidation(t *testing.T) { }{}, false, }, + { + "EnumWithEmptyDefault", + &struct { + Flag string `enum:"one,two," default:""` + }{}, + false, + }, } for _, test := range tests { test := test diff --git a/model.go b/model.go index 5f7f80e..a86b510 100644 --- a/model.go +++ b/model.go @@ -232,6 +232,7 @@ type Value struct { Name string Help string OrigHelp string // Original help string, without interpolated variables. + HasDefault bool Default string DefaultValue reflect.Value Enum string @@ -357,7 +358,7 @@ func (v *Value) Reset() error { return nil } } - if v.Default != "" { + if v.HasDefault { return v.Parse(ScanFromTokens(Token{Type: FlagValueToken, Value: v.Default}), v.Target) } return nil @@ -404,7 +405,7 @@ func (f *Flag) FormatPlaceHolder() string { if f.PlaceHolder != "" { return f.PlaceHolder + tail } - if f.Default != "" { + if f.HasDefault { if f.Value.Target.Kind() == reflect.String { return strconv.Quote(f.Default) + tail } diff --git a/tag.go b/tag.go index 9f937d7..e4fae67 100644 --- a/tag.go +++ b/tag.go @@ -20,6 +20,7 @@ type Tag struct { Help string Type string TypeName string + HasDefault bool Default string Format string PlaceHolder string @@ -182,9 +183,10 @@ func hydrateTag(t *Tag, typ reflect.Type) error { // nolint: gocyclo } t.Required = required t.Optional = optional + t.HasDefault = t.Has("default") t.Default = t.Get("default") // Arguments with defaults are always optional. - if t.Arg && t.Default != "" { + if t.Arg && t.HasDefault { t.Optional = true } else if t.Arg && !optional { // Arguments are required unless explicitly made optional. t.Required = true @@ -229,7 +231,7 @@ func hydrateTag(t *Tag, typ reflect.Type) error { // nolint: gocyclo t.PlaceHolder = t.Get("placeholder") t.Enum = t.Get("enum") scalarType := (typ == nil || !(typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map)) - if t.Enum != "" && !(t.Required || t.Default != "") && scalarType { + if t.Enum != "" && !(t.Required || t.HasDefault) && scalarType { return fmt.Errorf("enum value is only valid if it is either required or has a valid default value") } passthrough := t.Has("passthrough")