diff --git a/messages.go b/messages.go index fec34dbb..8eb4b8ef 100644 --- a/messages.go +++ b/messages.go @@ -136,6 +136,12 @@ func (wb *WriteBuf) WriteInt16(n int16) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint16(n uint16) { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + wb.buf = append(wb.buf, b...) +} + func (wb *WriteBuf) WriteInt32(n int32) { b := make([]byte, 4) binary.BigEndian.PutUint32(b, uint32(n)) diff --git a/msg_reader.go b/msg_reader.go index 2bcd2d51..59617b73 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -137,6 +137,34 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint16() uint16 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 2 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(2) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint16(binary.BigEndian.Uint16(b)) + + r.reader.Discard(2) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readUint32() uint32 { if r.err != nil { return 0 diff --git a/value_reader.go b/value_reader.go index 6e552ea8..a4897543 100644 --- a/value_reader.go +++ b/value_reader.go @@ -60,6 +60,20 @@ func (r *ValueReader) ReadInt16() int16 { return r.mr.readInt16() } +func (r *ValueReader) ReadUint16() uint16 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 2 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint16() +} + func (r *ValueReader) ReadInt32() int32 { if r.err != nil { return 0 diff --git a/values.go b/values.go index 4d542bf1..db96a007 100644 --- a/values.go +++ b/values.go @@ -8,6 +8,7 @@ import ( "math" "net" "reflect" + "regexp" "strconv" "strings" "time" @@ -22,6 +23,7 @@ const ( Int4Oid = 23 TextOid = 25 OidOid = 26 + TidOid = 27 XidOid = 28 CidOid = 29 JsonOid = 114 @@ -96,6 +98,7 @@ func init() { "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, + "tid": BinaryFormatCode, "xid": BinaryFormatCode, "cid": BinaryFormatCode, "record": BinaryFormatCode, @@ -439,6 +442,61 @@ func (n NullCid) Encode(w *WriteBuf, oid Oid) error { return encodeCid(w, oid, n.Cid) } +// Tid is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type Tid struct { + BlockNumber uint32 + OffsetNumber uint16 +} + +// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the +// Scanner and Encoder interfaces so it may be used both as an argument to +// Query[Row] and a destination for Scan. +// +// If Valid is false then the value is NULL. +type NullTid struct { + Tid Tid + Valid bool // Valid is true if Int32 is not NULL +} + +func (n *NullTid) Scan(vr *ValueReader) error { + if vr.Type().DataType != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) + } + + if vr.Len() == -1 { + n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false + return nil + } + n.Valid = true + n.Tid = decodeTid(vr) + return vr.Err() +} + +func (n NullTid) FormatCode() int16 { return BinaryFormatCode } + +func (n NullTid) Encode(w *WriteBuf, oid Oid) error { + if oid != TidOid { + return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) + } + + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTid(w, oid, n.Tid) +} + // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -933,6 +991,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeOid(vr) case *Xid: *v = decodeXid(vr) + case *Tid: + *v = decodeTid(vr) case *Cid: *v = decodeCid(vr) case *string: @@ -1545,6 +1605,66 @@ func encodeCid(w *WriteBuf, oid Oid, value Cid) error { return nil } +// Note that we do not match negative numbers, because neither the +// BlockNumber nor OffsetNumber of a Tid can be negative. +var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) + +func decodeTid(vr *ValueReader) Tid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Tid")) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + if vr.Type().DataType != TidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + // Unlikely Tid will ever go over the wire as text format, but who knows? + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + + match := tidRegexp.FindStringSubmatch(s) + if match == nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + + blockNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) + } + + offsetNumber, err := strconv.ParseUint(s, 10, 16) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) + } + return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} + case BinaryFormatCode: + if vr.Len() != 6 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } + return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Tid{BlockNumber: 0, OffsetNumber: 0} + } +} + +func encodeTid(w *WriteBuf, oid Oid, value Tid) error { + if oid != TidOid { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) + } + + w.WriteInt32(6) + w.WriteUint32(value.BlockNumber) + w.WriteUint16(value.OffsetNumber) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) diff --git a/values_test.go b/values_test.go index 2325b6f1..cea70b9c 100644 --- a/values_test.go +++ b/values_test.go @@ -596,6 +596,7 @@ func TestNullX(t *testing.T) { i32 pgx.NullInt32 xid pgx.NullXid cid pgx.NullCid + tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -623,6 +624,9 @@ func TestNullX(t *testing.T) { {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, + {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},