diff --git a/build.go b/build.go index 2eabbf9..f59ef7e 100644 --- a/build.go +++ b/build.go @@ -7,15 +7,7 @@ import ( ) func build(ast interface{}) (app *Application, err error) { - defer func() { - msg := recover() - if test, ok := msg.(error); ok { - app = nil - err = test - } else if msg != nil { - panic(msg) - } - }() + defer catch(&err) v := reflect.ValueOf(ast) iv := reflect.Indirect(v) if v.Kind() != reflect.Ptr || iv.Kind() != reflect.Struct { diff --git a/context.go b/context.go index 4047672..3715c4b 100644 --- a/context.go +++ b/context.go @@ -19,12 +19,12 @@ type ParseTrace struct { } type ParseContext struct { - Trace []*ParseTrace // A trace through parsed nodes. - Error error // Error that occurred during trace, if any. + Trace []*ParseTrace // A trace through parsed nodes. + Error error // Error that occurred during trace, if any. + Flags []*Flag // Accumulated available flags. + Command []string // Full command path. - command []string // Full command path. - flags []*Flag // Accumulated available flags. - node *Node // Current node being parsed. + node *Node // Current node being parsed. args []string app *Application @@ -77,9 +77,15 @@ func (p *ParseContext) reset(node *Node) error { if err != nil { return err } - p.reset(&branch.Argument.Node) + err = p.reset(&branch.Argument.Node) + if err != nil { + return err + } } else { - p.reset(branch.Command) + err := p.reset(branch.Command) + if err != nil { + return err + } } } return nil @@ -88,7 +94,7 @@ func (p *ParseContext) reset(node *Node) error { func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo positional := 0 p.node = node - p.flags = append(p.flags, node.Flags...) + p.Flags = append(p.Flags, node.Flags...) for !p.scan.Peek().IsEOL() { token := p.scan.Peek() @@ -164,7 +170,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo if err != nil { return err } - p.command = append(p.command, "<"+arg.Name+">") + p.Command = append(p.Command, "<"+arg.Name+">") p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value}) positional++ break @@ -176,7 +182,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo case branch.Command != nil: if branch.Command.Name == token.Value { p.scan.Pop() - p.command = append(p.command, branch.Command.Name) + p.Command = append(p.Command, branch.Command.Name) p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command}) return p.trace(branch.Command) } @@ -184,7 +190,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo case branch.Argument != nil: arg := branch.Argument.Argument if value, err := arg.Parse(p.scan); err == nil { - p.command = append(p.command, "<"+arg.Name+">") + p.Command = append(p.Command, "<"+arg.Name+">") p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value}) return p.trace(&branch.Argument.Node) } @@ -205,7 +211,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo return err } - if err := checkMissingFlags(node.Children, p.flags); err != nil { + if err := checkMissingFlags(node.Children, p.Flags); err != nil { return err } @@ -291,15 +297,8 @@ func checkMissingPositionals(positional int, values []*Value) error { func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) { token := p.scan.Peek() - defer func() { - msg := recover() - if test, ok := msg.(Error); ok { - err = fmt.Errorf("%s %s", token, test) - } else if msg != nil { - panic(msg) - } - }() - for _, flag := range p.flags { + defer catch(&err) + for _, flag := range p.Flags { // Found a matching flag. if flag.Name == token.Value { p.scan.Pop() diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..08bd685 --- /dev/null +++ b/context_test.go @@ -0,0 +1,20 @@ +package kong + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTraceErrorPartiallySucceeds(t *testing.T) { + var cli struct { + One struct { + Two struct { + } `kong:"cmd"` + } `kong:"cmd"` + } + p := mustNew(t, &cli) + trace, err := Trace([]string{"one", "bad"}, p.Model) + require.NoError(t, err) + require.Error(t, trace.Error) +} diff --git a/help.go b/help.go index 9df8215..8a92146 100644 --- a/help.go +++ b/help.go @@ -18,8 +18,12 @@ usage: {{.Name}} var defaultHelpTemplate = template.Must(template.New("help").Parse(defaultHelp)) -// WriteHelp to w. If w is nil, the default stdout writer will be used. -func (k *Kong) WriteHelp(w io.Writer) error { +// WriteHelp to w. +// +// If w is nil, the default stdout writer will be used. +// +// If args are provided, help will be written in the context o +func (k *Kong) WriteHelp(w io.Writer, args ...interface{}) error { if w == nil { w = k.stdout } diff --git a/kong.go b/kong.go index 6498f69..70855ed 100644 --- a/kong.go +++ b/kong.go @@ -8,6 +8,7 @@ import ( "text/template" ) +// Error reported by Kong. type Error struct{ msg string } func (e Error) Error() string { return e.msg } @@ -57,14 +58,7 @@ func New(ast interface{}, options ...Option) (*Kong, error) { // Parse arguments into target. func (k *Kong) Parse(args []string) (command string, err error) { - defer func() { - msg := recover() - if test, ok := msg.(Error); ok { - err = test - } else if msg != nil { - panic(msg) - } - }() + defer catch(&err) ctx, err := Trace(args, k.Model) if err != nil { return "", err @@ -78,21 +72,6 @@ func (k *Kong) Parse(args []string) (command string, err error) { return ctx.Apply() } -// Trace through the command tree. -// -// The returned context will include a trace of all parsed objects encountered; flags, arguments, commands. -func (k *Kong) Trace(args []string) (ctx *ParseContext, err error) { - defer func() { - msg := recover() - if test, ok := msg.(Error); ok { - err = test - } else if msg != nil { - panic(msg) - } - }() - return Trace(args, k.Model) -} - func (k *Kong) Errorf(format string, args ...interface{}) { fmt.Fprintf(os.Stderr, k.Model.Name+": "+format, args...) } @@ -108,3 +87,12 @@ func (k *Kong) FatalIfErrorf(err error, args ...interface{}) { k.Errorf("%s\n", msg) k.terminate(1) } + +func catch(err *error) { + msg := recover() + if test, ok := msg.(Error); ok { + *err = test + } else if msg != nil { + panic(msg) + } +} diff --git a/kong_test.go b/kong_test.go index 55f3947..44233d3 100644 --- a/kong_test.go +++ b/kong_test.go @@ -352,16 +352,3 @@ func TestDuplicateFlagOnPeerCommandIsOkay(t *testing.T) { _, err := New(&cli) require.NoError(t, err) } - -func TestTraceErrorPartiallySucceeds(t *testing.T) { - var cli struct { - One struct { - Two struct { - } `kong:"cmd"` - } `kong:"cmd"` - } - p := mustNew(t, &cli) - trace, err := p.Trace([]string{"one", "bad"}) - require.NoError(t, err) - require.Error(t, trace.Error) -}