Move command/flag accumulation into ParseTrace.

This commit is contained in:
Alec Thomas
2018-05-26 16:02:09 -04:00
parent d20b44baf4
commit f60fe01f08
+54 -25
View File
@@ -9,20 +9,22 @@ import (
// ParseTrace records the nodes and parsed values from the current command-line. // ParseTrace records the nodes and parsed values from the current command-line.
type ParseTrace struct { type ParseTrace struct {
// One of these will be non-nil. // One of these will be non-nil.
App *Application
Positional *Value Positional *Value
Flag *Flag Flag *Flag
Argument *Argument Argument *Argument
Command *Command Command *Command
// Flags added by this node.
Flags []*Flag
// Parsed value for non-commands. // Parsed value for non-commands.
Value reflect.Value Value reflect.Value
} }
type ParseContext struct { type ParseContext struct {
Trace []*ParseTrace // A trace through parsed nodes. Trace []*ParseTrace // A trace through parsed nodes.
Error error // Error that occurred during trace, if any. Error error // Error that occurred during trace, if any.
Flags []*Flag // Accumulated available flags.
Command []string // Full command path.
node *Node // Current node being parsed. node *Node // Current node being parsed.
@@ -31,17 +33,48 @@ type ParseContext struct {
scan *Scanner scan *Scanner
} }
// Flags returns the accumulated available flags.
func (p *ParseContext) Flags() (flags []*Flag) {
for _, trace := range p.Trace {
flags = append(flags, trace.Flags...)
}
return
}
// Command returns the full command path.
func (p *ParseContext) Command() (command []string) {
for _, trace := range p.Trace {
switch {
case trace.Positional != nil:
command = append(command, "<"+trace.Positional.Name+">")
case trace.Argument != nil:
command = append(command, "<"+trace.Argument.Name+">")
case trace.Command != nil:
command = append(command, trace.Command.Name)
}
}
return
}
// Trace parses the command-line, validating and collecting matching grammar nodes. // Trace parses the command-line, validating and collecting matching grammar nodes.
func Trace(args []string, app *Application) (*ParseContext, error) { func Trace(args []string, app *Application) (*ParseContext, error) {
p := &ParseContext{ p := &ParseContext{
app: app, app: app,
args: args, args: args,
} }
p.Trace = append(p.Trace, &ParseTrace{
App: app,
Flags: append([]*Flag{}, app.Flags...),
})
err := p.reset(&p.app.Node) err := p.reset(&p.app.Node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.Error = p.trace(&p.app.Node) p.Error = p.trace(&p.app.Node)
if err = checkMissingFlags(p.Flags()); err != nil {
return nil, err
}
return p, nil return p, nil
} }
@@ -94,7 +127,7 @@ func (p *ParseContext) reset(node *Node) error {
func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
positional := 0 positional := 0
p.node = node p.node = node
p.Flags = append(p.Flags, node.Flags...) flags := append(p.Flags(), node.Flags...)
for !p.scan.Peek().IsEOL() { for !p.scan.Peek().IsEOL() {
token := p.scan.Peek() token := p.scan.Peek()
@@ -146,14 +179,14 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
p.scan.PushTyped(token.Value[0:1], ShortFlagToken) p.scan.PushTyped(token.Value[0:1], ShortFlagToken)
case FlagToken: case FlagToken:
if err := p.matchFlags(func(f *Flag) bool { if err := p.matchFlags(flags, func(f *Flag) bool {
return f.Name == token.Value return f.Name == token.Value
}); err != nil { }); err != nil {
return err return err
} }
case ShortFlagToken: case ShortFlagToken:
if err := p.matchFlags(func(f *Flag) bool { if err := p.matchFlags(flags, func(f *Flag) bool {
return string(f.Name) == token.Value return string(f.Name) == token.Value
}); err != nil { }); err != nil {
return err return err
@@ -170,8 +203,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
if err != nil { if err != nil {
return err return err
} }
p.Command = append(p.Command, "<"+arg.Name+">") p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value, Flags: node.Flags})
p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value})
positional++ positional++
break break
} }
@@ -182,16 +214,21 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
case branch.Command != nil: case branch.Command != nil:
if branch.Command.Name == token.Value { if branch.Command.Name == token.Value {
p.scan.Pop() p.scan.Pop()
p.Command = append(p.Command, branch.Command.Name) p.Trace = append(p.Trace, &ParseTrace{
p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command}) Command: branch.Command,
Flags: node.Flags,
})
return p.trace(branch.Command) return p.trace(branch.Command)
} }
case branch.Argument != nil: case branch.Argument != nil:
arg := branch.Argument.Argument arg := branch.Argument.Argument
if value, err := arg.Parse(p.scan); err == nil { if value, err := arg.Parse(p.scan); err == nil {
p.Command = append(p.Command, "<"+arg.Name+">") p.Trace = append(p.Trace, &ParseTrace{
p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value}) Argument: branch.Argument,
Value: value,
Flags: node.Flags,
})
return p.trace(&branch.Argument.Node) return p.trace(&branch.Argument.Node)
} }
} }
@@ -211,10 +248,6 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
return err return err
} }
if err := checkMissingFlags(node.Children, p.Flags); err != nil {
return err
}
return nil return nil
} }
@@ -238,11 +271,7 @@ func (p *ParseContext) Apply() (string, error) {
return strings.Join(path, " "), nil return strings.Join(path, " "), nil
} }
func checkMissingFlags(children []*Branch, flags []*Flag) error { func checkMissingFlags(flags []*Flag) error {
// Only check required missing fields at the last child.
if len(children) > 0 {
return nil
}
missing := []string{} missing := []string{}
for _, flag := range flags { for _, flag := range flags {
if !flag.Required || flag.Set { if !flag.Required || flag.Set {
@@ -295,10 +324,10 @@ func checkMissingPositionals(positional int, values []*Value) error {
return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " "))
} }
func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) { func (p *ParseContext) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) {
token := p.scan.Peek()
defer catch(&err) defer catch(&err)
for _, flag := range p.Flags { token := p.scan.Peek()
for _, flag := range flags {
// Found a matching flag. // Found a matching flag.
if flag.Name == token.Value { if flag.Name == token.Value {
p.scan.Pop() p.scan.Pop()