Convert resolvers to an interface with a Validate() method.

This commit is contained in:
Alec Thomas
2018-09-21 15:18:17 +10:00
parent 25399cde9e
commit c112a076e7
5 changed files with 82 additions and 37 deletions
+36 -26
View File
@@ -51,7 +51,7 @@ type Context struct {
Error error Error error
values map[*Value]reflect.Value // Temporary values during tracing. values map[*Value]reflect.Value // Temporary values during tracing.
resolvers []ResolverFunc // Extra context-specific resolvers. resolvers []Resolver // Extra context-specific resolvers.
scan *Scanner scan *Scanner
} }
@@ -119,6 +119,11 @@ func (c *Context) Empty() bool {
// Validate the current context. // Validate the current context.
func (c *Context) Validate() error { func (c *Context) Validate() error {
for _, resolver := range c.combineResolvers() {
if err := resolver.Validate(c.Model); err != nil {
return err
}
}
for _, path := range c.Path { for _, path := range c.Path {
if err := checkMissingFlags(path.Flags); err != nil { if err := checkMissingFlags(path.Flags); err != nil {
return err return err
@@ -183,7 +188,7 @@ func (c *Context) Command() string {
// AddResolver adds a context-specific resolver. // AddResolver adds a context-specific resolver.
// //
// This is most useful in the BeforeResolve() hook. // This is most useful in the BeforeResolve() hook.
func (c *Context) AddResolver(resolver ResolverFunc) { func (c *Context) AddResolver(resolver Resolver) {
c.resolvers = append(c.resolvers, resolver) c.resolvers = append(c.resolvers, resolver)
} }
@@ -345,31 +350,9 @@ func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
return nil return nil
} }
func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error {
if len(haystack) == 0 {
return fmt.Errorf(format, args...)
}
closestCandidates := []string{}
for _, candidate := range haystack {
if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 {
closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate))
}
}
prefix := fmt.Sprintf(format, args...)
if len(closestCandidates) == 1 {
return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0])
} else if len(closestCandidates) > 1 {
return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", "))
}
return fmt.Errorf("%s", prefix)
}
// Resolve walks through the traced path, applying resolvers to any unset flags. // Resolve walks through the traced path, applying resolvers to any unset flags.
func (c *Context) Resolve() error { func (c *Context) Resolve() error {
// Combine application-level resolvers and context resolvers. resolvers := c.combineResolvers()
resolvers := []ResolverFunc{}
resolvers = append(resolvers, c.Kong.resolvers...)
resolvers = append(resolvers, c.resolvers...)
if len(resolvers) == 0 { if len(resolvers) == 0 {
return nil return nil
} }
@@ -382,7 +365,7 @@ func (c *Context) Resolve() error {
continue continue
} }
for _, resolver := range resolvers { for _, resolver := range resolvers {
s, err := resolver(c, path, flag) s, err := resolver.Resolve(c, path, flag)
if err != nil { if err != nil {
return err return err
} }
@@ -407,6 +390,14 @@ func (c *Context) Resolve() error {
return nil return nil
} }
// Combine application-level resolvers and context resolvers.
func (c *Context) combineResolvers() []Resolver {
resolvers := []Resolver{}
resolvers = append(resolvers, c.Kong.resolvers...)
resolvers = append(resolvers, c.resolvers...)
return resolvers
}
func (c *Context) getValue(value *Value) reflect.Value { func (c *Context) getValue(value *Value) reflect.Value {
v, ok := c.values[value] v, ok := c.values[value]
if !ok { if !ok {
@@ -573,3 +564,22 @@ func checkMissingPositionals(positional int, values []*Value) error {
} }
return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " "))
} }
func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error {
if len(haystack) == 0 {
return fmt.Errorf(format, args...)
}
closestCandidates := []string{}
for _, candidate := range haystack {
if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 {
closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate))
}
}
prefix := fmt.Sprintf(format, args...)
if len(closestCandidates) == 1 {
return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0])
} else if len(closestCandidates) > 1 {
return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", "))
}
return fmt.Errorf("%s", prefix)
}
+1 -1
View File
@@ -44,7 +44,7 @@ type Kong struct {
bindings bindings bindings bindings
loader ConfigurationFunc loader ConfigurationFunc
resolvers []ResolverFunc resolvers []Resolver
registry *Registry registry *Registry
noDefaultHelp bool noDefaultHelp bool
+2 -2
View File
@@ -180,8 +180,8 @@ func ClearResolvers() OptionFunc {
} }
} }
// Resolver registers flag resolvers. // Resolvers registers flag resolvers.
func Resolver(resolvers ...ResolverFunc) OptionFunc { func Resolvers(resolvers ...Resolver) OptionFunc {
return func(k *Kong) error { return func(k *Kong) error {
k.resolvers = append(k.resolvers, resolvers...) k.resolvers = append(k.resolvers, resolvers...)
return nil return nil
+19 -1
View File
@@ -7,9 +7,27 @@ import (
"strings" "strings"
) )
// ResolverFunc resolves a Flag value from an external source. // A Resolver resolves a Flag value from an external source.
type Resolver interface {
// Validate configuration against Application.
//
// This can be used to validate that all provided configuration is valid within this application.
Validate(app *Application) error
// Resolve the value for a Flag.
Resolve(context *Context, parent *Path, flag *Flag) (string, error)
}
// ResolverFunc is a convenience type for non-validating Resolvers.
type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error) type ResolverFunc func(context *Context, parent *Path, flag *Flag) (string, error)
var _ Resolver = ResolverFunc(nil)
func (r ResolverFunc) Resolve(context *Context, parent *Path, flag *Flag) (string, error) { // nolint: golint
return r(context, parent, flag)
}
func (r ResolverFunc) Validate(app *Application) error { return nil } // nolint: golint
// JSON returns a Resolver that retrieves values from a JSON source. // JSON returns a Resolver that retrieves values from a JSON source.
// //
// Hyphens in flag names are replaced with underscores. // Hyphens in flag names are replaced with underscores.
+24 -7
View File
@@ -1,6 +1,7 @@
package kong_test package kong_test
import ( import (
"errors"
"os" "os"
"reflect" "reflect"
"strings" "strings"
@@ -91,7 +92,7 @@ func TestJSONBasic(t *testing.T) {
r, err := kong.JSON(strings.NewReader(json)) r, err := kong.JSON(strings.NewReader(json))
require.NoError(t, err) require.NoError(t, err)
parser := mustNew(t, &cli, kong.Resolver(r)) parser := mustNew(t, &cli, kong.Resolvers(r))
_, err = parser.Parse([]string{}) _, err = parser.Parse([]string{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "🍕", cli.String) require.Equal(t, "🍕", cli.String)
@@ -129,14 +130,14 @@ func TestResolverWithBool(t *testing.T) {
Bool bool Bool bool
} }
resolver := func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) {
if flag.Name == "bool" { if flag.Name == "bool" {
return "true", nil return "true", nil
} }
return "", nil return "", nil
} }
p := mustNew(t, &cli, kong.Resolver(resolver)) p := mustNew(t, &cli, kong.Resolvers(resolver))
_, err := p.Parse(nil) _, err := p.Parse(nil)
require.NoError(t, err) require.NoError(t, err)
@@ -162,7 +163,7 @@ func TestLastResolverWins(t *testing.T) {
return "", nil return "", nil
} }
p := mustNew(t, &cli, kong.Resolver(first), kong.Resolver(second)) p := mustNew(t, &cli, kong.Resolvers(first, second))
_, err := p.Parse(nil) _, err := p.Parse(nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int{2}, cli.Int) require.Equal(t, []int{2}, cli.Int)
@@ -173,13 +174,13 @@ func TestResolverSatisfiesRequired(t *testing.T) {
var cli struct { var cli struct {
Int int `required` Int int `required`
} }
resolver := func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) { var resolver kong.ResolverFunc = func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) {
if flag.Name == "int" { if flag.Name == "int" {
return "1", nil return "1", nil
} }
return "", nil return "", nil
} }
_, err := mustNew(t, &cli, kong.Resolver(resolver)).Parse(nil) _, err := mustNew(t, &cli, kong.Resolvers(resolver)).Parse(nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, cli.Int) require.Equal(t, 1, cli.Int)
} }
@@ -198,9 +199,25 @@ func TestResolverTriggersHooks(t *testing.T) {
return "", nil return "", nil
} }
_, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolver(first)).Parse(nil) _, err := mustNew(t, &cli, kong.Bind(ctx), kong.Resolvers(first)).Parse(nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "1", string(cli.Flag)) require.Equal(t, "1", string(cli.Flag))
require.Equal(t, []string{"before:", "after:1"}, ctx.values) require.Equal(t, []string{"before:", "after:1"}, ctx.values)
} }
type validatingResolver struct {
err error
}
func (v *validatingResolver) Validate(app *kong.Application) error { return v.err }
func (v *validatingResolver) Resolve(context *kong.Context, parent *kong.Path, flag *kong.Flag) (string, error) {
return "", nil
}
func TestValidatingResolverErrors(t *testing.T) {
resolver := &validatingResolver{err: errors.New("invalid")}
var cli struct{}
_, err := mustNew(t, &cli, kong.Resolvers(resolver)).Parse(nil)
require.EqualError(t, err, "invalid")
}