diff --git a/build.go b/build.go index 0616785..2a6a21d 100644 --- a/build.go +++ b/build.go @@ -177,19 +177,27 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv } value := &Value{ - Name: name, - Help: tag.Help, - Default: tag.Default, - Mapper: mapper, - Tag: tag, - Target: fv, - Enum: tag.Enum, + Name: name, + Help: tag.Help, + Default: tag.Default, + DefaultValue: reflect.New(fv.Type()).Elem(), + Mapper: mapper, + Tag: tag, + Target: fv, + Enum: tag.Enum, // Flags are optional by default, and args are required by default. Required: (!tag.Arg && tag.Required) || (tag.Arg && !tag.Optional), Format: tag.Format, } + if value.Default != "" { + err := value.Parse(Scan(tag.Default), value.DefaultValue) + if err != nil { + fail("invalid default value %q for field type %s.%s (of type %s)", value.Default, v.Type(), ft.Name, ft.Type) + } + } + if tag.Arg { node.Positional = append(node.Positional, value) } else { diff --git a/context.go b/context.go index 6df5c36..8175eef 100644 --- a/context.go +++ b/context.go @@ -181,18 +181,18 @@ func (c *Context) AddResolver(resolver ResolverFunc) { c.resolvers = append(c.resolvers, resolver) } -// FlagValue returns the set value of a flag if it was encountered and exists. +// FlagValue returns the set value of a flag if it was encountered and exists, or its default value. func (c *Context) FlagValue(flag *Flag) interface{} { for _, trace := range c.Path { if trace.Flag == flag { v, ok := c.values[trace.Flag.Value] if !ok { - return nil + break } return v.Interface() } } - return nil + return flag.DefaultValue.Interface() } // Recursively reset values to defaults (as specified in the grammar) or the zero value. diff --git a/help.go b/help.go index 21b6de7..52b6580 100644 --- a/help.go +++ b/help.go @@ -13,6 +13,20 @@ const ( defaultColumnPadding = 4 ) +// Help flag. +type helpValue bool + +func (h helpValue) BeforeApply(ctx *Context) error { + options := ctx.Kong.helpOptions + options.Summary = false + err := ctx.Kong.help(options, ctx) + if err != nil { + return err + } + ctx.Kong.Exit(1) + return nil +} + // HelpOptions for HelpPrinters. type HelpOptions struct { // Don't print top-level usage summary. diff --git a/kong.go b/kong.go index bce8f94..2de2b12 100644 --- a/kong.go +++ b/kong.go @@ -151,19 +151,6 @@ func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) { return nil } -type helpValue bool - -func (h helpValue) BeforeApply(ctx *Context) error { - options := ctx.Kong.helpOptions - options.Summary = false - err := ctx.Kong.help(options, ctx) - if err != nil { - return err - } - ctx.Kong.Exit(1) - return nil -} - // Provide additional builtin flags, if any. func (k *Kong) extraFlags() []*Flag { if k.noDefaultHelp { @@ -248,6 +235,35 @@ func (k *Kong) applyHook(ctx *Context, name string) error { return err } } + // Path[0] will always be the app root. + return k.applyHookToDefaultFlags(ctx, ctx.Path[0].Node(), name) +} + +// Call hook on any unset flags with default values. +func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) error { + if node == nil { + return nil + } + bindings := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars)) + for _, flag := range node.Flags { + if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() { + continue + } + method := getMethod(flag.Target, name) + if !method.IsValid() { + continue + } + path := &Path{Flag: flag} + if err := callMethod(name, flag.Target, method, bindings.clone().add(path)); err != nil { + return err + } + } + for _, branch := range node.Children { + err := k.applyHookToDefaultFlags(ctx, branch, name) + if err != nil { + return err + } + } return nil } diff --git a/kong_test.go b/kong_test.go index 450f711..6112c53 100644 --- a/kong_test.go +++ b/kong_test.go @@ -642,3 +642,15 @@ func TestUnnamedFieldEmbeds(t *testing.T) { require.Contains(t, buf.String(), `--one-flag=STRING`) require.Contains(t, buf.String(), `--two-flag=STRING`) } + +func TestHooksCalledForDefault(t *testing.T) { + var cli struct { + Flag hookValue `default:"default"` + } + + ctx := &hookContext{} + _, err := mustNew(t, &cli, kong.Bind(ctx)).Parse(nil) + require.NoError(t, err) + require.Equal(t, "default", string(cli.Flag)) + require.Equal(t, []string{"before:", "after:default"}, ctx.values) +} diff --git a/model.go b/model.go index 8103bf5..f0180d8 100644 --- a/model.go +++ b/model.go @@ -196,18 +196,19 @@ func (n *Node) Path() (out string) { // A Value is either a flag or a variable positional argument. type Value struct { - Flag *Flag // Nil if positional argument. - Name string - Help string - Default string - Enum string - Mapper Mapper - Tag *Tag - Target reflect.Value - Required bool - Set bool // Set to true when this value is set through some mechanism. - Format string // Formatting directive, if applicable. - Position int // Position (for positional arguments). + Flag *Flag // Nil if positional argument. + Name string + Help string + Default string + DefaultValue reflect.Value + Enum string + Mapper Mapper + Tag *Tag + Target reflect.Value + Required bool + Set bool // Set to true when this value is set through some mechanism. + Format string // Formatting directive, if applicable. + Position int // Position (for positional arguments). } // EnumMap returns a map of the enums in this value. diff --git a/resolver_test.go b/resolver_test.go index 1caa942..501eae9 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -208,3 +208,24 @@ func TestResolverSatisfiesRequired(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, cli.Int) } + +func TestResolverTriggersHooks(t *testing.T) { + ctx := &hookContext{} + + var cli struct { + Flag hookValue + } + + var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + if flag.Name == "flag" { + return "1", nil + } + return "", nil + } + + _, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolver(first)).Parse(nil) + require.NoError(t, err) + + require.Equal(t, "1", string(cli.Flag)) + require.Equal(t, []string{"before:", "after:1"}, ctx.values) +}