2
0
Files
pgx/values.go
T
Jack Christensen 19c6689752 Add pgtype.Record and prerequisite restructuring
Because reading a record type requires the decoder to be able to look up oid
to type mapping and types such as hstore have types that are not fixed between
different PostgreSQL servers it was necessary to restructure the pgtype system
so all encoders and decodes take a *ConnInfo that includes oid/name/type
information.
2017-03-18 12:01:16 -05:00

451 lines
12 KiB
Go

package pgx
import (
"bytes"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"github.com/jackc/pgx/pgtype"
)
// PostgreSQL oids for common types
const (
BoolOid = 16
ByteaOid = 17
CharOid = 18
NameOid = 19
Int8Oid = 20
Int2Oid = 21
Int4Oid = 23
TextOid = 25
OidOid = 26
TidOid = 27
XidOid = 28
CidOid = 29
JsonOid = 114
CidrOid = 650
CidrArrayOid = 651
Float4Oid = 700
Float8Oid = 701
UnknownOid = 705
InetOid = 869
BoolArrayOid = 1000
Int2ArrayOid = 1005
Int4ArrayOid = 1007
TextArrayOid = 1009
ByteaArrayOid = 1001
VarcharArrayOid = 1015
Int8ArrayOid = 1016
Float4ArrayOid = 1021
Float8ArrayOid = 1022
AclitemOid = 1033
AclitemArrayOid = 1034
InetArrayOid = 1041
VarcharOid = 1043
DateOid = 1082
TimestampOid = 1114
TimestampArrayOid = 1115
DateArrayOid = 1182
TimestampTzOid = 1184
TimestampTzArrayOid = 1185
RecordOid = 2249
UuidOid = 2950
JsonbOid = 3802
)
// PostgreSQL format codes
const (
TextFormatCode = 0
BinaryFormatCode = 1
)
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
const minInt = -maxInt - 1
// DefaultTypeFormats maps type names to their default requested format (text
// or binary). In theory the Scanner interface should be the one to determine
// the format of the returned values. However, the query has already been
// executed by the time Scan is called so it has no chance to set the format.
// So for types that should always be returned in binary the format should be
// set here.
var DefaultTypeFormats map[string]int16
// internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface)
var internalNativeGoTypeFormats map[pgtype.Oid]int16
func init() {
DefaultTypeFormats = map[string]int16{
"_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
"_bool": BinaryFormatCode,
"_bytea": BinaryFormatCode,
"_cidr": BinaryFormatCode,
"_float4": BinaryFormatCode,
"_float8": BinaryFormatCode,
"_inet": BinaryFormatCode,
"_int2": BinaryFormatCode,
"_int4": BinaryFormatCode,
"_int8": BinaryFormatCode,
"_text": BinaryFormatCode,
"_timestamp": BinaryFormatCode,
"_timestamptz": BinaryFormatCode,
"_varchar": BinaryFormatCode,
"aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
"bool": BinaryFormatCode,
"bytea": BinaryFormatCode,
"char": BinaryFormatCode,
"cid": BinaryFormatCode,
"cidr": BinaryFormatCode,
"date": BinaryFormatCode,
"float4": BinaryFormatCode,
"float8": BinaryFormatCode,
"inet": BinaryFormatCode,
"int2": BinaryFormatCode,
"int4": BinaryFormatCode,
"int8": BinaryFormatCode,
"oid": BinaryFormatCode,
"record": BinaryFormatCode,
"tid": BinaryFormatCode,
"timestamp": BinaryFormatCode,
"timestamptz": BinaryFormatCode,
"xid": BinaryFormatCode,
}
internalNativeGoTypeFormats = map[pgtype.Oid]int16{
BoolArrayOid: BinaryFormatCode,
BoolOid: BinaryFormatCode,
ByteaArrayOid: BinaryFormatCode,
ByteaOid: BinaryFormatCode,
CidrArrayOid: BinaryFormatCode,
CidrOid: BinaryFormatCode,
DateOid: BinaryFormatCode,
Float4ArrayOid: BinaryFormatCode,
Float4Oid: BinaryFormatCode,
Float8ArrayOid: BinaryFormatCode,
Float8Oid: BinaryFormatCode,
InetArrayOid: BinaryFormatCode,
InetOid: BinaryFormatCode,
Int2ArrayOid: BinaryFormatCode,
Int2Oid: BinaryFormatCode,
Int4ArrayOid: BinaryFormatCode,
Int4Oid: BinaryFormatCode,
Int8ArrayOid: BinaryFormatCode,
Int8Oid: BinaryFormatCode,
JsonbOid: BinaryFormatCode,
JsonOid: BinaryFormatCode,
OidOid: BinaryFormatCode,
RecordOid: BinaryFormatCode,
TextArrayOid: BinaryFormatCode,
TimestampArrayOid: BinaryFormatCode,
TimestampOid: BinaryFormatCode,
TimestampTzArrayOid: BinaryFormatCode,
TimestampTzOid: BinaryFormatCode,
VarcharArrayOid: BinaryFormatCode,
}
}
// SerializationError occurs on failure to encode or decode a value
type SerializationError string
func (e SerializationError) Error() string {
return string(e)
}
// Encode encodes arg into wbuf as the type oid. This allows implementations
// of the Encoder interface to delegate the actual work of encoding to the
// built-in functionality.
func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
if arg == nil {
wbuf.WriteInt32(-1)
return nil
}
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
buf := &bytes.Buffer{}
null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
case pgtype.TextEncoder:
buf := &bytes.Buffer{}
null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
case driver.Valuer:
v, err := arg.Value()
if err != nil {
return err
}
return Encode(wbuf, oid, v)
case string:
return encodeString(wbuf, oid, arg)
case []byte:
return encodeByteSlice(wbuf, oid, arg)
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
wbuf.WriteInt32(-1)
return nil
}
arg = refVal.Elem().Interface()
return Encode(wbuf, oid, arg)
}
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
return err
}
buf := &bytes.Buffer{}
null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return Encode(wbuf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int:
convVal := int(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int8:
convVal := int8(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int16:
convVal := int16(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int32:
convVal := int32(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int64:
convVal := int64(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint:
convVal := uint(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint8:
convVal := uint8(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint16:
convVal := uint16(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint32:
convVal := uint32(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint64:
convVal := uint64(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.String:
convVal := val.String()
return convVal, reflect.TypeOf(convVal) != val.Type()
}
return nil, false
}
// Decode decodes from vr into d. d must be a pointer. This allows
// implementations of the Decoder interface to delegate the actual work of
// decoding to the built-in functionality.
func Decode(vr *ValueReader, d interface{}) error {
switch v := d.(type) {
case *string:
*v = decodeText(vr)
default:
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if d is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
// -1 is a null value
if vr.Len() == -1 {
if !el.IsNil() {
// if the destination pointer is not nil, nil it out
el.Set(reflect.Zero(el.Type()))
}
return nil
}
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
d = el.Interface()
return Decode(vr, d)
case reflect.String:
el.SetString(decodeText(vr))
return nil
}
}
return fmt.Errorf("Scan cannot decode into %T", d)
}
return nil
}
func decodeText(vr *ValueReader) string {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into string"))
return ""
}
if vr.Type().FormatCode == BinaryFormatCode {
vr.Fatal(ProtocolError("cannot decode binary value into string"))
return ""
}
return vr.ReadString(vr.Len())
}
func decodeTextAllowBinary(vr *ValueReader) string {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into string"))
return ""
}
return vr.ReadString(vr.Len())
}
func encodeString(w *WriteBuf, oid pgtype.Oid, value string) error {
w.WriteInt32(int32(len(value)))
w.WriteBytes([]byte(value))
return nil
}
func decodeBytea(vr *ValueReader) []byte {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != ByteaOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType)))
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
return vr.ReadBytes(vr.Len())
}
func encodeByteSlice(w *WriteBuf, oid pgtype.Oid, value []byte) error {
w.WriteInt32(int32(len(value)))
w.WriteBytes(value)
return nil
}
func decodeJSON(vr *ValueReader, d interface{}) error {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != JsonOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
}
bytes := vr.ReadBytes(vr.Len())
err := json.Unmarshal(bytes, d)
if err != nil {
vr.Fatal(err)
}
return err
}
func encodeJSON(w *WriteBuf, oid pgtype.Oid, value interface{}) error {
if oid != JsonOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
s, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("Failed to encode json from type: %T", value)
}
w.WriteInt32(int32(len(s)))
w.WriteBytes(s)
return nil
}
func decodeJSONB(vr *ValueReader, d interface{}) error {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != JsonbOid {
err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))
vr.Fatal(err)
return err
}
bytes := vr.ReadBytes(vr.Len())
if vr.Type().FormatCode == BinaryFormatCode {
if bytes[0] != 1 {
err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))
vr.Fatal(err)
return err
}
bytes = bytes[1:]
}
err := json.Unmarshal(bytes, d)
if err != nil {
vr.Fatal(err)
}
return err
}
func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error {
if oid != JsonbOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
s, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("Failed to encode json from type: %T", value)
}
w.WriteInt32(int32(len(s) + 1))
w.WriteByte(1) // JSONB format header
w.WriteBytes(s)
return nil
}