ApplyDefaults() now only applies defaults if the value is not already otherwise set.
This commit is contained in:
+27
-10
@@ -60,8 +60,8 @@ type Context struct {
|
|||||||
//
|
//
|
||||||
// The returned Context will include a Path of all commands, arguments, positionals and flags.
|
// The returned Context will include a Path of all commands, arguments, positionals and flags.
|
||||||
//
|
//
|
||||||
// This just constructs a new trace. To fully apply the trace you must call Resolve(), Validate() and
|
// This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(),
|
||||||
// Apply().
|
// Validate() and Apply().
|
||||||
func Trace(k *Kong, args []string) (*Context, error) {
|
func Trace(k *Kong, args []string) (*Context, error) {
|
||||||
c := &Context{
|
c := &Context{
|
||||||
Kong: k,
|
Kong: k,
|
||||||
@@ -74,11 +74,6 @@ func Trace(k *Kong, args []string) (*Context, error) {
|
|||||||
bindings: bindings{},
|
bindings: bindings{},
|
||||||
}
|
}
|
||||||
c.Error = c.trace(c.Model.Node)
|
c.Error = c.trace(c.Model.Node)
|
||||||
err := c.reset(c.Model.Node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,9 +231,9 @@ func (c *Context) FlagValue(flag *Flag) interface{} {
|
|||||||
return flag.DefaultValue.Interface()
|
return flag.DefaultValue.Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
|
// Reset recursively resets values to defaults (as specified in the grammar) or the zero value.
|
||||||
func (c *Context) reset(node *Node) error {
|
func (c *Context) Reset() error {
|
||||||
return Visit(node, func(node Visitable, next Next) error {
|
return Visit(c.Model.Node, func(node Visitable, next Next) error {
|
||||||
if value, ok := node.(*Value); ok {
|
if value, ok := node.(*Value); ok {
|
||||||
return next(value.Reset())
|
return next(value.Reset())
|
||||||
}
|
}
|
||||||
@@ -441,6 +436,28 @@ func (c *Context) getValue(value *Value) reflect.Value {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults if they are not already set.
|
||||||
|
func (c *Context) ApplyDefaults() error {
|
||||||
|
return Visit(c.Model.Node, func(node Visitable, next Next) error {
|
||||||
|
var value *Value
|
||||||
|
switch node := node.(type) {
|
||||||
|
case *Flag:
|
||||||
|
value = node.Value
|
||||||
|
case *Node:
|
||||||
|
value = node.Argument
|
||||||
|
case *Value:
|
||||||
|
value = node
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
if err := value.ApplyDefault(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Apply traced context to the target grammar.
|
// Apply traced context to the target grammar.
|
||||||
func (c *Context) Apply() (string, error) {
|
func (c *Context) Apply() (string, error) {
|
||||||
path := []string{}
|
path := []string{}
|
||||||
|
|||||||
+14
-3
@@ -1,11 +1,22 @@
|
|||||||
package kong
|
package kong
|
||||||
|
|
||||||
// ApplyDefaults applies defaults to a struct.
|
// ApplyDefaults if they are not already set.
|
||||||
func ApplyDefaults(target interface{}, options ...Option) error {
|
func ApplyDefaults(target interface{}, options ...Option) error {
|
||||||
app, err := New(target, options...)
|
app, err := New(target, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = app.Parse(nil)
|
ctx, err := Trace(app, nil)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = ctx.Resolve()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = ctx.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ctx.ApplyDefaults()
|
||||||
}
|
}
|
||||||
|
|||||||
+19
-4
@@ -12,8 +12,23 @@ func TestApplyDefaults(t *testing.T) {
|
|||||||
Str string `default:"str"`
|
Str string `default:"str"`
|
||||||
Duration time.Duration `default:"30s"`
|
Duration time.Duration `default:"30s"`
|
||||||
}
|
}
|
||||||
cli := &CLI{}
|
tests := []struct {
|
||||||
err := ApplyDefaults(cli)
|
name string
|
||||||
require.NoError(t, err)
|
target CLI
|
||||||
require.Equal(t, &CLI{Str: "str", Duration: time.Second * 30}, cli)
|
expected CLI
|
||||||
|
}{
|
||||||
|
{name: "DefaultsWhenNotSet",
|
||||||
|
expected: CLI{Str: "str", Duration: time.Second * 30}},
|
||||||
|
{name: "PartiallySetDefaults",
|
||||||
|
target: CLI{Duration: time.Second},
|
||||||
|
expected: CLI{Str: "str", Duration: time.Second}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
// nolint: scopelint
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := ApplyDefaults(&tt.target)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.expected, tt.target)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ func (k *Kong) Parse(args []string) (ctx *Context, err error) {
|
|||||||
if ctx.Error != nil {
|
if ctx.Error != nil {
|
||||||
return nil, &ParseError{error: ctx.Error, Context: ctx}
|
return nil, &ParseError{error: ctx.Error, Context: ctx}
|
||||||
}
|
}
|
||||||
|
if err = ctx.Reset(); err != nil {
|
||||||
|
return nil, &ParseError{error: err, Context: ctx}
|
||||||
|
}
|
||||||
if err = k.applyHook(ctx, "BeforeResolve"); err != nil {
|
if err = k.applyHook(ctx, "BeforeResolve"); err != nil {
|
||||||
return nil, &ParseError{error: err, Context: ctx}
|
return nil, &ParseError{error: err, Context: ctx}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package kong
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -308,6 +309,15 @@ func (v *Value) Apply(value reflect.Value) {
|
|||||||
v.Set = true
|
v.Set = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyDefault value to field if it is not already set.
|
||||||
|
func (v *Value) ApplyDefault() error {
|
||||||
|
if reflectValueIsZero(v.Target) {
|
||||||
|
return v.Reset()
|
||||||
|
}
|
||||||
|
v.Set = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Reset this value to its default, either the zero value or the parsed result of its envar,
|
// Reset this value to its default, either the zero value or the parsed result of its envar,
|
||||||
// or its "default" tag.
|
// or its "default" tag.
|
||||||
//
|
//
|
||||||
@@ -376,3 +386,42 @@ func (f *Flag) FormatPlaceHolder() string {
|
|||||||
}
|
}
|
||||||
return strings.ToUpper(f.Name) + tail
|
return strings.ToUpper(f.Name) + tail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is directly from the Go 1.13 source code.
|
||||||
|
func reflectValueIsZero(v reflect.Value) bool {
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return !v.Bool()
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return v.Int() == 0
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
return v.Uint() == 0
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return math.Float64bits(v.Float()) == 0
|
||||||
|
case reflect.Complex64, reflect.Complex128:
|
||||||
|
c := v.Complex()
|
||||||
|
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
|
||||||
|
case reflect.Array:
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
if !reflectValueIsZero(v.Index(i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
|
||||||
|
return v.IsNil()
|
||||||
|
case reflect.String:
|
||||||
|
return v.Len() == 0
|
||||||
|
case reflect.Struct:
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
if !reflectValueIsZero(v.Field(i)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
// This should never happens, but will act as a safeguard for
|
||||||
|
// later, as a default value doesn't makes sense here.
|
||||||
|
panic(&reflect.ValueError{"reflect.Value.IsZero", v.Kind()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Next should be called by Visitor to proceed with the walk.
|
// Next should be called by Visitor to proceed with the walk.
|
||||||
type Next func(error) error
|
//
|
||||||
|
// The walk will terminate if "err" is non-nil.
|
||||||
|
type Next func(err error) error
|
||||||
|
|
||||||
// Visitor can be used to walk all nodes in the model.
|
// Visitor can be used to walk all nodes in the model.
|
||||||
type Visitor func(node Visitable, next Next) error
|
type Visitor func(node Visitable, next Next) error
|
||||||
|
|||||||
Reference in New Issue
Block a user