Add pgtype.Record and prerequisite restructuring
Because reading a record type requires the decoder to be able to look up oid to type mapping and types such as hstore have types that are not fixed between different PostgreSQL servers it was necessary to restructure the pgtype system so all encoders and decodes take a *ConnInfo that includes oid/name/type information.
This commit is contained in:
@@ -5,9 +5,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
@@ -167,7 +165,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
switch arg := arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := arg.EncodeBinary(buf)
|
||||
null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -180,7 +178,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
return nil
|
||||
case pgtype.TextEncoder:
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := arg.EncodeText(buf)
|
||||
null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -214,14 +212,15 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
return Encode(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok {
|
||||
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
|
||||
value := dt.Value
|
||||
err := value.Set(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := value.(pgtype.BinaryEncoder).EncodeBinary(buf)
|
||||
null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -287,8 +286,6 @@ func Decode(vr *ValueReader, d interface{}) error {
|
||||
switch v := d.(type) {
|
||||
case *string:
|
||||
*v = decodeText(vr)
|
||||
case *[]interface{}:
|
||||
*v = decodeRecord(vr)
|
||||
default:
|
||||
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
@@ -320,232 +317,6 @@ func Decode(vr *ValueReader, d interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBool(vr *ValueReader) bool {
|
||||
if vr.Type().DataType != BoolOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType)))
|
||||
return false
|
||||
}
|
||||
|
||||
var b pgtype.Bool
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = b.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = b.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return false
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return false
|
||||
}
|
||||
|
||||
if b.Status != pgtype.Present {
|
||||
vr.Fatal(fmt.Errorf("Cannot decode null into bool"))
|
||||
return false
|
||||
}
|
||||
|
||||
return b.Bool
|
||||
}
|
||||
|
||||
func decodeInt(vr *ValueReader) int64 {
|
||||
switch vr.Type().DataType {
|
||||
case Int2Oid:
|
||||
return int64(decodeInt2(vr))
|
||||
case Int4Oid:
|
||||
return int64(decodeInt4(vr))
|
||||
case Int8Oid:
|
||||
return int64(decodeInt8(vr))
|
||||
}
|
||||
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
func decodeInt8(vr *ValueReader) int64 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int64"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int8Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
var n pgtype.Int8
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = n.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = n.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if n.Status == pgtype.Null {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return n.Int
|
||||
}
|
||||
|
||||
func decodeInt2(vr *ValueReader) int16 {
|
||||
|
||||
if vr.Type().DataType != Int2Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
var n pgtype.Int2
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = n.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = n.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if n.Status == pgtype.Null {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return n.Int
|
||||
}
|
||||
|
||||
func decodeInt4(vr *ValueReader) int32 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int32"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Int4Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
var n pgtype.Int4
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = n.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = n.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if n.Status == pgtype.Null {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||
return 0
|
||||
}
|
||||
|
||||
return n.Int
|
||||
}
|
||||
|
||||
func decodeFloat4(vr *ValueReader) float32 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into float32"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Float4Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Len() != 4 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
|
||||
return 0
|
||||
}
|
||||
|
||||
i := vr.ReadInt32()
|
||||
return math.Float32frombits(uint32(i))
|
||||
}
|
||||
|
||||
func encodeFloat32(w *WriteBuf, oid pgtype.Oid, value float32) error {
|
||||
switch oid {
|
||||
case Float4Oid:
|
||||
w.WriteInt32(4)
|
||||
w.WriteInt32(int32(math.Float32bits(value)))
|
||||
case Float8Oid:
|
||||
w.WriteInt32(8)
|
||||
w.WriteInt64(int64(math.Float64bits(float64(value))))
|
||||
default:
|
||||
return fmt.Errorf("cannot encode %s into oid %v", "float32", oid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeFloat8(vr *ValueReader) float64 {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into float64"))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().DataType != Float8Oid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Len() != 8 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len())))
|
||||
return 0
|
||||
}
|
||||
|
||||
i := vr.ReadInt64()
|
||||
return math.Float64frombits(uint64(i))
|
||||
}
|
||||
|
||||
func encodeFloat64(w *WriteBuf, oid pgtype.Oid, value float64) error {
|
||||
switch oid {
|
||||
case Float8Oid:
|
||||
w.WriteInt32(8)
|
||||
w.WriteInt64(int64(math.Float64bits(value)))
|
||||
default:
|
||||
return fmt.Errorf("cannot encode %s into oid %v", "float64", oid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeText(vr *ValueReader) string {
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into string"))
|
||||
@@ -677,215 +448,3 @@ func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeDate(vr *ValueReader) time.Time {
|
||||
if vr.Type().DataType != DateOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
var d pgtype.Date
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = d.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = d.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if d.Status == pgtype.Null {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into int16"))
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
return d.Time
|
||||
}
|
||||
|
||||
func encodeTime(w *WriteBuf, oid pgtype.Oid, value time.Time) error {
|
||||
switch oid {
|
||||
case DateOid:
|
||||
var d pgtype.Date
|
||||
err := d.Set(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := d.EncodeBinary(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if null {
|
||||
w.WriteInt32(-1)
|
||||
} else {
|
||||
w.WriteInt32(int32(buf.Len()))
|
||||
w.WriteBytes(buf.Bytes())
|
||||
}
|
||||
return nil
|
||||
|
||||
case TimestampTzOid, TimestampOid:
|
||||
var t pgtype.Timestamptz
|
||||
err := t.Set(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := t.EncodeBinary(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if null {
|
||||
w.WriteInt32(-1)
|
||||
} else {
|
||||
w.WriteInt32(int32(buf.Len()))
|
||||
w.WriteBytes(buf.Bytes())
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid)
|
||||
}
|
||||
}
|
||||
|
||||
const microsecFromUnixEpochToY2K = 946684800 * 1000000
|
||||
|
||||
func decodeTimestampTz(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TimestampTzOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
var t pgtype.Timestamptz
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = t.DecodeText(vr.bytes())
|
||||
case BinaryFormatCode:
|
||||
err = t.DecodeBinary(vr.bytes())
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
if t.Status == pgtype.Null {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into time.Time"))
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
return t.Time
|
||||
}
|
||||
|
||||
func decodeTimestamp(vr *ValueReader) time.Time {
|
||||
var zeroTime time.Time
|
||||
|
||||
if vr.Len() == -1 {
|
||||
vr.Fatal(ProtocolError("Cannot decode null into timestamp"))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().DataType != TimestampOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType)))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
if vr.Len() != 8 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len())))
|
||||
return zeroTime
|
||||
}
|
||||
|
||||
microsecSinceY2K := vr.ReadInt64()
|
||||
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
|
||||
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
|
||||
}
|
||||
|
||||
func decodeRecord(vr *ValueReader) []interface{} {
|
||||
if vr.Len() == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if vr.Type().DataType != RecordOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
valueCount := vr.ReadInt32()
|
||||
record := make([]interface{}, 0, int(valueCount))
|
||||
|
||||
for i := int32(0); i < valueCount; i++ {
|
||||
fd := FieldDescription{FormatCode: BinaryFormatCode}
|
||||
fieldVR := ValueReader{mr: vr.mr, fd: &fd}
|
||||
fd.DataType = vr.ReadOid()
|
||||
fieldVR.valueBytesRemaining = vr.ReadInt32()
|
||||
vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
|
||||
|
||||
switch fd.DataType {
|
||||
case BoolOid:
|
||||
record = append(record, decodeBool(&fieldVR))
|
||||
case ByteaOid:
|
||||
record = append(record, decodeBytea(&fieldVR))
|
||||
case Int8Oid:
|
||||
record = append(record, decodeInt8(&fieldVR))
|
||||
case Int2Oid:
|
||||
record = append(record, decodeInt2(&fieldVR))
|
||||
case Int4Oid:
|
||||
record = append(record, decodeInt4(&fieldVR))
|
||||
case Float4Oid:
|
||||
record = append(record, decodeFloat4(&fieldVR))
|
||||
case Float8Oid:
|
||||
record = append(record, decodeFloat8(&fieldVR))
|
||||
case DateOid:
|
||||
record = append(record, decodeDate(&fieldVR))
|
||||
case TimestampTzOid:
|
||||
record = append(record, decodeTimestampTz(&fieldVR))
|
||||
case TimestampOid:
|
||||
record = append(record, decodeTimestamp(&fieldVR))
|
||||
case TextOid, VarcharOid, UnknownOid:
|
||||
record = append(record, decodeTextAllowBinary(&fieldVR))
|
||||
default:
|
||||
vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Consume any remaining data
|
||||
if fieldVR.Len() > 0 {
|
||||
fieldVR.ReadBytes(fieldVR.Len())
|
||||
}
|
||||
|
||||
if fieldVR.Err() != nil {
|
||||
vr.Fatal(fieldVR.Err())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user