Files
kong/context.go
T
Alec Thomas 6408010083 Clean up disparity between Context and Kong.
Previously, there was a confusing mix of functionality shared between
the two wherein you would need to use the Kong type for printing errors,
etc. but it did not have access to the context in order to print
context-sensitive usage information. This has been fixed.

Additionally, there are now fuzzy correction suggestions for flags and
commands

Also added a server example which shows how Kong can be used for parsing
in interactive shells. Run with:

    $ go run ./_examples/server/*.go

Then interact with:

    $ ssh -p 6740 127.0.0.1
2018-06-27 09:11:11 +10:00

627 lines
15 KiB
Go

package kong
import (
"fmt"
"reflect"
"strconv"
"strings"
)
// Path records the nodes and parsed values from the current command-line.
type Path struct {
Parent *Node
// One of these will be non-nil.
App *Application
Positional *Positional
Flag *Flag
Argument *Argument
Command *Command
// Flags added by this node.
Flags []*Flag
// True if this Path element was created as the result of a resolver.
Resolved bool
}
// Node returns the Node associated with this Path, or nil if Path is a non-Node.
func (p *Path) Node() *Node {
switch {
case p.App != nil:
return p.App.Node
case p.Argument != nil:
return p.Argument
case p.Command != nil:
return p.Command
}
return nil
}
// Context contains the current parse context.
type Context struct {
*Kong
// A trace through parsed nodes.
Path []*Path
// Original command-line arguments.
Args []string
// Error that occurred during trace, if any.
Error error
values map[*Value]reflect.Value // Temporary values during tracing.
scan *Scanner
}
// Trace path of "args" through the gammar tree.
//
// The returned Context will include a Path of all commands, arguments, positionals and flags.
//
// Note that this will not modify the target grammar. Call Apply() to do so.
func Trace(k *Kong, args []string) (*Context, error) {
c := &Context{
Kong: k,
Args: args,
Path: []*Path{
{App: k.Model, Flags: k.Model.Flags},
},
values: map[*Value]reflect.Value{},
scan: Scan(args...),
}
c.Error = c.trace(c.Model.Node)
return c, c.traceResolvers()
}
// Value returns the value for a particular path element.
func (c *Context) Value(path *Path) reflect.Value {
switch {
case path.Positional != nil:
return c.values[path.Positional]
case path.Flag != nil:
return c.values[path.Flag.Value]
case path.Argument != nil:
return c.values[path.Argument.Argument]
}
panic("can only retrieve value for flag, argument or positional")
}
// Selected command or argument.
func (c *Context) Selected() *Node {
var selected *Node
for _, path := range c.Path {
switch {
case path.Command != nil:
selected = path.Command
case path.Argument != nil:
selected = path.Argument
}
}
return selected
}
// Empty returns true if there were no arguments provided.
func (c *Context) Empty() bool {
for _, path := range c.Path {
if !path.Resolved && path.App == nil {
return false
}
}
return true
}
// Validate the current context.
func (c *Context) Validate() error {
for _, path := range c.Path {
if err := checkMissingFlags(path.Flags); err != nil {
return err
}
}
// Check the terminal node.
node := c.Selected()
if node == nil {
node = c.Model.Node
}
// Find deepest positional argument so we can check if all required positionals have been provided.
positionals := 0
for _, path := range c.Path {
if path.Positional != nil {
positionals = path.Positional.Position + 1
}
}
if err := checkMissingChildren(node); err != nil {
return err
}
if err := checkMissingPositionals(positionals, node.Positional); err != nil {
return err
}
if node.Type == ArgumentNode {
value := node.Argument
if value.Required && !value.Set {
return fmt.Errorf("%s is required", node.Summary())
}
}
return nil
}
// Flags returns the accumulated available flags.
func (c *Context) Flags() (flags []*Flag) {
for _, trace := range c.Path {
flags = append(flags, trace.Flags...)
}
return
}
// Command returns the full command path.
func (c *Context) Command() string {
command := []string{}
for _, trace := range c.Path {
switch {
case trace.Positional != nil:
command = append(command, "<"+trace.Positional.Name+">")
case trace.Argument != nil:
command = append(command, "<"+trace.Argument.Name+">")
case trace.Command != nil:
command = append(command, trace.Command.Name)
}
}
return strings.Join(command, " ")
}
// FlagValue returns the set value of a flag, if it was encountered and exists.
func (c *Context) FlagValue(flag *Flag) reflect.Value {
for _, trace := range c.Path {
if trace.Flag == flag {
return c.values[trace.Flag.Value]
}
}
return reflect.Value{}
}
// Recursively reset values to defaults (as specified in the grammar) or the zero value.
func (c *Context) reset(node *Node) error {
for _, flag := range node.Flags {
err := flag.Value.Reset()
if err != nil {
return err
}
}
for _, pos := range node.Positional {
err := pos.Reset()
if err != nil {
return err
}
}
for _, branch := range node.Children {
if branch.Argument != nil {
arg := branch.Argument
err := arg.Reset()
if err != nil {
return err
}
}
err := c.reset(branch)
if err != nil {
return err
}
}
return nil
}
func (c *Context) trace(node *Node) (err error) { // nolint: gocyclo
positional := 0
flags := []*Flag{}
for _, group := range node.AllFlags(false) {
flags = append(flags, group...)
}
for !c.scan.Peek().IsEOL() {
token := c.scan.Peek()
switch token.Type {
case UntypedToken:
switch {
// Indicates end of parsing. All remaining arguments are treated as positional arguments only.
case token.Value == "--":
c.scan.Pop()
args := []string{}
for {
token = c.scan.Pop()
if token.Type == EOLToken {
break
}
args = append(args, token.Value)
}
// Note: tokens must be pushed in reverse order.
for i := range args {
c.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken)
}
// Long flag.
case strings.HasPrefix(token.Value, "--"):
c.scan.Pop()
// Parse it and push the tokens.
parts := strings.SplitN(token.Value[2:], "=", 2)
if len(parts) > 1 {
c.scan.PushTyped(parts[1], FlagValueToken)
}
c.scan.PushTyped(parts[0], FlagToken)
// Short flag.
case strings.HasPrefix(token.Value, "-"):
c.scan.Pop()
// Note: tokens must be pushed in reverse order.
if tail := token.Value[2:]; tail != "" {
c.scan.PushTyped(tail, ShortFlagTailToken)
}
c.scan.PushTyped(token.Value[1:2], ShortFlagToken)
default:
c.scan.Pop()
c.scan.PushTyped(token.Value, PositionalArgumentToken)
}
case ShortFlagTailToken:
c.scan.Pop()
// Note: tokens must be pushed in reverse order.
if tail := token.Value[1:]; tail != "" {
c.scan.PushTyped(tail, ShortFlagTailToken)
}
c.scan.PushTyped(token.Value[0:1], ShortFlagToken)
case FlagToken:
if err := c.parseFlag(flags, "--"+token.Value); err != nil {
return err
}
case ShortFlagToken:
if err := c.parseFlag(flags, "-"+token.Value); err != nil {
return err
}
case FlagValueToken:
return fmt.Errorf("unexpected flag argument %q", token.Value)
case PositionalArgumentToken:
candidates := []string{}
// Ensure we've consumed all positional arguments.
if positional < len(node.Positional) {
arg := node.Positional[positional]
err := arg.Parse(c.scan, c.getValue(arg))
if err != nil {
return err
}
c.Path = append(c.Path, &Path{
Parent: node,
Positional: arg,
})
positional++
break
}
// After positional arguments have been consumed, check commands next...
for _, branch := range node.Children {
if branch.Type == CommandNode {
candidates = append(candidates, branch.Name)
}
if branch.Type == CommandNode && branch.Name == token.Value {
c.scan.Pop()
c.Path = append(c.Path, &Path{
Parent: node,
Command: branch,
Flags: branch.Flags,
})
return c.trace(branch)
}
}
// Finally, check arguments.
for _, branch := range node.Children {
if branch.Type == ArgumentNode {
arg := branch.Argument
if err := arg.Parse(c.scan, c.getValue(arg)); err == nil {
c.Path = append(c.Path, &Path{
Parent: node,
Argument: branch,
Flags: branch.Flags,
})
return c.trace(branch)
}
}
}
return findPotentialCandidates(token.Value, candidates, "unexpected argument %s", token)
default:
return fmt.Errorf("unexpected token %s", token)
}
}
return nil
}
func findPotentialCandidates(needle string, haystack []string, format string, args ...interface{}) error {
if len(haystack) == 0 {
return fmt.Errorf(format, args...)
}
closestCandidates := []string{}
for _, candidate := range haystack {
if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 {
closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate))
}
}
prefix := fmt.Sprintf(format, args...)
if len(closestCandidates) == 1 {
return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0])
} else if len(closestCandidates) > 1 {
return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", "))
}
return fmt.Errorf("%s", prefix)
}
// Walk through flags from existing nodes in the path.
func (c *Context) traceResolvers() error {
if len(c.resolvers) == 0 {
return nil
}
inserted := []*Path{}
for _, path := range c.Path {
for _, flag := range path.Flags {
// Flag has already been set on the command-line.
if _, ok := c.values[flag.Value]; ok {
continue
}
for _, resolver := range c.resolvers {
s, err := resolver(c, path, flag)
if err != nil {
return err
}
if s == "" {
continue
}
scan := Scan().PushTyped(s, FlagValueToken)
delete(c.values, flag.Value)
err = flag.Parse(scan, c.getValue(flag.Value))
if err != nil {
return err
}
inserted = append(inserted, &Path{
Flag: flag,
Resolved: true,
})
}
}
}
c.Path = append(inserted, c.Path...)
return nil
}
func (c *Context) getValue(value *Value) reflect.Value {
v, ok := c.values[value]
if !ok {
v = reflect.New(value.Target.Type()).Elem()
c.values[value] = v
}
return v
}
// Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) {
err := c.reset(c.Model.Node)
if err != nil {
return "", err
}
path := []string{}
for _, trace := range c.Path {
var value *Value
switch {
case trace.App != nil:
case trace.Argument != nil:
path = append(path, "<"+trace.Argument.Name+">")
value = trace.Argument.Argument
case trace.Command != nil:
path = append(path, trace.Command.Name)
case trace.Flag != nil:
value = trace.Flag.Value
case trace.Positional != nil:
path = append(path, "<"+trace.Positional.Name+">")
value = trace.Positional
default:
panic("unsupported path ?!")
}
if value != nil {
value.Apply(c.getValue(value))
}
}
return strings.Join(path, " "), nil
}
func (c *Context) parseFlag(flags []*Flag, match string) (err error) {
defer catch(&err)
candidates := []string{}
for _, flag := range flags {
long := "--" + flag.Name
short := "-" + string(flag.Short)
candidates = append(candidates, long)
if flag.Short != 0 {
candidates = append(candidates, short)
}
if short != match && long != match {
continue
}
// Found a matching flag.
c.scan.Pop()
err := flag.Parse(c.scan, c.getValue(flag.Value))
if err != nil {
return err
}
c.Path = append(c.Path, &Path{Flag: flag})
return nil
}
return findPotentialCandidates(match, candidates, "unknown flag %s", match)
}
// Run executes the corresponding Run(params...) method on the target command selected by the parsed args.
//
// The target Run() method must exist and have the type signature "Run(params...) error".
func (c *Context) Run(params ...interface{}) (err error) {
defer catch(&err)
expectedRunSignature, err := c.validateRun(c.Model.Node, nil)
if err != nil {
return err
}
if expectedRunSignature.NumIn() != len(params) {
return fmt.Errorf("expected %d params but received %d; does not match target Run() signature of %s",
expectedRunSignature.NumIn(), len(params), expectedRunSignature)
}
for i, param := range params {
if reflect.TypeOf(param) != expectedRunSignature.In(i) {
return fmt.Errorf("param %d is of type %s but should be of type %s to match target Run() signature of %s",
i, reflect.TypeOf(param), expectedRunSignature.In(i), expectedRunSignature)
}
}
node := c.Selected()
if node == nil {
return fmt.Errorf("no command selected")
}
method, err := getRunMethod(node.Target)
if err != nil {
return err
}
_, err = c.Apply()
if err != nil {
return err
}
reflectedParams := []reflect.Value{}
for _, param := range params {
reflectedParams = append(reflectedParams, reflect.ValueOf(param))
}
result := method.Call(reflectedParams)
if result[0].IsNil() {
return nil
}
return result[0].Interface().(error)
}
// PrintUsage to Kong's stdout.
//
// If summary is true, a summarised version of the help will be output.
func (c *Context) PrintUsage(summary bool) error {
options := c.helpOptions
options.Summary = summary
_ = c.help(options, c)
return nil
}
// Validate that all commands have Run() methods and that their signatures are the same.
func (c *Context) validateRun(node *Node, signature reflect.Type) (reflect.Type, error) {
if node.Leaf() {
method, err := getRunMethod(node.Target)
if err != nil {
return nil, err
}
if signature == nil {
signature = method.Type()
} else if signature != method.Type() {
return nil, fmt.Errorf("Run() methods are not consistent on %s, expected %s but got %s", node.Target.Type(), signature, method.Type())
}
if signature.NumOut() != 1 || signature.Out(0) != expectedRunReturnSignature {
return nil, fmt.Errorf("Run() method on %s should return (error)", node.Target.Type())
}
}
for _, child := range node.Children {
if childSignature, err := c.validateRun(child, signature); err != nil {
return nil, err
} else if signature == nil {
signature = childSignature
}
}
return signature, nil
}
func getRunMethod(value reflect.Value) (reflect.Value, error) {
method := value.MethodByName("Run")
if !method.IsValid() {
if value.CanAddr() {
method = value.Addr().MethodByName("Run")
}
if !method.IsValid() {
return method, fmt.Errorf("no Run() method on %s", value.Type())
}
}
return method, nil
}
func checkMissingFlags(flags []*Flag) error {
missing := []string{}
for _, flag := range flags {
if !flag.Required || flag.Set {
continue
}
missing = append(missing, flag.Summary())
}
if len(missing) == 0 {
return nil
}
return fmt.Errorf("missing flags: %s", strings.Join(missing, ", "))
}
func checkMissingChildren(node *Node) error {
missing := []string{}
for _, arg := range node.Positional {
if arg.Required && !arg.Set {
missing = append(missing, strconv.Quote(arg.Summary()))
}
}
for _, child := range node.Children {
if child.Argument != nil {
if !child.Argument.Required {
continue
}
missing = append(missing, strconv.Quote(child.Summary()))
} else {
missing = append(missing, strconv.Quote(child.Name))
}
}
if len(missing) == 0 {
return nil
}
if len(missing) == 1 {
return fmt.Errorf("expected %s", missing[0])
}
if len(missing) > 5 {
missing = append(missing[:5], "...")
}
return fmt.Errorf("expected one of %s", strings.Join(missing, ", "))
}
// If we're missing any positionals and they're required, return an error.
func checkMissingPositionals(positional int, values []*Value) error {
// All the positionals are in.
if positional >= len(values) {
return nil
}
// We're low on supplied positionals, but the missing one is optional.
if !values[positional].Required {
return nil
}
missing := []string{}
for ; positional < len(values); positional++ {
missing = append(missing, "<"+values[positional].Name+">")
}
return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " "))
}