From f861d83a17de567831a2a50ab9e4743aac335bbe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 8 Feb 2022 16:48:17 -0600 Subject: [PATCH] Fix range types not clearing unbounded or empty --- pgtype/range_codec.go | 3 ++- pgtype/range_codec_test.go | 52 ++++++++++++++++++++++++++++++++++++++ pgtype/range_types.go | 42 ++++++++++++++++++++++++++++++ pgtype/range_types.go.erb | 6 +++++ 4 files changed, 102 insertions(+), 1 deletion(-) diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 0dc63e6c..f5091c36 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -29,7 +29,8 @@ type RangeScanner interface { ScanBounds() (lowerTarget, upperTarget interface{}) // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned - // (if appropriate) before SetBoundTypes is called. + // (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must + // also set the bound values. SetBoundTypes(lower, upper BoundType) error } diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index 30095065..6597ab98 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -63,6 +63,58 @@ func TestRangeCodecTranscodeCompatibleRangeElementTypes(t *testing.T) { }) } +func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + var r pgtype.Int4range + + err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(context.Background(), `select '[1,)'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Unbounded, + Valid: true, + }, + r, + ) + + err = conn.QueryRow(context.Background(), `select 'empty'::int4range`).Scan(&r) + require.NoError(t, err) + + require.Equal( + t, + pgtype.Int4range{ + Lower: pgtype.Int4{}, + Upper: pgtype.Int4{}, + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Valid: true, + }, + r, + ) +} + func TestRangeCodecDecodeValue(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) diff --git a/pgtype/range_types.go b/pgtype/range_types.go index aa979d56..1496ca30 100644 --- a/pgtype/range_types.go +++ b/pgtype/range_types.go @@ -31,6 +31,12 @@ func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Int4{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Int4{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -67,6 +73,12 @@ func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Int8{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Int8{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -103,6 +115,12 @@ func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Numeric{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Numeric{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -139,6 +157,12 @@ func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Timestamp{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Timestamp{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -175,6 +199,12 @@ func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Timestamptz{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Timestamptz{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -211,6 +241,12 @@ func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Date{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Date{} + } r.LowerType = lower r.UpperType = upper r.Valid = true @@ -247,6 +283,12 @@ func (r *Float8range) ScanBounds() (lowerTarget, upperTarget interface{}) { } func (r *Float8range) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = Float8{} + } + if upper == Unbounded || upper == Empty { + r.Upper = Float8{} + } r.LowerType = lower r.UpperType = upper r.Valid = true diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb index dc796a1d..8b43f7f9 100644 --- a/pgtype/range_types.go.erb +++ b/pgtype/range_types.go.erb @@ -41,6 +41,12 @@ func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) } func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { + if lower == Unbounded || lower == Empty { + r.Lower = <%= element_type %>{} + } + if upper == Unbounded || upper == Empty { + r.Upper = <%= element_type %>{} + } r.LowerType = lower r.UpperType = upper r.Valid = true