Move .Set = true into Decode().
This commit is contained in:
@@ -110,6 +110,7 @@ func buildNode(v reflect.Value, cmd bool) *Node {
|
||||
Name: name,
|
||||
Flag: flag,
|
||||
Help: help,
|
||||
Default: dflt,
|
||||
Decoder: decoder,
|
||||
Value: fv,
|
||||
Field: ft,
|
||||
@@ -124,7 +125,6 @@ func buildNode(v reflect.Value, cmd bool) *Node {
|
||||
node.Flags = append(node.Flags, &Flag{
|
||||
Value: value,
|
||||
Short: short,
|
||||
Default: dflt,
|
||||
Placeholder: placeholder,
|
||||
Env: env,
|
||||
})
|
||||
@@ -135,12 +135,8 @@ func buildNode(v reflect.Value, cmd bool) *Node {
|
||||
// Scan through argument positionals to ensure optional is never before a required
|
||||
last := true
|
||||
for _, p := range node.Positional {
|
||||
if p.Flag {
|
||||
continue
|
||||
}
|
||||
|
||||
if !last && p.Required {
|
||||
fail("arguments can not be required after an optional: %v", p.Name)
|
||||
fail("argument %q can not be required after an optional", p.Name)
|
||||
}
|
||||
|
||||
last = p.Required
|
||||
|
||||
@@ -75,10 +75,15 @@ func (k *Kong) reset(node *Node) {
|
||||
flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type()))
|
||||
if flag.Default != "" {
|
||||
flag.Decode(Scan(flag.Default))
|
||||
flag.Set = false
|
||||
}
|
||||
}
|
||||
for _, pos := range node.Positional {
|
||||
pos.Value.Set(reflect.Zero(pos.Value.Type()))
|
||||
if pos.Default != "" {
|
||||
pos.Decode(Scan(pos.Default))
|
||||
pos.Set = false
|
||||
}
|
||||
}
|
||||
for _, branch := range node.Children {
|
||||
if branch.Argument != nil {
|
||||
@@ -210,14 +215,14 @@ func (k *Kong) applyNode(scan *Scanner, node *Node, flags []*Flag) (command []st
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := chickMissingFlags(node.Children, flags); err != nil {
|
||||
if err := checkMissingFlags(node.Children, flags); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func chickMissingFlags(children []*Branch, flags []*Flag) error {
|
||||
func checkMissingFlags(children []*Branch, flags []*Flag) error {
|
||||
// Only check required missing fields at the last child.
|
||||
if len(children) > 0 {
|
||||
return nil
|
||||
@@ -290,7 +295,6 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
flag.Set = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,3 +241,13 @@ func TestMixedRequiredArgs(t *testing.T) {
|
||||
require.Equal(t, "gak", cli.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultValueForOptionalArg(t *testing.T) {
|
||||
var cli struct {
|
||||
Arg string `arg:"" optional:"" default:"default"`
|
||||
}
|
||||
p := mustNew(t, &cli)
|
||||
_, err := p.Parse(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "default", cli.Arg)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ type Value struct {
|
||||
Flag bool // True if flag, false if positional argument.
|
||||
Name string
|
||||
Help string
|
||||
Default string
|
||||
Decoder Decoder
|
||||
Field reflect.StructField
|
||||
Value reflect.Value
|
||||
@@ -34,7 +35,11 @@ type Value struct {
|
||||
}
|
||||
|
||||
func (v *Value) Decode(scan *Scanner) error {
|
||||
return v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value)
|
||||
err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value)
|
||||
if err == nil {
|
||||
v.Set = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type Positional = Value
|
||||
@@ -49,5 +54,4 @@ type Flag struct {
|
||||
Placeholder string
|
||||
Env string
|
||||
Short rune
|
||||
Default string
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user