diff --git a/Rakefile b/Rakefile index 4579034d..f3a61a09 100644 --- a/Rakefile +++ b/Rakefile @@ -6,5 +6,14 @@ rule '.go' => '.go.erb' do |task| sh "goimports", "-w", task.name end +generated_code_files = [ + "pgtype/int.go", + "pgtype/int_test.go", + "pgtype/integration_benchmark_test.go", + "pgtype/range_types.go", + "pgtype/zeronull/int.go", + "pgtype/zeronull/int_test.go" +] + desc "Generate code" -task generate: ["pgtype/int.go", "pgtype/int_test.go", "pgtype/integration_benchmark_test.go", "pgtype/zeronull/int.go", "pgtype/zeronull/int_test.go"] +task generate: generated_code_files diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb index c98f6488..8858ce90 100644 --- a/pgtype/int_test.go.erb +++ b/pgtype/int_test.go.erb @@ -10,7 +10,7 @@ import ( <% [2, 4, 8].each do |pg_byte_size| %> <% pg_bit_size = pg_byte_size * 8 %> func TestInt<%= pg_byte_size %>Codec(t *testing.T) { - testPgxCodec(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ + testutil.RunTranscodeTests(t, "int<%= pg_byte_size %>", []testutil.TranscodeTestCase{ {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index ab317f6e..ec6d3ec9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -208,14 +208,6 @@ func NewConnInfo() *ConnInfo { }, } - // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) - // ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) - // ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) - // ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) - // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) - // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) - // ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) - // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) ci.RegisterDataType(DataType{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) ci.RegisterDataType(DataType{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) @@ -257,6 +249,16 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + ci.RegisterDataType(DataType{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[DateOID]}}) + ci.RegisterDataType(DataType{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int4OID]}}) + ci.RegisterDataType(DataType{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[Int8OID]}}) + ci.RegisterDataType(DataType{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[NumericOID]}}) + ci.RegisterDataType(DataType{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestampOID]}}) + ci.RegisterDataType(DataType{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementDataType: ci.oidToDataType[TimestamptzOID]}}) + + // ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + // ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) + ci.RegisterDataType(DataType{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[ACLItemOID]}}) ci.RegisterDataType(DataType{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BitOID]}}) ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementDataType: ci.oidToDataType[BoolOID]}}) diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..e999f6a9 --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,277 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +func (bt BoundType) String() string { + return string(bt) +} + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = Empty + utr.UpperType = Empty + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + buf.UnreadRune() + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go new file mode 100644 index 00000000..0dc63e6c --- /dev/null +++ b/pgtype/range_codec.go @@ -0,0 +1,414 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +// RangeValuer is a type that can be converted into a PostgreSQL range. +type RangeValuer interface { + // IsNull returns true if the value is SQL NULL. + IsNull() bool + + // BoundTypes returns the lower and upper bound types. + BoundTypes() (lower, upper BoundType) + + // Bounds returns the lower and upper range values. + Bounds() (lower, upper interface{}) +} + +// RangeScanner is a type can be scanned from a PostgreSQL range. +type RangeScanner interface { + // ScanNull sets the value to SQL NULL. + ScanNull() error + + // ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or + // the bound type is unbounded. + ScanBounds() (lowerTarget, upperTarget interface{}) + + // SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned + // (if appropriate) before SetBoundTypes is called. + SetBoundTypes(lower, upper BoundType) error +} + +type GenericRange struct { + Lower interface{} + Upper interface{} + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r GenericRange) IsNull() bool { + return !r.Valid +} + +func (r GenericRange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r GenericRange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *GenericRange) ScanNull() error { + *r = GenericRange{} + return nil +} + +func (r *GenericRange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *GenericRange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +// RangeCodec is a codec for any range type. +type RangeCodec struct { + ElementDataType *DataType +} + +func (c *RangeCodec) FormatSupported(format int16) bool { + return c.ElementDataType.Codec.FormatSupported(format) +} + +func (c *RangeCodec) PreferredFormat() int16 { + if c.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode +} + +func (c *RangeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(RangeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return &encodePlanRangeCodecRangeValuerToBinary{rc: c, ci: ci} + case TextFormatCode: + return &encodePlanRangeCodecRangeValuerToText{rc: c, ci: ci} + } + + return nil +} + +type encodePlanRangeCodecRangeValuerToBinary struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + var rangeType byte + switch lowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", lowerType) + } + + switch upperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", upperType) + } + + buf = append(buf, rangeType) + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, BinaryFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +type encodePlanRangeCodecRangeValuerToText struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + getter := value.(RangeValuer) + + if getter.IsNull() { + return nil, nil + } + + lowerType, upperType := getter.BoundTypes() + lower, upper := getter.Bounds() + + switch lowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", lowerType) + } + + if lowerType != Unbounded { + if lower == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + lowerPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, lower) + if lowerPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", lower) + } + + buf, err = lowerPlan.Encode(lower, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if upperType != Unbounded { + if upper == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + upperPlan := plan.ci.PlanEncode(plan.rc.ElementDataType.OID, TextFormatCode, upper) + if upperPlan == nil { + return nil, fmt.Errorf("cannot encode %v as element of range", upper) + } + + buf, err = upperPlan.Encode(upper, buf) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch upperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", upperType) + } + + return buf, nil +} + +func (c *RangeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanBinaryRangeToRangeScanner{rc: c, ci: ci} + } + case TextFormatCode: + switch target.(type) { + case RangeScanner: + return &scanPlanTextRangeToRangeScanner{rc: c, ci: ci} + } + } + + return nil +} + +type scanPlanBinaryRangeToRangeScanner struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target interface{}) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + if ubr.LowerType == Empty { + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive { + lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan(ubr.Lower, lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive { + upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, BinaryFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan(ubr.Upper, upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType) +} + +type scanPlanTextRangeToRangeScanner struct { + rc *RangeCodec + ci *ConnInfo +} + +func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target interface{}) error { + rangeScanner := (target).(RangeScanner) + + if src == nil { + return rangeScanner.ScanNull() + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + if utr.LowerType == Empty { + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) + } + + lowerTarget, upperTarget := rangeScanner.ScanBounds() + + if utr.LowerType == Inclusive || utr.LowerType == Exclusive { + lowerPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, lowerTarget) + if lowerPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", lowerTarget) + } + + err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + } + } + + if utr.UpperType == Inclusive || utr.UpperType == Exclusive { + upperPlan := plan.ci.PlanScan(plan.rc.ElementDataType.OID, TextFormatCode, upperTarget) + if upperPlan == nil { + return fmt.Errorf("cannot scan into %v from range element", upperTarget) + } + + err = upperPlan.Scan([]byte(utr.Upper), upperTarget) + if err != nil { + return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + } + } + + return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType) +} + +func (c *RangeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + switch format { + case TextFormatCode: + return string(src), nil + case BinaryFormatCode: + buf := make([]byte, len(src)) + copy(buf, src) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code %d", format) + } +} + +func (c *RangeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var r GenericRange + err := c.PlanScan(ci, oid, format, &r, true).Scan(src, &r) + return r, err +} diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go new file mode 100644 index 00000000..b4cc9e8e --- /dev/null +++ b/pgtype/range_codec_test.go @@ -0,0 +1,72 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" + "github.com/stretchr/testify/require" +) + +func TestRangeCodecTranscode(t *testing.T) { + testutil.RunTranscodeTests(t, "int4range", []testutil.TranscodeTestCase{ + { + pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}, + new(pgtype.Int4range), + isExpectedEq(pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Valid: true}), + }, + { + pgtype.Int4range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }, + new(pgtype.Int4range), + isExpectedEq(pgtype.Int4range{ + LowerType: pgtype.Inclusive, + Lower: pgtype.Int4{Int: 1, Valid: true}, + Upper: pgtype.Int4{Int: 5, Valid: true}, + UpperType: pgtype.Exclusive, Valid: true, + }), + }, + {pgtype.Int4range{}, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + {nil, new(pgtype.Int4range), isExpectedEq(pgtype.Int4range{})}, + }) +} + +func TestRangeCodecDecodeValue(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for _, tt := range []struct { + sql string + expected interface{} + }{ + { + sql: `select '[1,5)'::int4range`, + expected: pgtype.GenericRange{ + Lower: int32(1), + Upper: int32(5), + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(context.Background(), tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/pgtype/range_types.go b/pgtype/range_types.go new file mode 100644 index 00000000..3f1e7d8a --- /dev/null +++ b/pgtype/range_types.go @@ -0,0 +1,218 @@ +// Do not edit. Generated from pgtype/range_types.go.erb +package pgtype + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Int4range) IsNull() bool { + return !r.Valid +} + +func (r Int4range) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Int4range) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int4range) ScanNull() error { + *r = Int4range{} + return nil +} + +func (r *Int4range) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int4range) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Int8range) IsNull() bool { + return !r.Valid +} + +func (r Int8range) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Int8range) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int8range) ScanNull() error { + *r = Int8range{} + return nil +} + +func (r *Int8range) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Int8range) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Numrange) IsNull() bool { + return !r.Valid +} + +func (r Numrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Numrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Numrange) ScanNull() error { + *r = Numrange{} + return nil +} + +func (r *Numrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Numrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Tsrange) IsNull() bool { + return !r.Valid +} + +func (r Tsrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Tsrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tsrange) ScanNull() error { + *r = Tsrange{} + return nil +} + +func (r *Tsrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tsrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Tstzrange) IsNull() bool { + return !r.Valid +} + +func (r Tstzrange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Tstzrange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tstzrange) ScanNull() error { + *r = Tstzrange{} + return nil +} + +func (r *Tstzrange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Tstzrange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r Daterange) IsNull() bool { + return !r.Valid +} + +func (r Daterange) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r Daterange) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Daterange) ScanNull() error { + *r = Daterange{} + return nil +} + +func (r *Daterange) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *Daterange) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} diff --git a/pgtype/range_types.go.erb b/pgtype/range_types.go.erb new file mode 100644 index 00000000..11b12822 --- /dev/null +++ b/pgtype/range_types.go.erb @@ -0,0 +1,49 @@ +package pgtype + +<% + [ + ["Int4range", "Int4"], + ["Int8range", "Int8"], + ["Numrange", "Numeric"], + ["Tsrange", "Timestamp"], + ["Tstzrange", "Timestamptz"], + ["Daterange", "Date"] + ].each do |range_type, element_type| +%> +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Valid bool +} + +func (r <%= range_type %>) IsNull() bool { + return !r.Valid +} + +func (r <%= range_type %>) BoundTypes() (lower, upper BoundType) { + return r.LowerType, r.UpperType +} + +func (r <%= range_type %>) Bounds() (lower, upper interface{}) { + return &r.Lower, &r.Upper +} + +func (r *<%= range_type %>) ScanNull() error { + *r = <%= range_type %>{} + return nil +} + +func (r *<%= range_type %>) ScanBounds() (lowerTarget, upperTarget interface{}) { + return &r.Lower, &r.Upper +} + +func (r *<%= range_type %>) SetBoundTypes(lower, upper BoundType) error { + r.LowerType = lower + r.UpperType = upper + r.Valid = true + return nil +} + +<% end %>