ApplyDefaults() now only applies defaults if the value is not already otherwise set.

This commit is contained in:
Alec Thomas
2019-06-12 13:25:33 +10:00
parent 0d256bb68a
commit 7be398e79f
6 changed files with 115 additions and 18 deletions
+27 -10
View File
@@ -60,8 +60,8 @@ type Context struct {
//
// The returned Context will include a Path of all commands, arguments, positionals and flags.
//
// This just constructs a new trace. To fully apply the trace you must call Resolve(), Validate() and
// Apply().
// This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(),
// Validate() and Apply().
func Trace(k *Kong, args []string) (*Context, error) {
c := &Context{
Kong: k,
@@ -74,11 +74,6 @@ func Trace(k *Kong, args []string) (*Context, error) {
bindings: bindings{},
}
c.Error = c.trace(c.Model.Node)
err := c.reset(c.Model.Node)
if err != nil {
return nil, err
}
return c, nil
}
@@ -236,9 +231,9 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
return flag.DefaultValue.Interface()
}
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (c *Context) reset(node *Node) error {
return Visit(node, func(node Visitable, next Next) error {
// Reset recursively resets values to defaults (as specified in the grammar) or the zero value.
func (c *Context) Reset() error {
return Visit(c.Model.Node, func(node Visitable, next Next) error {
if value, ok := node.(*Value); ok {
return next(value.Reset())
}
@@ -441,6 +436,28 @@ func (c *Context) getValue(value *Value) reflect.Value {
return v
}
// ApplyDefaults if they are not already set.
func (c *Context) ApplyDefaults() error {
return Visit(c.Model.Node, func(node Visitable, next Next) error {
var value *Value
switch node := node.(type) {
case *Flag:
value = node.Value
case *Node:
value = node.Argument
case *Value:
value = node
default:
}
if value != nil {
if err := value.ApplyDefault(); err != nil {
return err
}
}
return next(nil)
})
}
// Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) {
path := []string{}
+14 -3
View File
@@ -1,11 +1,22 @@
package kong
// ApplyDefaults applies defaults to a struct.
// ApplyDefaults if they are not already set.
func ApplyDefaults(target interface{}, options ...Option) error {
app, err := New(target, options...)
if err != nil {
return err
}
_, err = app.Parse(nil)
return err
ctx, err := Trace(app, nil)
if err != nil {
return err
}
err = ctx.Resolve()
if err != nil {
return err
}
err = ctx.Validate()
if err != nil {
return err
}
return ctx.ApplyDefaults()
}
+19 -4
View File
@@ -12,8 +12,23 @@ func TestApplyDefaults(t *testing.T) {
Str string `default:"str"`
Duration time.Duration `default:"30s"`
}
cli := &CLI{}
err := ApplyDefaults(cli)
require.NoError(t, err)
require.Equal(t, &CLI{Str: "str", Duration: time.Second * 30}, cli)
tests := []struct {
name string
target CLI
expected CLI
}{
{name: "DefaultsWhenNotSet",
expected: CLI{Str: "str", Duration: time.Second * 30}},
{name: "PartiallySetDefaults",
target: CLI{Duration: time.Second},
expected: CLI{Str: "str", Duration: time.Second}},
}
for _, tt := range tests {
// nolint: scopelint
t.Run(tt.name, func(t *testing.T) {
err := ApplyDefaults(&tt.target)
require.NoError(t, err)
require.Equal(t, tt.expected, tt.target)
})
}
}
+3
View File
@@ -197,6 +197,9 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) {
if ctx.Error != nil {
return nil, &ParseError{error: ctx.Error, Context: ctx}
}
if err = ctx.Reset(); err != nil {
return nil, &ParseError{error: err, Context: ctx}
}
if err = k.applyHook(ctx, "BeforeResolve"); err != nil {
return nil, &ParseError{error: err, Context: ctx}
}
+49
View File
@@ -2,6 +2,7 @@ package kong
import (
"fmt"
"math"
"os"
"reflect"
"strconv"
@@ -308,6 +309,15 @@ func (v *Value) Apply(value reflect.Value) {
v.Set = true
}
// ApplyDefault value to field if it is not already set.
func (v *Value) ApplyDefault() error {
if reflectValueIsZero(v.Target) {
return v.Reset()
}
v.Set = true
return nil
}
// Reset this value to its default, either the zero value or the parsed result of its envar,
// or its "default" tag.
//
@@ -376,3 +386,42 @@ func (f *Flag) FormatPlaceHolder() string {
}
return strings.ToUpper(f.Name) + tail
}
// This is directly from the Go 1.13 source code.
func reflectValueIsZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.Complex64, reflect.Complex128:
c := v.Complex()
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !reflectValueIsZero(v.Index(i)) {
return false
}
}
return true
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return v.IsNil()
case reflect.String:
return v.Len() == 0
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !reflectValueIsZero(v.Field(i)) {
return false
}
}
return true
default:
// This should never happens, but will act as a safeguard for
// later, as a default value doesn't makes sense here.
panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()})
}
}
+3 -1
View File
@@ -5,7 +5,9 @@ import (
)
// Next should be called by Visitor to proceed with the walk.
type Next func(error) error
//
// The walk will terminate if "err" is non-nil.
type Next func(err error) error
// Visitor can be used to walk all nodes in the model.
type Visitor func(node Visitable, next Next) error