ApplyDefaults() now only applies defaults if the value is not already otherwise set.

This commit is contained in:
Alec Thomas
2019-06-12 13:25:33 +10:00
parent 0d256bb68a
commit 7be398e79f
6 changed files with 115 additions and 18 deletions
+27 -10
View File
@@ -60,8 +60,8 @@ type Context struct {
// //
// The returned Context will include a Path of all commands, arguments, positionals and flags. // The returned Context will include a Path of all commands, arguments, positionals and flags.
// //
// This just constructs a new trace. To fully apply the trace you must call Resolve(), Validate() and // This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(),
// Apply(). // Validate() and Apply().
func Trace(k *Kong, args []string) (*Context, error) { func Trace(k *Kong, args []string) (*Context, error) {
c := &Context{ c := &Context{
Kong: k, Kong: k,
@@ -74,11 +74,6 @@ func Trace(k *Kong, args []string) (*Context, error) {
bindings: bindings{}, bindings: bindings{},
} }
c.Error = c.trace(c.Model.Node) c.Error = c.trace(c.Model.Node)
err := c.reset(c.Model.Node)
if err != nil {
return nil, err
}
return c, nil return c, nil
} }
@@ -236,9 +231,9 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
return flag.DefaultValue.Interface() return flag.DefaultValue.Interface()
} }
// Recursively reset values to defaults (as specified in the grammar) or the zero value. // Reset recursively resets values to defaults (as specified in the grammar) or the zero value.
func (c *Context) reset(node *Node) error { func (c *Context) Reset() error {
return Visit(node, func(node Visitable, next Next) error { return Visit(c.Model.Node, func(node Visitable, next Next) error {
if value, ok := node.(*Value); ok { if value, ok := node.(*Value); ok {
return next(value.Reset()) return next(value.Reset())
} }
@@ -441,6 +436,28 @@ func (c *Context) getValue(value *Value) reflect.Value {
return v return v
} }
// ApplyDefaults if they are not already set.
func (c *Context) ApplyDefaults() error {
return Visit(c.Model.Node, func(node Visitable, next Next) error {
var value *Value
switch node := node.(type) {
case *Flag:
value = node.Value
case *Node:
value = node.Argument
case *Value:
value = node
default:
}
if value != nil {
if err := value.ApplyDefault(); err != nil {
return err
}
}
return next(nil)
})
}
// Apply traced context to the target grammar. // Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) { func (c *Context) Apply() (string, error) {
path := []string{} path := []string{}
+14 -3
View File
@@ -1,11 +1,22 @@
package kong package kong
// ApplyDefaults applies defaults to a struct. // ApplyDefaults if they are not already set.
func ApplyDefaults(target interface{}, options ...Option) error { func ApplyDefaults(target interface{}, options ...Option) error {
app, err := New(target, options...) app, err := New(target, options...)
if err != nil { if err != nil {
return err return err
} }
_, err = app.Parse(nil) ctx, err := Trace(app, nil)
return err if err != nil {
return err
}
err = ctx.Resolve()
if err != nil {
return err
}
err = ctx.Validate()
if err != nil {
return err
}
return ctx.ApplyDefaults()
} }
+19 -4
View File
@@ -12,8 +12,23 @@ func TestApplyDefaults(t *testing.T) {
Str string `default:"str"` Str string `default:"str"`
Duration time.Duration `default:"30s"` Duration time.Duration `default:"30s"`
} }
cli := &CLI{} tests := []struct {
err := ApplyDefaults(cli) name string
require.NoError(t, err) target CLI
require.Equal(t, &CLI{Str: "str", Duration: time.Second * 30}, cli) expected CLI
}{
{name: "DefaultsWhenNotSet",
expected: CLI{Str: "str", Duration: time.Second * 30}},
{name: "PartiallySetDefaults",
target: CLI{Duration: time.Second},
expected: CLI{Str: "str", Duration: time.Second}},
}
for _, tt := range tests {
// nolint: scopelint
t.Run(tt.name, func(t *testing.T) {
err := ApplyDefaults(&tt.target)
require.NoError(t, err)
require.Equal(t, tt.expected, tt.target)
})
}
} }
+3
View File
@@ -197,6 +197,9 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) {
if ctx.Error != nil { if ctx.Error != nil {
return nil, &ParseError{error: ctx.Error, Context: ctx} return nil, &ParseError{error: ctx.Error, Context: ctx}
} }
if err = ctx.Reset(); err != nil {
return nil, &ParseError{error: err, Context: ctx}
}
if err = k.applyHook(ctx, "BeforeResolve"); err != nil { if err = k.applyHook(ctx, "BeforeResolve"); err != nil {
return nil, &ParseError{error: err, Context: ctx} return nil, &ParseError{error: err, Context: ctx}
} }
+49
View File
@@ -2,6 +2,7 @@ package kong
import ( import (
"fmt" "fmt"
"math"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@@ -308,6 +309,15 @@ func (v *Value) Apply(value reflect.Value) {
v.Set = true v.Set = true
} }
// ApplyDefault value to field if it is not already set.
func (v *Value) ApplyDefault() error {
if reflectValueIsZero(v.Target) {
return v.Reset()
}
v.Set = true
return nil
}
// Reset this value to its default, either the zero value or the parsed result of its envar, // Reset this value to its default, either the zero value or the parsed result of its envar,
// or its "default" tag. // or its "default" tag.
// //
@@ -376,3 +386,42 @@ func (f *Flag) FormatPlaceHolder() string {
} }
return strings.ToUpper(f.Name) + tail return strings.ToUpper(f.Name) + tail
} }
// This is directly from the Go 1.13 source code.
func reflectValueIsZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.Complex64, reflect.Complex128:
c := v.Complex()
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !reflectValueIsZero(v.Index(i)) {
return false
}
}
return true
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return v.IsNil()
case reflect.String:
return v.Len() == 0
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !reflectValueIsZero(v.Field(i)) {
return false
}
}
return true
default:
// This should never happens, but will act as a safeguard for
// later, as a default value doesn't makes sense here.
panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()})
}
}
+3 -1
View File
@@ -5,7 +5,9 @@ import (
) )
// Next should be called by Visitor to proceed with the walk. // Next should be called by Visitor to proceed with the walk.
type Next func(error) error //
// The walk will terminate if "err" is non-nil.
type Next func(err error) error
// Visitor can be used to walk all nodes in the model. // Visitor can be used to walk all nodes in the model.
type Visitor func(node Visitable, next Next) error type Visitor func(node Visitable, next Next) error