Apply hooks to default values.

This commit is contained in:
Alec Thomas
2018-09-20 18:16:45 +10:00
parent 54338bd8b1
commit 6fa83bdc0e
7 changed files with 107 additions and 35 deletions
+8
View File
@@ -180,6 +180,7 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv
Name: name, Name: name,
Help: tag.Help, Help: tag.Help,
Default: tag.Default, Default: tag.Default,
DefaultValue: reflect.New(fv.Type()).Elem(),
Mapper: mapper, Mapper: mapper,
Tag: tag, Tag: tag,
Target: fv, Target: fv,
@@ -190,6 +191,13 @@ func buildField(k *Kong, node *Node, v reflect.Value, ft reflect.StructField, fv
Format: tag.Format, Format: tag.Format,
} }
if value.Default != "" {
err := value.Parse(Scan(tag.Default), value.DefaultValue)
if err != nil {
fail("invalid default value %q for field type %s.%s (of type %s)", value.Default, v.Type(), ft.Name, ft.Type)
}
}
if tag.Arg { if tag.Arg {
node.Positional = append(node.Positional, value) node.Positional = append(node.Positional, value)
} else { } else {
+3 -3
View File
@@ -181,18 +181,18 @@ func (c *Context) AddResolver(resolver ResolverFunc) {
c.resolvers = append(c.resolvers, resolver) c.resolvers = append(c.resolvers, resolver)
} }
// FlagValue returns the set value of a flag if it was encountered and exists. // FlagValue returns the set value of a flag if it was encountered and exists, or its default value.
func (c *Context) FlagValue(flag *Flag) interface{} { func (c *Context) FlagValue(flag *Flag) interface{} {
for _, trace := range c.Path { for _, trace := range c.Path {
if trace.Flag == flag { if trace.Flag == flag {
v, ok := c.values[trace.Flag.Value] v, ok := c.values[trace.Flag.Value]
if !ok { if !ok {
return nil break
} }
return v.Interface() return v.Interface()
} }
} }
return nil return flag.DefaultValue.Interface()
} }
// Recursively reset values to defaults (as specified in the grammar) or the zero value. // Recursively reset values to defaults (as specified in the grammar) or the zero value.
+14
View File
@@ -13,6 +13,20 @@ const (
defaultColumnPadding = 4 defaultColumnPadding = 4
) )
// Help flag.
type helpValue bool
func (h helpValue) BeforeApply(ctx *Context) error {
options := ctx.Kong.helpOptions
options.Summary = false
err := ctx.Kong.help(options, ctx)
if err != nil {
return err
}
ctx.Kong.Exit(1)
return nil
}
// HelpOptions for HelpPrinters. // HelpOptions for HelpPrinters.
type HelpOptions struct { type HelpOptions struct {
// Don't print top-level usage summary. // Don't print top-level usage summary.
+29 -13
View File
@@ -151,19 +151,6 @@ func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) {
return nil return nil
} }
type helpValue bool
func (h helpValue) BeforeApply(ctx *Context) error {
options := ctx.Kong.helpOptions
options.Summary = false
err := ctx.Kong.help(options, ctx)
if err != nil {
return err
}
ctx.Kong.Exit(1)
return nil
}
// Provide additional builtin flags, if any. // Provide additional builtin flags, if any.
func (k *Kong) extraFlags() []*Flag { func (k *Kong) extraFlags() []*Flag {
if k.noDefaultHelp { if k.noDefaultHelp {
@@ -248,6 +235,35 @@ func (k *Kong) applyHook(ctx *Context, name string) error {
return err return err
} }
} }
// Path[0] will always be the app root.
return k.applyHookToDefaultFlags(ctx, ctx.Path[0].Node(), name)
}
// Call hook on any unset flags with default values.
func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) error {
if node == nil {
return nil
}
bindings := k.bindings.clone().add(ctx).add(node.Vars().CloneWith(k.vars))
for _, flag := range node.Flags {
if flag.Default == "" || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
}
path := &Path{Flag: flag}
if err := callMethod(name, flag.Target, method, bindings.clone().add(path)); err != nil {
return err
}
}
for _, branch := range node.Children {
err := k.applyHookToDefaultFlags(ctx, branch, name)
if err != nil {
return err
}
}
return nil return nil
} }
+12
View File
@@ -642,3 +642,15 @@ func TestUnnamedFieldEmbeds(t *testing.T) {
require.Contains(t, buf.String(), `--one-flag=STRING`) require.Contains(t, buf.String(), `--one-flag=STRING`)
require.Contains(t, buf.String(), `--two-flag=STRING`) require.Contains(t, buf.String(), `--two-flag=STRING`)
} }
func TestHooksCalledForDefault(t *testing.T) {
var cli struct {
Flag hookValue `default:"default"`
}
ctx := &hookContext{}
_, err := mustNew(t, &cli, kong.Bind(ctx)).Parse(nil)
require.NoError(t, err)
require.Equal(t, "default", string(cli.Flag))
require.Equal(t, []string{"before:", "after:default"}, ctx.values)
}
+1
View File
@@ -200,6 +200,7 @@ type Value struct {
Name string Name string
Help string Help string
Default string Default string
DefaultValue reflect.Value
Enum string Enum string
Mapper Mapper Mapper Mapper
Tag *Tag Tag *Tag
+21
View File
@@ -208,3 +208,24 @@ func TestResolverSatisfiesRequired(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, cli.Int) require.Equal(t, 1, cli.Int)
} }
func TestResolverTriggersHooks(t *testing.T) {
ctx := &hookContext{}
var cli struct {
Flag hookValue
}
var first kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) {
if flag.Name == "flag" {
return "1", nil
}
return "", nil
}
_, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolver(first)).Parse(nil)
require.NoError(t, err)
require.Equal(t, "1", string(cli.Flag))
require.Equal(t, []string{"before:", "after:1"}, ctx.values)
}