Remove a bunch of duplicate recover code.

This commit is contained in:
Alec Thomas
2018-05-26 15:32:45 -04:00
parent 0bb304449c
commit d20b44baf4
6 changed files with 58 additions and 68 deletions
+1 -9
View File
@@ -7,15 +7,7 @@ import (
) )
func build(ast interface{}) (app *Application, err error) { func build(ast interface{}) (app *Application, err error) {
defer func() { defer catch(&err)
msg := recover()
if test, ok := msg.(error); ok {
app = nil
err = test
} else if msg != nil {
panic(msg)
}
}()
v := reflect.ValueOf(ast) v := reflect.ValueOf(ast)
iv := reflect.Indirect(v) iv := reflect.Indirect(v)
if v.Kind() != reflect.Ptr || iv.Kind() != reflect.Struct { if v.Kind() != reflect.Ptr || iv.Kind() != reflect.Struct {
+17 -18
View File
@@ -21,9 +21,9 @@ type ParseTrace struct {
type ParseContext struct { type ParseContext struct {
Trace []*ParseTrace // A trace through parsed nodes. Trace []*ParseTrace // A trace through parsed nodes.
Error error // Error that occurred during trace, if any. Error error // Error that occurred during trace, if any.
Flags []*Flag // Accumulated available flags.
Command []string // Full command path.
command []string // Full command path.
flags []*Flag // Accumulated available flags.
node *Node // Current node being parsed. node *Node // Current node being parsed.
args []string args []string
@@ -77,9 +77,15 @@ func (p *ParseContext) reset(node *Node) error {
if err != nil { if err != nil {
return err return err
} }
p.reset(&branch.Argument.Node) err = p.reset(&branch.Argument.Node)
if err != nil {
return err
}
} else { } else {
p.reset(branch.Command) err := p.reset(branch.Command)
if err != nil {
return err
}
} }
} }
return nil return nil
@@ -88,7 +94,7 @@ func (p *ParseContext) reset(node *Node) error {
func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
positional := 0 positional := 0
p.node = node p.node = node
p.flags = append(p.flags, node.Flags...) p.Flags = append(p.Flags, node.Flags...)
for !p.scan.Peek().IsEOL() { for !p.scan.Peek().IsEOL() {
token := p.scan.Peek() token := p.scan.Peek()
@@ -164,7 +170,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
if err != nil { if err != nil {
return err return err
} }
p.command = append(p.command, "<"+arg.Name+">") p.Command = append(p.Command, "<"+arg.Name+">")
p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value}) p.Trace = append(p.Trace, &ParseTrace{Positional: arg, Value: value})
positional++ positional++
break break
@@ -176,7 +182,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
case branch.Command != nil: case branch.Command != nil:
if branch.Command.Name == token.Value { if branch.Command.Name == token.Value {
p.scan.Pop() p.scan.Pop()
p.command = append(p.command, branch.Command.Name) p.Command = append(p.Command, branch.Command.Name)
p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command}) p.Trace = append(p.Trace, &ParseTrace{Command: branch.Command})
return p.trace(branch.Command) return p.trace(branch.Command)
} }
@@ -184,7 +190,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
case branch.Argument != nil: case branch.Argument != nil:
arg := branch.Argument.Argument arg := branch.Argument.Argument
if value, err := arg.Parse(p.scan); err == nil { if value, err := arg.Parse(p.scan); err == nil {
p.command = append(p.command, "<"+arg.Name+">") p.Command = append(p.Command, "<"+arg.Name+">")
p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value}) p.Trace = append(p.Trace, &ParseTrace{Argument: branch.Argument, Value: value})
return p.trace(&branch.Argument.Node) return p.trace(&branch.Argument.Node)
} }
@@ -205,7 +211,7 @@ func (p *ParseContext) trace(node *Node) (err error) { // nolint: gocyclo
return err return err
} }
if err := checkMissingFlags(node.Children, p.flags); err != nil { if err := checkMissingFlags(node.Children, p.Flags); err != nil {
return err return err
} }
@@ -291,15 +297,8 @@ func checkMissingPositionals(positional int, values []*Value) error {
func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) { func (p *ParseContext) matchFlags(matcher func(f *Flag) bool) (err error) {
token := p.scan.Peek() token := p.scan.Peek()
defer func() { defer catch(&err)
msg := recover() for _, flag := range p.Flags {
if test, ok := msg.(Error); ok {
err = fmt.Errorf("%s %s", token, test)
} else if msg != nil {
panic(msg)
}
}()
for _, flag := range p.flags {
// Found a matching flag. // Found a matching flag.
if flag.Name == token.Value { if flag.Name == token.Value {
p.scan.Pop() p.scan.Pop()
+20
View File
@@ -0,0 +1,20 @@
package kong
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestTraceErrorPartiallySucceeds(t *testing.T) {
var cli struct {
One struct {
Two struct {
} `kong:"cmd"`
} `kong:"cmd"`
}
p := mustNew(t, &cli)
trace, err := Trace([]string{"one", "bad"}, p.Model)
require.NoError(t, err)
require.Error(t, trace.Error)
}
+6 -2
View File
@@ -18,8 +18,12 @@ usage: {{.Name}}
var defaultHelpTemplate = template.Must(template.New("help").Parse(defaultHelp)) var defaultHelpTemplate = template.Must(template.New("help").Parse(defaultHelp))
// WriteHelp to w. If w is nil, the default stdout writer will be used. // WriteHelp to w.
func (k *Kong) WriteHelp(w io.Writer) error { //
// If w is nil, the default stdout writer will be used.
//
// If args are provided, help will be written in the context o
func (k *Kong) WriteHelp(w io.Writer, args ...interface{}) error {
if w == nil { if w == nil {
w = k.stdout w = k.stdout
} }
+11 -23
View File
@@ -8,6 +8,7 @@ import (
"text/template" "text/template"
) )
// Error reported by Kong.
type Error struct{ msg string } type Error struct{ msg string }
func (e Error) Error() string { return e.msg } func (e Error) Error() string { return e.msg }
@@ -57,14 +58,7 @@ func New(ast interface{}, options ...Option) (*Kong, error) {
// Parse arguments into target. // Parse arguments into target.
func (k *Kong) Parse(args []string) (command string, err error) { func (k *Kong) Parse(args []string) (command string, err error) {
defer func() { defer catch(&err)
msg := recover()
if test, ok := msg.(Error); ok {
err = test
} else if msg != nil {
panic(msg)
}
}()
ctx, err := Trace(args, k.Model) ctx, err := Trace(args, k.Model)
if err != nil { if err != nil {
return "", err return "", err
@@ -78,21 +72,6 @@ func (k *Kong) Parse(args []string) (command string, err error) {
return ctx.Apply() return ctx.Apply()
} }
// Trace through the command tree.
//
// The returned context will include a trace of all parsed objects encountered; flags, arguments, commands.
func (k *Kong) Trace(args []string) (ctx *ParseContext, err error) {
defer func() {
msg := recover()
if test, ok := msg.(Error); ok {
err = test
} else if msg != nil {
panic(msg)
}
}()
return Trace(args, k.Model)
}
func (k *Kong) Errorf(format string, args ...interface{}) { func (k *Kong) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, k.Model.Name+": "+format, args...) fmt.Fprintf(os.Stderr, k.Model.Name+": "+format, args...)
} }
@@ -108,3 +87,12 @@ func (k *Kong) FatalIfErrorf(err error, args ...interface{}) {
k.Errorf("%s\n", msg) k.Errorf("%s\n", msg)
k.terminate(1) k.terminate(1)
} }
func catch(err *error) {
msg := recover()
if test, ok := msg.(Error); ok {
*err = test
} else if msg != nil {
panic(msg)
}
}
-13
View File
@@ -352,16 +352,3 @@ func TestDuplicateFlagOnPeerCommandIsOkay(t *testing.T) {
_, err := New(&cli) _, err := New(&cli)
require.NoError(t, err) require.NoError(t, err)
} }
func TestTraceErrorPartiallySucceeds(t *testing.T) {
var cli struct {
One struct {
Two struct {
} `kong:"cmd"`
} `kong:"cmd"`
}
p := mustNew(t, &cli)
trace, err := p.Trace([]string{"one", "bad"})
require.NoError(t, err)
require.Error(t, trace.Error)
}