Call Validate() functions on nodes if present.

This commit is contained in:
Alec Thomas
2020-10-21 19:13:02 +11:00
parent a062611ecf
commit 38db823367
3 changed files with 294 additions and 138 deletions
+78 -4
View File
@@ -44,6 +44,27 @@ func (p *Path) Node() *Node {
return nil
}
// Visitable returns the Visitable for this path element.
func (p *Path) Visitable() Visitable {
switch {
case p.App != nil:
return p.App
case p.Argument != nil:
return p.Argument
case p.Command != nil:
return p.Command
case p.Flag != nil:
return p.Flag
case p.Positional != nil:
return p.Positional
}
return nil
}
// Context contains the current parse context.
type Context struct {
*Kong
@@ -136,10 +157,19 @@ func (c *Context) Empty() bool {
// Validate the current context.
func (c *Context) Validate() error { // nolint: gocyclo
err := Visit(c.Model, func(node Visitable, next Next) error {
if value, ok := node.(*Value); ok {
_, ok := os.LookupEnv(value.Tag.Env)
if value.Enum != "" && (!value.Required || value.Default != "" || (value.Tag.Env != "" && ok)) {
if err := checkEnum(value, value.Target); err != nil {
switch node := node.(type) {
case *Value:
_, ok := os.LookupEnv(node.Tag.Env)
if node.Enum != "" && (!node.Required || node.Default != "" || (node.Tag.Env != "" && ok)) {
if err := checkEnum(node, node.Target); err != nil {
return err
}
}
case *Flag:
_, ok := os.LookupEnv(node.Tag.Env)
if node.Enum != "" && (!node.Required || node.Default != "" || (node.Tag.Env != "" && ok)) {
if err := checkEnum(node.Value, node.Target); err != nil {
return err
}
}
@@ -149,6 +179,35 @@ func (c *Context) Validate() error { // nolint: gocyclo
if err != nil {
return err
}
for _, el := range c.Path {
var (
value reflect.Value
desc string
)
switch node := el.Visitable().(type) {
case *Value:
value = node.Target
desc = node.ShortSummary()
case *Flag:
value = node.Target
desc = node.ShortSummary()
case *Application:
value = node.Target
desc = node.Name
case *Node:
value = node.Target
desc = node.Path()
}
if validate := isValidatable(value); validate != nil {
err := validate.Validate()
if err != nil {
return errors.Wrap(err, desc)
}
}
}
for _, resolver := range c.combineResolvers() {
if err := resolver.Validate(c.Model); err != nil {
return err
@@ -787,3 +846,18 @@ func findPotentialCandidates(needle string, haystack []string, format string, ar
}
return fmt.Errorf("%s", prefix)
}
type validatable interface{ Validate() error }
func isValidatable(v reflect.Value) validatable {
if !v.IsValid() || (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice || v.Kind() == reflect.Map) && v.IsNil() {
return nil
}
if validate, ok := v.Interface().(validatable); ok {
return validate
}
if v.CanAddr() {
return isValidatable(v.Addr())
}
return nil
}