Make invalid value error messages more useful.

This commit is contained in:
Alec Thomas
2019-04-01 10:06:02 +11:00
parent 4e9878074f
commit 2ac3d43124
4 changed files with 34 additions and 18 deletions
-1
View File
@@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"github.com/alecthomas/kong" "github.com/alecthomas/kong"
) )
+1 -1
View File
@@ -715,7 +715,7 @@ func TestNumericParamErrors(t *testing.T) {
} }
parser := mustNew(t, &cli) parser := mustNew(t, &cli)
_, err := parser.Parse([]string{"--name", "-10"}) _, err := parser.Parse([]string{"--name", "-10"})
require.EqualError(t, err, `expected string value but got "-10" (short flag)`) require.EqualError(t, err, `--name: expected string value but got "-10" (short flag)`)
} }
func TestDefaultValueIsHyphen(t *testing.T) { func TestDefaultValueIsHyphen(t *testing.T) {
+11 -9
View File
@@ -75,7 +75,7 @@ type BoolMapper interface {
// A MapperFunc is a single function that complies with the Mapper interface. // A MapperFunc is a single function that complies with the Mapper interface.
type MapperFunc func(ctx *DecodeContext, target reflect.Value) error type MapperFunc func(ctx *DecodeContext, target reflect.Value) error
func (m MapperFunc) Decode(ctx *DecodeContext, target reflect.Value) error { //nolint: golint func (m MapperFunc) Decode(ctx *DecodeContext, target reflect.Value) error { // nolint: golint
return m(ctx, target) return m(ctx, target)
} }
@@ -224,9 +224,10 @@ func (boolMapper) IsBool() bool { return true }
func durationDecoder() MapperFunc { func durationDecoder() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error { return func(ctx *DecodeContext, target reflect.Value) error {
r, err := time.ParseDuration(ctx.Scan.PopValue("duration")) value := ctx.Scan.PopValue("duration")
r, err := time.ParseDuration(value)
if err != nil { if err != nil {
return err return fmt.Errorf("expected duration but got %q: %s", value, err)
} }
target.Set(reflect.ValueOf(r)) target.Set(reflect.ValueOf(r))
return nil return nil
@@ -235,11 +236,12 @@ func durationDecoder() MapperFunc {
func timeDecoder() MapperFunc { func timeDecoder() MapperFunc {
return func(ctx *DecodeContext, target reflect.Value) error { return func(ctx *DecodeContext, target reflect.Value) error {
fmt := time.RFC3339 format := time.RFC3339
if ctx.Value.Format != "" { if ctx.Value.Format != "" {
fmt = ctx.Value.Format format = ctx.Value.Format
} }
t, err := time.Parse(fmt, ctx.Scan.PopValue("time")) value := ctx.Scan.PopValue("time")
t, err := time.Parse(format, value)
if err != nil { if err != nil {
return err return err
} }
@@ -253,7 +255,7 @@ func intDecoder(bits int) MapperFunc {
value := ctx.Scan.PopValue("int") value := ctx.Scan.PopValue("int")
n, err := strconv.ParseInt(value, 10, bits) n, err := strconv.ParseInt(value, 10, bits)
if err != nil { if err != nil {
return fmt.Errorf("invalid int %q", value) return fmt.Errorf("expected int but got %q", value)
} }
target.SetInt(n) target.SetInt(n)
return nil return nil
@@ -265,7 +267,7 @@ func uintDecoder(bits int) MapperFunc {
value := ctx.Scan.PopValue("uint") value := ctx.Scan.PopValue("uint")
n, err := strconv.ParseUint(value, 10, bits) n, err := strconv.ParseUint(value, 10, bits)
if err != nil { if err != nil {
return fmt.Errorf("invalid uint %q", value) return fmt.Errorf("expected unsigned int but got %q", value)
} }
target.SetUint(n) target.SetUint(n)
return nil return nil
@@ -277,7 +279,7 @@ func floatDecoder(bits int) MapperFunc {
value := ctx.Scan.PopValue("float") value := ctx.Scan.PopValue("float")
n, err := strconv.ParseFloat(value, bits) n, err := strconv.ParseFloat(value, bits)
if err != nil { if err != nil {
return fmt.Errorf("invalid float %q", value) return fmt.Errorf("expected float but got %q", value)
} }
target.SetFloat(n) target.SetFloat(n)
return nil return nil
+22 -7
View File
@@ -226,6 +226,21 @@ func (v *Value) EnumMap() map[string]bool {
return out return out
} }
// ShortSummary returns a human-readable summary of the value, not including any placeholders/defaults.
func (v *Value) ShortSummary() string {
if v.Flag != nil {
return fmt.Sprintf("--%s", v.Name)
}
argText := "<" + v.Name + ">"
if v.IsCumulative() {
argText += " ..."
}
if !v.Required {
argText = "[" + argText + "]"
}
return argText
}
// Summary returns a human-readable summary of the value. // Summary returns a human-readable summary of the value.
func (v *Value) Summary() string { func (v *Value) Summary() string {
if v.Flag != nil { if v.Flag != nil {
@@ -268,20 +283,20 @@ func (v *Value) IsBool() bool {
} }
// Parse tokens into value, parse, and validate, but do not write to the field. // Parse tokens into value, parse, and validate, but do not write to the field.
func (v *Value) Parse(scan *Scanner, target reflect.Value) error { func (v *Value) Parse(scan *Scanner, target reflect.Value) (err error) {
defer func() { defer func() {
if err := recover(); err != nil { if rerr := recover(); rerr != nil {
switch err := err.(type) { switch rerr := rerr.(type) {
case Error: case Error:
panic(err) err = fmt.Errorf("%s: %s", v.ShortSummary(), rerr)
default: default:
panic(fmt.Sprintf("mapper %T failed to apply to %s: %s", v.Mapper, v.Summary(), err)) panic(fmt.Sprintf("mapper %T failed to apply to %s: %s", v.Mapper, v.Summary(), rerr))
} }
} }
}() }()
err := v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target) err = v.Mapper.Decode(&DecodeContext{Value: v, Scan: scan}, target)
if err != nil { if err != nil {
return fmt.Errorf("%s: %s", v.Summary(), err) return fmt.Errorf("%s: %s", v.ShortSummary(), err)
} }
v.Set = true v.Set = true
return nil return nil