ApplyDefaults() now only applies defaults if the value is not already otherwise set.
This commit is contained in:
+27
-10
@@ -60,8 +60,8 @@ type Context struct {
|
||||
//
|
||||
// The returned Context will include a Path of all commands, arguments, positionals and flags.
|
||||
//
|
||||
// This just constructs a new trace. To fully apply the trace you must call Resolve(), Validate() and
|
||||
// Apply().
|
||||
// This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(),
|
||||
// Validate() and Apply().
|
||||
func Trace(k *Kong, args []string) (*Context, error) {
|
||||
c := &Context{
|
||||
Kong: k,
|
||||
@@ -74,11 +74,6 @@ func Trace(k *Kong, args []string) (*Context, error) {
|
||||
bindings: bindings{},
|
||||
}
|
||||
c.Error = c.trace(c.Model.Node)
|
||||
err := c.reset(c.Model.Node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -236,9 +231,9 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
|
||||
return flag.DefaultValue.Interface()
|
||||
}
|
||||
|
||||
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
|
||||
func (c *Context) reset(node *Node) error {
|
||||
return Visit(node, func(node Visitable, next Next) error {
|
||||
// Reset recursively resets values to defaults (as specified in the grammar) or the zero value.
|
||||
func (c *Context) Reset() error {
|
||||
return Visit(c.Model.Node, func(node Visitable, next Next) error {
|
||||
if value, ok := node.(*Value); ok {
|
||||
return next(value.Reset())
|
||||
}
|
||||
@@ -441,6 +436,28 @@ func (c *Context) getValue(value *Value) reflect.Value {
|
||||
return v
|
||||
}
|
||||
|
||||
// ApplyDefaults if they are not already set.
|
||||
func (c *Context) ApplyDefaults() error {
|
||||
return Visit(c.Model.Node, func(node Visitable, next Next) error {
|
||||
var value *Value
|
||||
switch node := node.(type) {
|
||||
case *Flag:
|
||||
value = node.Value
|
||||
case *Node:
|
||||
value = node.Argument
|
||||
case *Value:
|
||||
value = node
|
||||
default:
|
||||
}
|
||||
if value != nil {
|
||||
if err := value.ApplyDefault(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return next(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// Apply traced context to the target grammar.
|
||||
func (c *Context) Apply() (string, error) {
|
||||
path := []string{}
|
||||
|
||||
+14
-3
@@ -1,11 +1,22 @@
|
||||
package kong
|
||||
|
||||
// ApplyDefaults applies defaults to a struct.
|
||||
// ApplyDefaults if they are not already set.
|
||||
func ApplyDefaults(target interface{}, options ...Option) error {
|
||||
app, err := New(target, options...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = app.Parse(nil)
|
||||
return err
|
||||
ctx, err := Trace(app, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ctx.Resolve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ctx.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.ApplyDefaults()
|
||||
}
|
||||
|
||||
+19
-4
@@ -12,8 +12,23 @@ func TestApplyDefaults(t *testing.T) {
|
||||
Str string `default:"str"`
|
||||
Duration time.Duration `default:"30s"`
|
||||
}
|
||||
cli := &CLI{}
|
||||
err := ApplyDefaults(cli)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &CLI{Str: "str", Duration: time.Second * 30}, cli)
|
||||
tests := []struct {
|
||||
name string
|
||||
target CLI
|
||||
expected CLI
|
||||
}{
|
||||
{name: "DefaultsWhenNotSet",
|
||||
expected: CLI{Str: "str", Duration: time.Second * 30}},
|
||||
{name: "PartiallySetDefaults",
|
||||
target: CLI{Duration: time.Second},
|
||||
expected: CLI{Str: "str", Duration: time.Second}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
// nolint: scopelint
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ApplyDefaults(&tt.target)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, tt.target)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,6 +197,9 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) {
|
||||
if ctx.Error != nil {
|
||||
return nil, &ParseError{error: ctx.Error, Context: ctx}
|
||||
}
|
||||
if err = ctx.Reset(); err != nil {
|
||||
return nil, &ParseError{error: err, Context: ctx}
|
||||
}
|
||||
if err = k.applyHook(ctx, "BeforeResolve"); err != nil {
|
||||
return nil, &ParseError{error: err, Context: ctx}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package kong
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -308,6 +309,15 @@ func (v *Value) Apply(value reflect.Value) {
|
||||
v.Set = true
|
||||
}
|
||||
|
||||
// ApplyDefault value to field if it is not already set.
|
||||
func (v *Value) ApplyDefault() error {
|
||||
if reflectValueIsZero(v.Target) {
|
||||
return v.Reset()
|
||||
}
|
||||
v.Set = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset this value to its default, either the zero value or the parsed result of its envar,
|
||||
// or its "default" tag.
|
||||
//
|
||||
@@ -376,3 +386,42 @@ func (f *Flag) FormatPlaceHolder() string {
|
||||
}
|
||||
return strings.ToUpper(f.Name) + tail
|
||||
}
|
||||
|
||||
// This is directly from the Go 1.13 source code.
|
||||
func reflectValueIsZero(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return math.Float64bits(v.Float()) == 0
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
c := v.Complex()
|
||||
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
|
||||
case reflect.Array:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !reflectValueIsZero(v.Index(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
|
||||
return v.IsNil()
|
||||
case reflect.String:
|
||||
return v.Len() == 0
|
||||
case reflect.Struct:
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !reflectValueIsZero(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
// This should never happens, but will act as a safeguard for
|
||||
// later, as a default value doesn't makes sense here.
|
||||
panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
)
|
||||
|
||||
// Next should be called by Visitor to proceed with the walk.
|
||||
type Next func(error) error
|
||||
//
|
||||
// The walk will terminate if "err" is non-nil.
|
||||
type Next func(err error) error
|
||||
|
||||
// Visitor can be used to walk all nodes in the model.
|
||||
type Visitor func(node Visitable, next Next) error
|
||||
|
||||
Reference in New Issue
Block a user