From c27dd50be683a5d60517894e428cf46db73afeac Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sat, 19 May 2018 21:02:49 +1000 Subject: [PATCH] Move .Set = true into Decode(). --- build.go | 8 ++------ kong.go | 10 +++++++--- kong_test.go | 10 ++++++++++ model.go | 8 ++++++-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/build.go b/build.go index 03aae9f..e7e3a96 100644 --- a/build.go +++ b/build.go @@ -110,6 +110,7 @@ func buildNode(v reflect.Value, cmd bool) *Node { Name: name, Flag: flag, Help: help, + Default: dflt, Decoder: decoder, Value: fv, Field: ft, @@ -124,7 +125,6 @@ func buildNode(v reflect.Value, cmd bool) *Node { node.Flags = append(node.Flags, &Flag{ Value: value, Short: short, - Default: dflt, Placeholder: placeholder, Env: env, }) @@ -135,12 +135,8 @@ func buildNode(v reflect.Value, cmd bool) *Node { // Scan through argument positionals to ensure optional is never before a required last := true for _, p := range node.Positional { - if p.Flag { - continue - } - if !last && p.Required { - fail("arguments can not be required after an optional: %v", p.Name) + fail("argument %q can not be required after an optional", p.Name) } last = p.Required diff --git a/kong.go b/kong.go index f6c6b91..8246db7 100644 --- a/kong.go +++ b/kong.go @@ -75,10 +75,15 @@ func (k *Kong) reset(node *Node) { flag.Value.Value.Set(reflect.Zero(flag.Value.Value.Type())) if flag.Default != "" { flag.Decode(Scan(flag.Default)) + flag.Set = false } } for _, pos := range node.Positional { pos.Value.Set(reflect.Zero(pos.Value.Type())) + if pos.Default != "" { + pos.Decode(Scan(pos.Default)) + pos.Set = false + } } for _, branch := range node.Children { if branch.Argument != nil { @@ -210,14 +215,14 @@ func (k *Kong) applyNode(scan *Scanner, node *Node, flags []*Flag) (command []st return nil, err } - if err := chickMissingFlags(node.Children, flags); err != nil { + if err := checkMissingFlags(node.Children, flags); err != nil { return nil, err } return } -func chickMissingFlags(children []*Branch, flags []*Flag) error { +func checkMissingFlags(children []*Branch, flags []*Flag) error { // Only check required missing fields at the last child. if len(children) > 0 { return nil @@ -290,7 +295,6 @@ func matchFlags(flags []*Flag, token Token, scan *Scanner, matcher func(f *Flag) if err != nil { return err } - flag.Set = true return nil } } diff --git a/kong_test.go b/kong_test.go index 066453a..b9f4f76 100644 --- a/kong_test.go +++ b/kong_test.go @@ -241,3 +241,13 @@ func TestMixedRequiredArgs(t *testing.T) { require.Equal(t, "gak", cli.Name) }) } + +func TestDefaultValueForOptionalArg(t *testing.T) { + var cli struct { + Arg string `arg:"" optional:"" default:"default"` + } + p := mustNew(t, &cli) + _, err := p.Parse(nil) + require.NoError(t, err) + require.Equal(t, "default", cli.Arg) +} diff --git a/model.go b/model.go index 8d0bd2e..6a0fcf5 100644 --- a/model.go +++ b/model.go @@ -25,6 +25,7 @@ type Value struct { Flag bool // True if flag, false if positional argument. Name string Help string + Default string Decoder Decoder Field reflect.StructField Value reflect.Value @@ -34,7 +35,11 @@ type Value struct { } func (v *Value) Decode(scan *Scanner) error { - return v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value) + err := v.Decoder.Decode(&DecoderContext{Value: v}, scan, v.Value) + if err == nil { + v.Set = true + } + return err } type Positional = Value @@ -49,5 +54,4 @@ type Flag struct { Placeholder string Env string Short rune - Default string }