Move .Set = true into Decode().

This commit is contained in:
Alec Thomas
2018-05-19 21:02:49 +10:00
parent 8e96da517d
commit c27dd50be6
4 changed files with 25 additions and 11 deletions
+2 -6
View File
@@ -110,6 +110,7 @@ func buildNode(v reflect.Value, cmd bool) *Node {
Name: name,
Flag: flag,
Help: help,
Default: dflt,
Decoder: decoder,
Value: fv,
Field: ft,
@@ -124,7 +125,6 @@ func buildNode(v reflect.Value, cmd bool) *Node {
node.Flags = append(node.Flags, &Flag{
Value: value,
Short: short,
Default: dflt,
Placeholder: placeholder,
Env: env,
})
@@ -135,12 +135,8 @@ func buildNode(v reflect.Value, cmd bool) *Node {
// Scan through argument positionals to ensure optional is never before a required
last := true
for _, p := range node.Positional {
if p.Flag {
continue
}
if !last && p.Required {
fail("arguments can not be required after an optional: %v", p.Name)
fail("argument %q can not be required after an optional", p.Name)
}
last = p.Required
+7 -3
View File
@@ -75,10 +75,15 @@ func (k *Kong) reset(node *Node) {
flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type()))
if flag.Default != "" {
flag.Decode(Scan(flag.Default))
flag.Set = false
}
}
for _, pos := range node.Positional {
pos.Value.Set(reflect.Zero(pos.Value.Type()))
if pos.Default != "" {
pos.Decode(Scan(pos.Default))
pos.Set = false
}
}
for _, branch := range node.Children {
if branch.Argument != nil {
@@ -210,14 +215,14 @@ func (k *Kong) applyNode(scan *Scanner, node *Node, flags []*Flag) (command []st
return nil, err
}
if err := chickMissingFlags(node.Children, flags); err != nil {
if err := checkMissingFlags(node.Children, flags); err != nil {
return nil, err
}
return
}
func chickMissingFlags(children []*Branch, flags []*Flag) error {
func checkMissingFlags(children []*Branch, flags []*Flag) error {
// Only check required missing fields at the last child.
if len(children) > 0 {
return nil
@@ -290,7 +295,6 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag)
if err != nil {
return err
}
flag.Set = true
return nil
}
}
+10
View File
@@ -241,3 +241,13 @@ func TestMixedRequiredArgs(t *testing.T) {
require.Equal(t, "gak", cli.Name)
})
}
func TestDefaultValueForOptionalArg(t *testing.T) {
var cli struct {
Arg string `arg:"" optional:"" default:"default"`
}
p := mustNew(t, &cli)
_, err := p.Parse(nil)
require.NoError(t, err)
require.Equal(t, "default", cli.Arg)
}
+6 -2
View File
@@ -25,6 +25,7 @@ type Value struct {
Flag bool // True if flag, false if positional argument.
Name string
Help string
Default string
Decoder Decoder
Field reflect.StructField
Value reflect.Value
@@ -34,7 +35,11 @@ type Value struct {
}
func (v *Value) Decode(scan *Scanner) error {
return v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value)
err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value)
if err == nil {
v.Set = true
}
return err
}
type Positional = Value
@@ -49,5 +54,4 @@ type Flag struct {
Placeholder string
Env string
Short rune
Default string
}