diff --git a/messages.go b/messages.go index db0258de..7e5c3b54 100644 --- a/messages.go +++ b/messages.go @@ -53,7 +53,7 @@ func (s *startupMessage) Bytes() (buf []byte) { return buf } -type Oid int32 +type Oid uint32 type FieldDescription struct { Name string diff --git a/msg_reader.go b/msg_reader.go index c8869bdd..2bcd2d51 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -137,6 +137,34 @@ func (r *msgReader) readInt32() int32 { return n } +func (r *msgReader) readUint32() uint32 { + if r.err != nil { + return 0 + } + + r.msgBytesRemaining -= 4 + if r.msgBytesRemaining < 0 { + r.fatal(errors.New("read past end of message")) + return 0 + } + + b, err := r.reader.Peek(4) + if err != nil { + r.fatal(err) + return 0 + } + + n := uint32(binary.BigEndian.Uint32(b)) + + r.reader.Discard(4) + + if r.shouldLog(LogLevelTrace) { + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + } + + return n +} + func (r *msgReader) readInt64() int64 { if r.err != nil { return 0 diff --git a/value_reader.go b/value_reader.go index 4936b887..6e552ea8 100644 --- a/value_reader.go +++ b/value_reader.go @@ -74,6 +74,20 @@ func (r *ValueReader) ReadInt32() int32 { return r.mr.readInt32() } +func (r *ValueReader) ReadUint32() uint32 { + if r.err != nil { + return 0 + } + + r.valueBytesRemaining -= 4 + if r.valueBytesRemaining < 0 { + r.Fatal(errors.New("read past end of value")) + return 0 + } + + return r.mr.readUint32() +} + func (r *ValueReader) ReadInt64() int64 { if r.err != nil { return 0 @@ -89,7 +103,7 @@ func (r *ValueReader) ReadInt64() int64 { } func (r *ValueReader) ReadOid() Oid { - return Oid(r.ReadInt32()) + return Oid(r.ReadUint32()) } // ReadString reads count bytes and returns as string diff --git a/values_test.go b/values_test.go index 7a690055..3e650b61 100644 --- a/values_test.go +++ b/values_test.go @@ -551,6 +551,39 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } +func TestOid(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value pgx.Oid + }{ + {"select $1::oid", 0}, + {"select $1::oid", 1}, + {"select $1::oid", 4294967295}, + } + + for i, tt := range tests { + expected := tt.value + var actual pgx.Oid + + err := conn.QueryRow(tt.sql, expected).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, expected) + continue + } + + if actual != expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + func TestNullX(t *testing.T) { t.Parallel()