diff --git a/context.go b/context.go index 9ca1f31..d4761dd 100644 --- a/context.go +++ b/context.go @@ -60,8 +60,8 @@ type Context struct { // // The returned Context will include a Path of all commands, arguments, positionals and flags. // -// This just constructs a new trace. To fully apply the trace you must call Resolve(), Validate() and -// Apply(). +// This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(), +// Validate() and Apply(). func Trace(k *Kong, args []string) (*Context, error) { c := &Context{ Kong: k, @@ -74,11 +74,6 @@ func Trace(k *Kong, args []string) (*Context, error) { bindings: bindings{}, } c.Error = c.trace(c.Model.Node) - err := c.reset(c.Model.Node) - if err != nil { - return nil, err - } - return c, nil } @@ -236,9 +231,9 @@ func (c *Context) FlagValue(flag *Flag) interface{} { return flag.DefaultValue.Interface() } -// Recursively reset values to defaults (as specified in the grammar) or the zero value. -func (c *Context) reset(node *Node) error { - return Visit(node, func(node Visitable, next Next) error { +// Reset recursively resets values to defaults (as specified in the grammar) or the zero value. +func (c *Context) Reset() error { + return Visit(c.Model.Node, func(node Visitable, next Next) error { if value, ok := node.(*Value); ok { return next(value.Reset()) } @@ -441,6 +436,28 @@ func (c *Context) getValue(value *Value) reflect.Value { return v } +// ApplyDefaults if they are not already set. +func (c *Context) ApplyDefaults() error { + return Visit(c.Model.Node, func(node Visitable, next Next) error { + var value *Value + switch node := node.(type) { + case *Flag: + value = node.Value + case *Node: + value = node.Argument + case *Value: + value = node + default: + } + if value != nil { + if err := value.ApplyDefault(); err != nil { + return err + } + } + return next(nil) + }) +} + // Apply traced context to the target grammar. func (c *Context) Apply() (string, error) { path := []string{} diff --git a/defaults.go b/defaults.go index 1c9fcc7..a4a314e 100644 --- a/defaults.go +++ b/defaults.go @@ -1,11 +1,22 @@ package kong -// ApplyDefaults applies defaults to a struct. +// ApplyDefaults if they are not already set. func ApplyDefaults(target interface{}, options ...Option) error { app, err := New(target, options...) if err != nil { return err } - _, err = app.Parse(nil) - return err + ctx, err := Trace(app, nil) + if err != nil { + return err + } + err = ctx.Resolve() + if err != nil { + return err + } + err = ctx.Validate() + if err != nil { + return err + } + return ctx.ApplyDefaults() } diff --git a/defaults_test.go b/defaults_test.go index 9481a05..11a2612 100644 --- a/defaults_test.go +++ b/defaults_test.go @@ -12,8 +12,23 @@ func TestApplyDefaults(t *testing.T) { Str string `default:"str"` Duration time.Duration `default:"30s"` } - cli := &CLI{} - err := ApplyDefaults(cli) - require.NoError(t, err) - require.Equal(t, &CLI{Str: "str", Duration: time.Second * 30}, cli) + tests := []struct { + name string + target CLI + expected CLI + }{ + {name: "DefaultsWhenNotSet", + expected: CLI{Str: "str", Duration: time.Second * 30}}, + {name: "PartiallySetDefaults", + target: CLI{Duration: time.Second}, + expected: CLI{Str: "str", Duration: time.Second}}, + } + for _, tt := range tests { + // nolint: scopelint + t.Run(tt.name, func(t *testing.T) { + err := ApplyDefaults(&tt.target) + require.NoError(t, err) + require.Equal(t, tt.expected, tt.target) + }) + } } diff --git a/kong.go b/kong.go index b4350cd..d1039d2 100644 --- a/kong.go +++ b/kong.go @@ -197,6 +197,9 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) { if ctx.Error != nil { return nil, &ParseError{error: ctx.Error, Context: ctx} } + if err = ctx.Reset(); err != nil { + return nil, &ParseError{error: err, Context: ctx} + } if err = k.applyHook(ctx, "BeforeResolve"); err != nil { return nil, &ParseError{error: err, Context: ctx} } diff --git a/model.go b/model.go index b2fd39c..5e72723 100644 --- a/model.go +++ b/model.go @@ -2,6 +2,7 @@ package kong import ( "fmt" + "math" "os" "reflect" "strconv" @@ -308,6 +309,15 @@ func (v *Value) Apply(value reflect.Value) { v.Set = true } +// ApplyDefault value to field if it is not already set. +func (v *Value) ApplyDefault() error { + if reflectValueIsZero(v.Target) { + return v.Reset() + } + v.Set = true + return nil +} + // Reset this value to its default, either the zero value or the parsed result of its envar, // or its "default" tag. // @@ -376,3 +386,42 @@ func (f *Flag) FormatPlaceHolder() string { } return strings.ToUpper(f.Name) + tail } + +// This is directly from the Go 1.13 source code. +func reflectValueIsZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return math.Float64bits(v.Float()) == 0 + case reflect.Complex64, reflect.Complex128: + c := v.Complex() + return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0 + case reflect.Array: + for i := 0; i < v.Len(); i++ { + if !reflectValueIsZero(v.Index(i)) { + return false + } + } + return true + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + return v.IsNil() + case reflect.String: + return v.Len() == 0 + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + if !reflectValueIsZero(v.Field(i)) { + return false + } + } + return true + default: + // This should never happens, but will act as a safeguard for + // later, as a default value doesn't makes sense here. + panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()}) + } +} diff --git a/visit.go b/visit.go index d00bc27..f7dab53 100644 --- a/visit.go +++ b/visit.go @@ -5,7 +5,9 @@ import ( ) // Next should be called by Visitor to proceed with the walk. -type Next func(error) error +// +// The walk will terminate if "err" is non-nil. +type Next func(err error) error // Visitor can be used to walk all nodes in the model. type Visitor func(node Visitable, next Next) error