diff --git a/context.go b/context.go index e70af31..fca1e81 100644 --- a/context.go +++ b/context.go @@ -51,7 +51,7 @@ type Context struct { Error error values map[*Value]reflect.Value // Temporary values during tracing. - resolvers []ResolverFunc // Extra context-specific resolvers. + resolvers []Resolver // Extra context-specific resolvers. scan *Scanner } @@ -119,6 +119,11 @@ func (c *Context) Empty() bool { // Validate the current context. func (c *Context) Validate() error { + for _, resolver := range c.combineResolvers() { + if err := resolver.Validate(c.Model); err != nil { + return err + } + } for _, path := range c.Path { if err := checkMissingFlags(path.Flags); err != nil { return err @@ -183,7 +188,7 @@ func (c *Context) Command() string { // AddResolver adds a context-specific resolver. // // This is most useful in the BeforeResolve() hook. -func (c *Context) AddResolver(resolver ResolverFunc) { +func (c *Context) AddResolver(resolver Resolver) { c.resolvers = append(c.resolvers, resolver) } @@ -345,31 +350,9 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo return nil } -func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error { - if len(haystack) == 0 { - return fmt.Errorf(format, args...) - } - closestCandidates := []string{} - for _, candidate := range haystack { - if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 { - closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate)) - } - } - prefix := fmt.Sprintf(format, args...) - if len(closestCandidates) == 1 { - return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0]) - } else if len(closestCandidates) > 1 { - return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", ")) - } - return fmt.Errorf("%s", prefix) -} - // Resolve walks through the traced path, applying resolvers to any unset flags. func (c *Context) Resolve() error { - // Combine application-level resolvers and context resolvers. - resolvers := []ResolverFunc{} - resolvers = append(resolvers, c.Kong.resolvers...) - resolvers = append(resolvers, c.resolvers...) + resolvers := c.combineResolvers() if len(resolvers) == 0 { return nil } @@ -382,7 +365,7 @@ func (c *Context) Resolve() error { continue } for _, resolver := range resolvers { - s, err := resolver(c, path, flag) + s, err := resolver.Resolve(c, path, flag) if err != nil { return err } @@ -407,6 +390,14 @@ func (c *Context) Resolve() error { return nil } +// Combine application-level resolvers and context resolvers. +func (c *Context) combineResolvers() []Resolver { + resolvers := []Resolver{} + resolvers = append(resolvers, c.Kong.resolvers...) + resolvers = append(resolvers, c.resolvers...) + return resolvers +} + func (c *Context) getValue(value *Value) reflect.Value { v, ok := c.values[value] if !ok { @@ -573,3 +564,22 @@ func checkMissingPositionals(positional int, values []*Value) error { } return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } + +func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error { + if len(haystack) == 0 { + return fmt.Errorf(format, args...) + } + closestCandidates := []string{} + for _, candidate := range haystack { + if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 { + closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate)) + } + } + prefix := fmt.Sprintf(format, args...) + if len(closestCandidates) == 1 { + return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0]) + } else if len(closestCandidates) > 1 { + return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", ")) + } + return fmt.Errorf("%s", prefix) +} diff --git a/kong.go b/kong.go index b1c45c4..5e31009 100644 --- a/kong.go +++ b/kong.go @@ -44,7 +44,7 @@ type Kong struct { bindings bindings loader ConfigurationFunc - resolvers []ResolverFunc + resolvers []Resolver registry *Registry noDefaultHelp bool diff --git a/options.go b/options.go index 5bcec7e..140c1e0 100644 --- a/options.go +++ b/options.go @@ -180,8 +180,8 @@ func ClearResolvers() OptionFunc { } } -// Resolver registers flag resolvers. -func Resolver(resolvers ...ResolverFunc) OptionFunc { +// Resolvers registers flag resolvers. +func Resolvers(resolvers ...Resolver) OptionFunc { return func(k *Kong) error { k.resolvers = append(k.resolvers, resolvers...) return nil diff --git a/resolver.go b/resolver.go index 9de21a5..5cc92ce 100644 --- a/resolver.go +++ b/resolver.go @@ -7,9 +7,27 @@ import ( "strings" ) -// ResolverFunc resolves a Flag value from an external source. +// A Resolver resolves a Flag value from an external source. +type Resolver interface { + // Validate configuration against Application. + // + // This can be used to validate that all provided configuration is valid within this application. + Validate(app *Application) error + + // Resolve the value for a Flag. + Resolve(context *Context, parent *Path, flag *Flag) (string, error) +} + +// ResolverFunc is a convenience type for non-validating Resolvers. type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error) +var _ Resolver = ResolverFunc(nil) + +func (r ResolverFunc) Resolve(context *Context, parent *Path, flag *Flag) (string, error) { // nolint: golint + return r(context, parent, flag) +} +func (r ResolverFunc) Validate(app *Application) error { return nil } // nolint: golint + // JSON returns a Resolver that retrieves values from a JSON source. // // Hyphens in flag names are replaced with underscores. diff --git a/resolver_test.go b/resolver_test.go index ad0cebe..c23c4f3 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -1,6 +1,7 @@ package kong_test import ( + "errors" "os" "reflect" "strings" @@ -91,7 +92,7 @@ func TestJSONBasic(t *testing.T) { r, err := kong.JSON(strings.NewReader(json)) require.NoError(t, err) - parser := mustNew(t, &cli, kong.Resolver(r)) + parser := mustNew(t, &cli, kong.Resolvers(r)) _, err = parser.Parse([]string{}) require.NoError(t, err) require.Equal(t, "🍕", cli.String) @@ -129,14 +130,14 @@ func TestResolverWithBool(t *testing.T) { Bool bool } - resolver := func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { if flag.Name == "bool" { return "true", nil } return "", nil } - p := mustNew(t, &cli, kong.Resolver(resolver)) + p := mustNew(t, &cli, kong.Resolvers(resolver)) _, err := p.Parse(nil) require.NoError(t, err) @@ -162,7 +163,7 @@ func TestLastResolverWins(t *testing.T) { return "", nil } - p := mustNew(t, &cli, kong.Resolver(first), kong.Resolver(second)) + p := mustNew(t, &cli, kong.Resolvers(first, second)) _, err := p.Parse(nil) require.NoError(t, err) require.Equal(t, []int{2}, cli.Int) @@ -173,13 +174,13 @@ func TestResolverSatisfiesRequired(t *testing.T) { var cli struct { Int int `required` } - resolver := func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { if flag.Name == "int" { return "1", nil } return "", nil } - _, err := mustNew(t, &cli, kong.Resolver(resolver)).Parse(nil) + _, err := mustNew(t, &cli, kong.Resolvers(resolver)).Parse(nil) require.NoError(t, err) require.Equal(t, 1, cli.Int) } @@ -198,9 +199,25 @@ func TestResolverTriggersHooks(t *testing.T) { return "", nil } - _, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolver(first)).Parse(nil) + _, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolvers(first)).Parse(nil) require.NoError(t, err) require.Equal(t, "1", string(cli.Flag)) require.Equal(t, []string{"before:", "after:1"}, ctx.values) } + +type validatingResolver struct { + err error +} + +func (v *validatingResolver) Validate(app *kong.Application) error { return v.err } +func (v *validatingResolver) Resolve(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { + return "", nil +} + +func TestValidatingResolverErrors(t *testing.T) { + resolver := &validatingResolver{err: errors.New("invalid")} + var cli struct{} + _, err := mustNew(t, &cli, kong.Resolvers(resolver)).Parse(nil) + require.EqualError(t, err, "invalid") +}