diff --git a/mapper.go b/mapper.go index a605dea..d038cb0 100644 --- a/mapper.go +++ b/mapper.go @@ -3,6 +3,7 @@ package kong import ( "encoding" "encoding/json" + "fmt" "io/ioutil" "math/bits" "net/url" @@ -325,23 +326,22 @@ func intDecoder(bits int) MapperFunc { // nolint: dupl if err != nil { return err } + var sv string switch v := t.Value.(type) { case string: - n, err := strconv.ParseInt(v, 10, bits) - if err != nil { - return errors.Errorf("expected an int but got %q (%T)", t, t.Value) - } - target.SetInt(n) + sv = v - case float64: - target.SetInt(int64(v)) - - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - target.Set(reflect.ValueOf(v)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + sv = fmt.Sprintf("%v", v) default: return errors.Errorf("expected an int but got %q (%T)", t, t.Value) } + n, err := strconv.ParseInt(sv, 10, bits) + if err != nil { + return errors.Errorf("expected a valid %d bit int but got %q", bits, sv) + } + target.SetInt(n) return nil } } @@ -352,23 +352,22 @@ func uintDecoder(bits int) MapperFunc { // nolint: dupl if err != nil { return err } + var sv string switch v := t.Value.(type) { case string: - n, err := strconv.ParseUint(v, 10, bits) - if err != nil { - return errors.Errorf("expected a uint but got %q (%T)", t, t.Value) - } - target.SetUint(n) + sv = v - case float64: - target.SetUint(uint64(v)) - - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - target.Set(reflect.ValueOf(v)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + sv = fmt.Sprintf("%v", v) default: return errors.Errorf("expected an int but got %q (%T)", t, t.Value) } + n, err := strconv.ParseUint(sv, 10, bits) + if err != nil { + return errors.Errorf("expected a valid %d bit uint but got %q", bits, sv) + } + target.SetUint(n) return nil } } diff --git a/mapper_test.go b/mapper_test.go index 25e09c4..9996b24 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "math" "net/url" "os" "reflect" @@ -260,3 +261,66 @@ func TestCounter(t *testing.T) { require.NoError(t, err) require.Equal(t, 3., cli.Float) } + +func TestNumbers(t *testing.T) { + type CLI struct { + F32 float32 + F64 float64 + I8 int8 + I16 int16 + I32 int32 + I64 int64 + U8 uint8 + U16 uint16 + U32 uint32 + U64 uint64 + } + var cli CLI + p := mustNew(t, &cli) + t.Run("Max", func(t *testing.T) { + _, err := p.Parse([]string{ + "--f-32", fmt.Sprintf("%v", math.MaxFloat32), + "--f-64", fmt.Sprintf("%v", math.MaxFloat64), + "--i-8", fmt.Sprintf("%v", math.MaxInt8), + "--i-16", fmt.Sprintf("%v", math.MaxInt16), + "--i-32", fmt.Sprintf("%v", math.MaxInt32), + "--i-64", fmt.Sprintf("%v", math.MaxInt64), + "--u-8", fmt.Sprintf("%v", math.MaxUint8), + "--u-16", fmt.Sprintf("%v", math.MaxUint16), + "--u-32", fmt.Sprintf("%v", math.MaxUint32), + "--u-64", fmt.Sprintf("%v", uint64(math.MaxUint64)), + }) + require.NoError(t, err) + require.Equal(t, CLI{ + F32: math.MaxFloat32, + F64: math.MaxFloat64, + I8: math.MaxInt8, + I16: math.MaxInt16, + I32: math.MaxInt32, + I64: math.MaxInt64, + U8: math.MaxUint8, + U16: math.MaxUint16, + U32: math.MaxUint32, + U64: math.MaxUint64, + }, cli) + }) + t.Run("Min", func(t *testing.T) { + _, err := p.Parse([]string{ + fmt.Sprintf("--i-8=%v", math.MinInt8), + fmt.Sprintf("--i-16=%v", math.MinInt16), + fmt.Sprintf("--i-32=%v", math.MinInt32), + fmt.Sprintf("--i-64=%v", math.MinInt64), + fmt.Sprintf("--u-8=%v", 0), + fmt.Sprintf("--u-16=%v", 0), + fmt.Sprintf("--u-32=%v", 0), + fmt.Sprintf("--u-64=%v", 0), + }) + require.NoError(t, err) + require.Equal(t, CLI{ + I8: math.MinInt8, + I16: math.MinInt16, + I32: math.MinInt32, + I64: math.MinInt64, + }, cli) + }) +}