diff --git a/context.go b/context.go index 3715c4b..108873a 100644 --- a/context.go +++ b/context.go @@ -9,20 +9,22 @@ import ( // ParseTrace records the nodes and parsed values from the current command-line. type ParseTrace struct { // One of these will be non-nil. + App *Application Positional *Value Flag *Flag Argument *Argument Command *Command + // Flags added by this node. + Flags []*Flag + // Parsed value for non-commands. Value reflect.Value } type ParseContext struct { - Trace []*ParseTrace // A trace through parsed nodes. - Error error // Error that occurred during trace, if any. - Flags []*Flag // Accumulated available flags. - Command []string // Full command path. + Trace []*ParseTrace // A trace through parsed nodes. + Error error // Error that occurred during trace, if any. node *Node // Current node being parsed. @@ -31,17 +33,48 @@ type ParseContext struct { scan *Scanner } +// Flags returns the accumulated available flags. +func (p *ParseContext) Flags() (flags []*Flag) { + for _, trace := range p.Trace { + flags = append(flags, trace.Flags...) + } + return +} + +// Command returns the full command path. +func (p *ParseContext) Command() (command []string) { + for _, trace := range p.Trace { + switch { + case trace.Positional != nil: + command = append(command, "<"+trace.Positional.Name+">") + case trace.Argument != nil: + command = append(command, "<"+trace.Argument.Name+">") + case trace.Command != nil: + command = append(command, trace.Command.Name) + } + } + return +} + // Trace parses the command-line, validating and collecting matching grammar nodes. func Trace(args []string, app *Application) (*ParseContext, error) { p := &ParseContext{ app: app, args: args, } + p.Trace = append(p.Trace, &ParseTrace{ + App: app, + Flags: append([]*Flag{}, app.Flags...), + }) err := p.reset(&p.app.Node) if err != nil { return nil, err } p.Error = p.trace(&p.app.Node) + if err = checkMissingFlags(p.Flags()); err != nil { + return nil, err + } + return p, nil } @@ -94,7 +127,7 @@ func (p *ParseContext) reset(node *Node) error { func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo positional := 0 p.node = node - p.Flags = append(p.Flags, node.Flags...) + flags := append(p.Flags(), node.Flags...) for !p.scan.Peek().IsEOL() { token := p.scan.Peek() @@ -146,14 +179,14 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo p.scan.PushTyped(token.Value[0:1], ShortFlagToken) case FlagToken: - if err := p.matchFlags(func(f *Flag) bool { + if err := p.matchFlags(flags, func(f *Flag) bool { return f.Name == token.Value }); err != nil { return err } case ShortFlagToken: - if err := p.matchFlags(func(f *Flag) bool { + if err := p.matchFlags(flags, func(f *Flag) bool { return string(f.Name) == token.Value }); err != nil { return err @@ -170,8 +203,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo if err != nil { return err } - p.Command = append(p.Command, "<"+arg.Name+">") - p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value}) + p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value, Flags: node.Flags}) positional++ break } @@ -182,16 +214,21 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo case branch.Command != nil: if branch.Command.Name == token.Value { p.scan.Pop() - p.Command = append(p.Command, branch.Command.Name) - p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command}) + p.Trace = append(p.Trace, &ParseTrace{ + Command: branch.Command, + Flags: node.Flags, + }) return p.trace(branch.Command) } case branch.Argument != nil: arg := branch.Argument.Argument if value, err := arg.Parse(p.scan); err == nil { - p.Command = append(p.Command, "<"+arg.Name+">") - p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value}) + p.Trace = append(p.Trace, &ParseTrace{ + Argument: branch.Argument, + Value: value, + Flags: node.Flags, + }) return p.trace(&branch.Argument.Node) } } @@ -211,10 +248,6 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo return err } - if err := checkMissingFlags(node.Children, p.Flags); err != nil { - return err - } - return nil } @@ -238,11 +271,7 @@ func (p *ParseContext) Apply() (string, error) { return strings.Join(path, " "), nil } -func checkMissingFlags(children []*Branch, flags []*Flag) error { - // Only check required missing fields at the last child. - if len(children) > 0 { - return nil - } +func checkMissingFlags(flags []*Flag) error { missing := []string{} for _, flag := range flags { if !flag.Required || flag.Set { @@ -295,10 +324,10 @@ func checkMissingPositionals(positional int, values []*Value) error { return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } -func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) { - token := p.scan.Peek() +func (p *ParseContext) matchFlags(flags []*Flag, matcher func(f *Flag) bool) (err error) { defer catch(&err) - for _, flag := range p.Flags { + token := p.scan.Peek() + for _, flag := range flags { // Found a matching flag. if flag.Name == token.Value { p.scan.Pop()