From 77e4b01553277c03e56a3303f3f8f935509beccd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jan 2022 18:46:28 -0600 Subject: [PATCH] Convert Interval to Codec --- pgtype/builtin_wrappers.go | 16 +++ pgtype/interval.go | 274 ++++++++++++++++++++++--------------- pgtype/interval_test.go | 188 ++++++++++++++++--------- pgtype/pgtype.go | 32 ++++- 4 files changed, 331 insertions(+), 179 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 15d4e083..5689b321 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -342,6 +342,22 @@ func (w timeWrapper) DateValue() (Date, error) { return Date{Time: time.Time(w), Valid: true}, nil } +type durationWrapper time.Duration + +func (w *durationWrapper) ScanInterval(v Interval) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Interval") + } + + us := int64(v.Months)*microsecondsPerMonth + int64(v.Days)*microsecondsPerDay + v.Microseconds + *w = durationWrapper(time.Duration(us) * time.Microsecond) + return nil +} + +func (w durationWrapper) IntervalValue() (Interval, error) { + return Interval{Microseconds: int64(w) / 1000, Valid: true}, nil +} + type netIPNetWrapper net.IPNet func (w *netIPNetWrapper) ScanInet(v Inet) error { diff --git a/pgtype/interval.go b/pgtype/interval.go index a92cd41f..41216f37 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -6,7 +6,6 @@ import ( "fmt" "strconv" "strings" - "time" "github.com/jackc/pgio" ) @@ -19,6 +18,14 @@ const ( microsecondsPerMonth = 30 * microsecondsPerDay ) +type IntervalScanner interface { + ScanInterval(v Interval) error +} + +type IntervalValuer interface { + IntervalValue() (Interval, error) +} + type Interval struct { Microseconds int64 Days int32 @@ -26,61 +33,169 @@ type Interval struct { Valid bool } -func (dst *Interval) Set(src interface{}) error { +func (interval *Interval) ScanInterval(v Interval) error { + *interval = v + return nil +} + +func (interval Interval) IntervalValue() (Interval, error) { + return interval, nil +} + +// Scan implements the database/sql Scanner interface. +func (interval *Interval) Scan(src interface{}) error { if src == nil { - *dst = Interval{} + *interval = Interval{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToIntervalScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), interval) } - switch value := src.(type) { - case time.Duration: - *dst = Interval{Microseconds: int64(value) / 1000, Valid: true} - default: - if originalSrc, ok := underlyingPtrType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Interval", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (interval Interval) Value() (driver.Value, error) { + if !interval.Valid { + return nil, nil + } + + buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type IntervalCodec struct{} + +func (IntervalCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (IntervalCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (IntervalCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(IntervalValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanIntervalCodecBinary{} + case TextFormatCode: + return encodePlanIntervalCodecText{} } return nil } -func (dst Interval) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanIntervalCodecBinary struct{} + +func (encodePlanIntervalCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err } - return dst + + if !interval.Valid { + return nil, nil + } + + buf = pgio.AppendInt64(buf, interval.Microseconds) + buf = pgio.AppendInt32(buf, interval.Days) + buf = pgio.AppendInt32(buf, interval.Months) + return buf, nil } -func (src *Interval) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanIntervalCodecText struct{} + +func (encodePlanIntervalCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + interval, err := value.(IntervalValuer).IntervalValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *time.Duration: - us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds - *v = time.Duration(us) * time.Microsecond - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) + if !interval.Valid { + return nil, nil + } + + if interval.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) + buf = append(buf, " mon "...) + } + + if interval.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + buf = append(buf, timeStr...) + return buf, nil +} + +func (IntervalCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanBinaryIntervalToIntervalScanner{} + } + case TextFormatCode: + switch target.(type) { + case IntervalScanner: + return scanPlanTextAnyToIntervalScanner{} } - return fmt.Errorf("unable to assign to %T", dst) } + + return nil } -func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { +type scanPlanBinaryIntervalToIntervalScanner struct{} + +func (scanPlanBinaryIntervalToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(IntervalScanner) + if src == nil { - *dst = Interval{} - return nil + return scanner.ScanInterval(Interval{}) + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) +} + +type scanPlanTextAnyToIntervalScanner struct{} + +func (scanPlanTextAnyToIntervalScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(IntervalScanner) + + if src == nil { + return scanner.ScanInterval(Interval{}) } var microseconds int64 @@ -156,89 +271,22 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true} - return nil + return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) } -func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c IntervalCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c IntervalCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Interval{} - return nil - } - - if len(src) != 16 { - return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) - } - - microseconds := int64(binary.BigEndian.Uint64(src)) - days := int32(binary.BigEndian.Uint32(src[8:])) - months := int32(binary.BigEndian.Uint32(src[12:])) - - *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true} - return nil -} - -func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - if src.Months != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) - buf = append(buf, " mon "...) + var interval Interval + err := codecScan(c, ci, oid, format, src, &interval) + if err != nil { + return nil, err } - - if src.Days != 0 { - buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) - buf = append(buf, " day "...) - } - - absMicroseconds := src.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } - - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - microseconds := absMicroseconds % microsecondsPerSecond - - timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - return append(buf, timeStr...), nil -} - -// EncodeBinary encodes src into w. -func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendInt64(buf, src.Microseconds) - buf = pgio.AppendInt32(buf, src.Days) - return pgio.AppendInt32(buf, src.Months), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Interval) Scan(src interface{}) error { - if src == nil { - *dst = Interval{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Interval) Value() (driver.Value, error) { - return EncodeValueText(src) + return interval, nil } diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index a8241bf6..75733ff1 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -5,70 +5,132 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestIntervalTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - &pgtype.Interval{Microseconds: 1, Valid: true}, - &pgtype.Interval{Microseconds: 1000000, Valid: true}, - &pgtype.Interval{Microseconds: 1000001, Valid: true}, - &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, - &pgtype.Interval{Days: 1, Valid: true}, - &pgtype.Interval{Months: 1, Valid: true}, - &pgtype.Interval{Months: 12, Valid: true}, - &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, - &pgtype.Interval{Microseconds: -1, Valid: true}, - &pgtype.Interval{Microseconds: -1000000, Valid: true}, - &pgtype.Interval{Microseconds: -1000001, Valid: true}, - &pgtype.Interval{Microseconds: -123202800000000, Valid: true}, - &pgtype.Interval{Days: -1, Valid: true}, - &pgtype.Interval{Months: -1, Valid: true}, - &pgtype.Interval{Months: -12, Valid: true}, - &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, - &pgtype.Interval{}, +func TestIntervalCodec(t *testing.T) { + testPgxCodec(t, "interval", []PgxTranscodeTestCase{ + { + pgtype.Interval{Microseconds: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: 123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + pgtype.Interval{Months: 12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000000, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -1000001, Valid: true}), + }, + { + pgtype.Interval{Microseconds: -123202800000000, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: -123202800000000, Valid: true}), + }, + { + pgtype.Interval{Days: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -1, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -1, Valid: true}), + }, + { + pgtype.Interval{Months: -12, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -12, Valid: true}), + }, + { + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}, + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Valid: true}), + }, + { + "1 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000000, Valid: true}), + }, + { + "1.000001 second", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 1000001, Valid: true}), + }, + { + "34223 hours", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Microseconds: 123202800000000, Valid: true}), + }, + { + "1 day", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Days: 1, Valid: true}), + }, + { + "1 month", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 1, Valid: true}), + }, + { + "1 year", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: 12, Valid: true}), + }, + { + "-13 mon", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{Months: -13, Valid: true}), + }, + {time.Hour, new(time.Duration), isExpectedEq(time.Hour)}, + { + pgtype.Interval{Months: 1, Days: 1, Valid: true}, + new(time.Duration), + isExpectedEq(time.Duration(2678400000000000)), + }, + {pgtype.Interval{}, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, + {nil, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, }) } - -func TestIntervalNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ - { - SQL: "select '1 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000000, Valid: true}, - }, - { - SQL: "select '1.000001 second'::interval", - Value: &pgtype.Interval{Microseconds: 1000001, Valid: true}, - }, - { - SQL: "select '34223 hours'::interval", - Value: &pgtype.Interval{Microseconds: 123202800000000, Valid: true}, - }, - { - SQL: "select '1 day'::interval", - Value: &pgtype.Interval{Days: 1, Valid: true}, - }, - { - SQL: "select '1 month'::interval", - Value: &pgtype.Interval{Months: 1, Valid: true}, - }, - { - SQL: "select '1 year'::interval", - Value: &pgtype.Interval{Months: 12, Valid: true}, - }, - { - SQL: "select '-13 mon'::interval", - Value: &pgtype.Interval{Months: -13, Valid: true}, - }, - }) -} - -func TestIntervalLossyConversionToDuration(t *testing.T) { - interval := &pgtype.Interval{Months: 1, Days: 1, Valid: true} - var d time.Duration - err := interval.AssignTo(&d) - require.NoError(t, err) - assert.EqualValues(t, int64(2678400000000000), d.Nanoseconds()) -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index c0d02197..5ac3b50d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -305,7 +305,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) ci.RegisterDataType(DataType{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) - ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) @@ -858,6 +858,8 @@ func tryWrapBuiltinTypeScanPlan(dst interface{}) (plan WrappedScanPlanNextSetter return &wrapStringScanPlan{}, (*stringWrapper)(dst), true case *time.Time: return &wrapTimeScanPlan{}, (*timeWrapper)(dst), true + case *time.Duration: + return &wrapDurationScanPlan{}, (*durationWrapper)(dst), true case *net.IPNet: return &wrapNetIPNetScanPlan{}, (*netIPNetWrapper)(dst), true case *net.IP: @@ -1011,6 +1013,16 @@ func (plan *wrapTimeScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, s return plan.next.Scan(ci, oid, formatCode, src, (*timeWrapper)(dst.(*time.Time))) } +type wrapDurationScanPlan struct { + next ScanPlan +} + +func (plan *wrapDurationScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapDurationScanPlan) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.next.Scan(ci, oid, formatCode, src, (*durationWrapper)(dst.(*time.Duration))) +} + type wrapNetIPNetScanPlan struct { next ScanPlan } @@ -1143,8 +1155,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan for _, f := range tryWrappers { if wrapperPlan, nextDst, ok := f(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan + if _, ok := nextPlan.(*scanPlanDataTypeAssignTo); !ok { // avoid fallthrough -- this will go away when old system removed. + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } } } } @@ -1381,6 +1395,8 @@ func tryWrapBuiltinTypeEncodePlan(value interface{}) (plan WrappedEncodePlanNext return &wrapStringEncodePlan{}, stringWrapper(value), true case time.Time: return &wrapTimeEncodePlan{}, timeWrapper(value), true + case time.Duration: + return &wrapDurationEncodePlan{}, durationWrapper(value), true case net.IPNet: return &wrapNetIPNetEncodePlan{}, netIPNetWrapper(value), true case net.IP: @@ -1534,6 +1550,16 @@ func (plan *wrapTimeEncodePlan) Encode(value interface{}, buf []byte) (newBuf [] return plan.next.Encode(timeWrapper(value.(time.Time)), buf) } +type wrapDurationEncodePlan struct { + next EncodePlan +} + +func (plan *wrapDurationEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapDurationEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + return plan.next.Encode(durationWrapper(value.(time.Duration)), buf) +} + type wrapNetIPNetEncodePlan struct { next EncodePlan }