diff --git a/mapper.go b/mapper.go index 4c1bb2b..657b8ee 100644 --- a/mapper.go +++ b/mapper.go @@ -308,15 +308,23 @@ func (boolMapper) IsBool() bool { return true } func durationDecoder() MapperFunc { return func(ctx *DecodeContext, target reflect.Value) error { - var value string - if err := ctx.Scan.PopValueInto("duration", &value); err != nil { + t, err := ctx.Scan.PopValue("duration") + if err != nil { return err } - r, err := time.ParseDuration(value) - if err != nil { - return errors.Errorf("expected duration but got %q: %s", value, err) + var d time.Duration + switch v := t.Value.(type) { + case string: + d, err = time.ParseDuration(v) + if err != nil { + return errors.Errorf("expected duration but got %q: %s", v, err) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + d = reflect.ValueOf(v).Convert(reflect.TypeOf(time.Duration(0))).Interface().(time.Duration) // nolint: forcetypeassert + default: + return errors.Errorf("expected duration but got %q", v) } - target.Set(reflect.ValueOf(r)) + target.Set(reflect.ValueOf(d)) return nil } } diff --git a/mapper_test.go b/mapper_test.go index 0a81f9c..3f9b87a 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -127,6 +127,18 @@ func TestDurationMapper(t *testing.T) { require.Equal(t, time.Second*5, cli.Flag) } +func TestDurationMapperJSONResolver(t *testing.T) { + var cli struct { + Flag time.Duration + } + resolver, err := kong.JSON(strings.NewReader(`{"flag": 5000000000}`)) + require.NoError(t, err) + k := mustNew(t, &cli, kong.Resolvers(resolver)) + _, err = k.Parse(nil) + require.NoError(t, err) + require.Equal(t, time.Second*5, cli.Flag) +} + func TestSplitEscaped(t *testing.T) { require.Equal(t, []string{"a", "b"}, kong.SplitEscaped("a,b", ',')) require.Equal(t, []string{"a,b", "c"}, kong.SplitEscaped(`a\,b,c`, ','))