ed2b3b3b49
But not to NullTime because of text vs binary encoding difficulties. You really should never use timestamp anyway.
1009 lines
26 KiB
Go
1009 lines
26 KiB
Go
package pgx
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"fmt"
|
|
"math"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unsafe"
|
|
)
|
|
|
|
const (
|
|
BoolOid = 16
|
|
ByteaOid = 17
|
|
Int8Oid = 20
|
|
Int2Oid = 21
|
|
Int4Oid = 23
|
|
TextOid = 25
|
|
Float4Oid = 700
|
|
Float8Oid = 701
|
|
VarcharOid = 1043
|
|
DateOid = 1082
|
|
TimestampOid = 1114
|
|
TimestampTzOid = 1184
|
|
)
|
|
|
|
const (
|
|
TextFormatCode = 0
|
|
BinaryFormatCode = 1
|
|
)
|
|
|
|
const (
|
|
NullText = iota
|
|
SafeText = iota
|
|
UnsafeText = iota
|
|
)
|
|
|
|
type SerializationError string
|
|
|
|
func (e SerializationError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
// Scanner is an interface used to decode values from the PostgreSQL server.
|
|
type Scanner interface {
|
|
// Scan MUST check fd's DataType and FormatCode before decoding. It should
|
|
// not assume that it was called on the type of value.
|
|
Scan(rows *Rows, fd *FieldDescription, size int32) error
|
|
}
|
|
|
|
// TextEncoder is an interface used to encode values in text format for
|
|
// transmission to the PostgreSQL server. It is used by unprepared
|
|
// queries and for prepared queries when the type does not implement
|
|
// BinaryEncoder
|
|
type TextEncoder interface {
|
|
// EncodeText returns the value encoded into a string. status must be
|
|
// NullText if the value is NULL, UnsafeText if the value should be quoted
|
|
// and escaped, or SafeText if the value should not be quoted.
|
|
EncodeText() (val string, status byte, err error)
|
|
}
|
|
|
|
// BinaryEncoder is an interface used to encode values in binary format for
|
|
// transmission to the PostgreSQL server. It is used by prepared queries.
|
|
type BinaryEncoder interface {
|
|
// EncodeBinary writes the binary value to w.
|
|
//
|
|
// EncodeBinary MUST check fd.DataType to see if the parameter data type is
|
|
// compatible. If this is not done, the PostgreSQL server may detect the
|
|
// error if the expected data size or format of the encoded data does not
|
|
// match. But if the encoded data is a valid representation of the data type
|
|
// PostgreSQL expects such as date and int4, incorrect data may be stored.
|
|
EncodeBinary(w *WriteBuf, fd *FieldDescription) error
|
|
}
|
|
|
|
// NullFloat32 represents an float4 that may be null.
|
|
// NullFloat32 implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullFloat32 struct {
|
|
Float32 float32
|
|
Valid bool // Valid is true if Float32 is not NULL
|
|
}
|
|
|
|
func (n *NullFloat32) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Float32, n.Valid = 0, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Float32 = decodeFloat4(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullFloat32) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatFloat(float64(n.Float32), 'f', -1, 32), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullFloat32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != Float4Oid {
|
|
return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeFloat4(w, n.Float32)
|
|
}
|
|
|
|
// NullFloat64 represents an float8 that may be null.
|
|
// NullFloat64 implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullFloat64 struct {
|
|
Float64 float64
|
|
Valid bool // Valid is true if Float64 is not NULL
|
|
}
|
|
|
|
func (n *NullFloat64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Float64, n.Valid = 0, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Float64 = decodeFloat8(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullFloat64) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatFloat(n.Float64, 'f', -1, 64), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullFloat64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != Float8Oid {
|
|
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeFloat8(w, n.Float64)
|
|
}
|
|
|
|
// NullString represents an string that may be null. NullString implements
|
|
// the Scanner and TextEncoder interfaces so it may be used both as an
|
|
// argument to Query[Row] and a destination for Scan for prepared and
|
|
// unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullString struct {
|
|
String string
|
|
Valid bool // Valid is true if Int64 is not NULL
|
|
}
|
|
|
|
func (s *NullString) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
s.String, s.Valid = "", false
|
|
return nil
|
|
}
|
|
s.Valid = true
|
|
s.String = decodeText(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (s NullString) EncodeText() (string, byte, error) {
|
|
if s.Valid {
|
|
return s.String, UnsafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
// NullInt16 represents an smallint that may be null.
|
|
// NullInt16 implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullInt16 struct {
|
|
Int16 int16
|
|
Valid bool // Valid is true if Int16 is not NULL
|
|
}
|
|
|
|
func (n *NullInt16) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Int16, n.Valid = 0, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Int16 = decodeInt2(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullInt16) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatInt(int64(n.Int16), 10), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullInt16) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != Int2Oid {
|
|
return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeInt2(w, n.Int16)
|
|
}
|
|
|
|
// NullInt32 represents an integer that may be null.
|
|
// NullInt32 implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullInt32 struct {
|
|
Int32 int32
|
|
Valid bool // Valid is true if Int64 is not NULL
|
|
}
|
|
|
|
func (n *NullInt32) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Int32, n.Valid = 0, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Int32 = decodeInt4(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullInt32) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatInt(int64(n.Int32), 10), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullInt32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != Int4Oid {
|
|
return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeInt4(w, n.Int32)
|
|
}
|
|
|
|
// NullInt64 represents an bigint that may be null.
|
|
// NullInt64 implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullInt64 struct {
|
|
Int64 int64
|
|
Valid bool // Valid is true if Int64 is not NULL
|
|
}
|
|
|
|
func (n *NullInt64) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Int64, n.Valid = 0, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Int64 = decodeInt8(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullInt64) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatInt(int64(n.Int64), 10), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullInt64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != Int8Oid {
|
|
return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeInt8(w, n.Int64)
|
|
}
|
|
|
|
// NullBool represents an bool that may be null.
|
|
// NullBool implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullBool struct {
|
|
Bool bool
|
|
Valid bool // Valid is true if Bool is not NULL
|
|
}
|
|
|
|
func (n *NullBool) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if size == -1 {
|
|
n.Bool, n.Valid = false, false
|
|
return nil
|
|
}
|
|
n.Valid = true
|
|
n.Bool = decodeBool(rows, fd, size)
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullBool) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return strconv.FormatBool(n.Bool), SafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullBool) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != BoolOid {
|
|
return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeBool(w, n.Bool)
|
|
}
|
|
|
|
// NullTime represents an bigint that may be null.
|
|
// NullTime implements the Scanner, TextEncoder, and BinaryEncoder interfaces
|
|
// so it may be used both as an argument to Query[Row] and a destination for
|
|
// Scan for prepared and unprepared queries.
|
|
//
|
|
// If Valid is false then the value is NULL.
|
|
type NullTime struct {
|
|
Time time.Time
|
|
Valid bool // Valid is true if Time is not NULL
|
|
}
|
|
|
|
func (n *NullTime) Scan(rows *Rows, fd *FieldDescription, size int32) error {
|
|
if fd.DataType != TimestampTzOid {
|
|
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if size == -1 {
|
|
n.Time, n.Valid = time.Time{}, false
|
|
return nil
|
|
}
|
|
|
|
n.Valid = true
|
|
n.Time = decodeTimestampTz(rows, fd, size)
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
func (n NullTime) EncodeText() (string, byte, error) {
|
|
if n.Valid {
|
|
return n.Time.Format("2006-01-02 15:04:05.999999 -0700"), UnsafeText, nil
|
|
} else {
|
|
return "", NullText, nil
|
|
}
|
|
}
|
|
|
|
func (n NullTime) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
|
if fd.DataType != TimestampTzOid {
|
|
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType))
|
|
}
|
|
|
|
if !n.Valid {
|
|
w.WriteInt32(-1)
|
|
return nil
|
|
}
|
|
|
|
return encodeTimestampTz(w, n.Time)
|
|
}
|
|
|
|
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
|
|
|
// QuoteString escapes and quotes a string making it safe for interpolation
|
|
// into an SQL string.
|
|
func QuoteString(input string) (output string) {
|
|
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
|
return
|
|
}
|
|
|
|
// QuoteIdentifier escapes and quotes an identifier making it safe for
|
|
// interpolation into an SQL string
|
|
func QuoteIdentifier(input string) (output string) {
|
|
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
|
return
|
|
}
|
|
|
|
// SanitizeSql substitutely args positionaly into sql. Placeholder values are
|
|
// $ prefixed integers like $1, $2, $3, etc. args are sanitized and quoted as
|
|
// appropriate.
|
|
func SanitizeSql(sql string, args ...interface{}) (output string, err error) {
|
|
replacer := func(match string) (replacement string) {
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
|
if int(n-1) >= len(args) {
|
|
err = fmt.Errorf("Cannot interpolate %v, only %d arguments provided", match, len(args))
|
|
return
|
|
}
|
|
|
|
var s string
|
|
s, err = sanitizeArg(args[n-1])
|
|
return s
|
|
}
|
|
|
|
output = literalPattern.ReplaceAllStringFunc(sql, replacer)
|
|
return
|
|
}
|
|
|
|
func sanitizeArg(arg interface{}) (string, error) {
|
|
switch arg := arg.(type) {
|
|
case string:
|
|
return QuoteString(arg), nil
|
|
case int:
|
|
return strconv.FormatInt(int64(arg), 10), nil
|
|
case int8:
|
|
return strconv.FormatInt(int64(arg), 10), nil
|
|
case int16:
|
|
return strconv.FormatInt(int64(arg), 10), nil
|
|
case int32:
|
|
return strconv.FormatInt(int64(arg), 10), nil
|
|
case int64:
|
|
return strconv.FormatInt(int64(arg), 10), nil
|
|
case time.Time:
|
|
return QuoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")), nil
|
|
case uint:
|
|
return strconv.FormatUint(uint64(arg), 10), nil
|
|
case uint8:
|
|
return strconv.FormatUint(uint64(arg), 10), nil
|
|
case uint16:
|
|
return strconv.FormatUint(uint64(arg), 10), nil
|
|
case uint32:
|
|
return strconv.FormatUint(uint64(arg), 10), nil
|
|
case uint64:
|
|
return strconv.FormatUint(uint64(arg), 10), nil
|
|
case float32:
|
|
return strconv.FormatFloat(float64(arg), 'f', -1, 32), nil
|
|
case float64:
|
|
return strconv.FormatFloat(arg, 'f', -1, 64), nil
|
|
case bool:
|
|
return strconv.FormatBool(arg), nil
|
|
case []byte:
|
|
return `E'\\x` + hex.EncodeToString(arg) + `'`, nil
|
|
case nil:
|
|
return "null", nil
|
|
case TextEncoder:
|
|
s, status, err := arg.EncodeText()
|
|
switch status {
|
|
case NullText:
|
|
return "null", err
|
|
case UnsafeText:
|
|
return QuoteString(s), err
|
|
case SafeText:
|
|
return s, err
|
|
default:
|
|
return "", SerializationError("Received invalid status from EncodeText")
|
|
}
|
|
default:
|
|
return "", SerializationError(fmt.Sprintf("%T is not a core type and it does not implement TextEncoder", arg))
|
|
}
|
|
}
|
|
|
|
func decodeBool(rows *Rows, fd *FieldDescription, size int32) bool {
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
switch s {
|
|
case "t":
|
|
return true
|
|
case "f":
|
|
return false
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s)))
|
|
return false
|
|
}
|
|
case BinaryFormatCode:
|
|
if size != 1 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size)))
|
|
return false
|
|
}
|
|
b := rows.mr.ReadByte()
|
|
return b != 0
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return false
|
|
}
|
|
}
|
|
|
|
func encodeBool(w *WriteBuf, value interface{}) error {
|
|
v, ok := value.(bool)
|
|
if !ok {
|
|
return fmt.Errorf("Expected bool, received %T", value)
|
|
}
|
|
|
|
w.WriteInt32(1)
|
|
|
|
var n byte
|
|
if v {
|
|
n = 1
|
|
}
|
|
|
|
w.WriteByte(n)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeInt8(rows *Rows, fd *FieldDescription, size int32) int64 {
|
|
if fd.DataType != Int8Oid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, fd.DataType)))
|
|
return 0
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
n, err := strconv.ParseInt(s, 10, 64)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s)))
|
|
return 0
|
|
}
|
|
return n
|
|
case BinaryFormatCode:
|
|
if size != 8 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size)))
|
|
return 0
|
|
}
|
|
return rows.mr.ReadInt64()
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func encodeInt8(w *WriteBuf, value interface{}) error {
|
|
var v int64
|
|
switch value := value.(type) {
|
|
case int8:
|
|
v = int64(value)
|
|
case uint8:
|
|
v = int64(value)
|
|
case int16:
|
|
v = int64(value)
|
|
case uint16:
|
|
v = int64(value)
|
|
case int32:
|
|
v = int64(value)
|
|
case uint32:
|
|
v = int64(value)
|
|
case int64:
|
|
v = int64(value)
|
|
case uint64:
|
|
if value > math.MaxInt64 {
|
|
return fmt.Errorf("uint64 %d is larger than max int64 %d", value, math.MaxInt64)
|
|
}
|
|
v = int64(value)
|
|
case int:
|
|
v = int64(value)
|
|
default:
|
|
return fmt.Errorf("Expected integer representable in int64, received %T %v", value, value)
|
|
}
|
|
|
|
w.WriteInt32(8)
|
|
w.WriteInt64(v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeInt2(rows *Rows, fd *FieldDescription, size int32) int16 {
|
|
if fd.DataType != Int2Oid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, fd.DataType)))
|
|
return 0
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
n, err := strconv.ParseInt(s, 10, 16)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s)))
|
|
return 0
|
|
}
|
|
return int16(n)
|
|
case BinaryFormatCode:
|
|
if size != 2 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size)))
|
|
return 0
|
|
}
|
|
return rows.mr.ReadInt16()
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func encodeInt2(w *WriteBuf, value interface{}) error {
|
|
var v int16
|
|
switch value := value.(type) {
|
|
case int8:
|
|
v = int16(value)
|
|
case uint8:
|
|
v = int16(value)
|
|
case int16:
|
|
v = int16(value)
|
|
case uint16:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
case int32:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
case uint32:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
case int64:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
case uint64:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
case int:
|
|
if value > math.MaxInt16 {
|
|
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
|
|
}
|
|
v = int16(value)
|
|
default:
|
|
return fmt.Errorf("Expected integer representable in int16, received %T %v", value, value)
|
|
}
|
|
|
|
w.WriteInt32(2)
|
|
w.WriteInt16(v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeInt4(rows *Rows, fd *FieldDescription, size int32) int32 {
|
|
if fd.DataType != Int4Oid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, fd.DataType)))
|
|
return 0
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
n, err := strconv.ParseInt(s, 10, 32)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s)))
|
|
}
|
|
return int32(n)
|
|
case BinaryFormatCode:
|
|
if size != 4 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size)))
|
|
return 0
|
|
}
|
|
return rows.mr.ReadInt32()
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func encodeInt4(w *WriteBuf, value interface{}) error {
|
|
var v int32
|
|
switch value := value.(type) {
|
|
case int8:
|
|
v = int32(value)
|
|
case uint8:
|
|
v = int32(value)
|
|
case int16:
|
|
v = int32(value)
|
|
case uint16:
|
|
v = int32(value)
|
|
case int32:
|
|
v = int32(value)
|
|
case uint32:
|
|
if value > math.MaxInt32 {
|
|
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
|
|
}
|
|
v = int32(value)
|
|
case int64:
|
|
if value > math.MaxInt32 {
|
|
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
|
|
}
|
|
v = int32(value)
|
|
case uint64:
|
|
if value > math.MaxInt32 {
|
|
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
|
|
}
|
|
v = int32(value)
|
|
case int:
|
|
if value > math.MaxInt32 {
|
|
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
|
|
}
|
|
v = int32(value)
|
|
default:
|
|
return fmt.Errorf("Expected integer representable in int32, received %T %v", value, value)
|
|
}
|
|
|
|
w.WriteInt32(4)
|
|
w.WriteInt32(v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeFloat4(rows *Rows, fd *FieldDescription, size int32) float32 {
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
n, err := strconv.ParseFloat(s, 32)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s)))
|
|
return 0
|
|
}
|
|
return float32(n)
|
|
case BinaryFormatCode:
|
|
if size != 4 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size)))
|
|
return 0
|
|
}
|
|
|
|
i := rows.mr.ReadInt32()
|
|
p := unsafe.Pointer(&i)
|
|
return *(*float32)(p)
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func encodeFloat4(w *WriteBuf, value interface{}) error {
|
|
var v float32
|
|
switch value := value.(type) {
|
|
case float32:
|
|
v = float32(value)
|
|
case float64:
|
|
if value > math.MaxFloat32 {
|
|
return fmt.Errorf("%T %f is larger than max float32 %f", value, math.MaxFloat32)
|
|
}
|
|
v = float32(value)
|
|
default:
|
|
return fmt.Errorf("Expected float representable in float32, received %T %v", value, value)
|
|
}
|
|
|
|
w.WriteInt32(4)
|
|
|
|
p := unsafe.Pointer(&v)
|
|
w.WriteInt32(*(*int32)(p))
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeFloat8(rows *Rows, fd *FieldDescription, size int32) float64 {
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
v, err := strconv.ParseFloat(s, 64)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s)))
|
|
return 0
|
|
}
|
|
return v
|
|
case BinaryFormatCode:
|
|
if size != 8 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size)))
|
|
return 0
|
|
}
|
|
|
|
i := rows.mr.ReadInt64()
|
|
p := unsafe.Pointer(&i)
|
|
return *(*float64)(p)
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func encodeFloat8(w *WriteBuf, value interface{}) error {
|
|
var v float64
|
|
switch value := value.(type) {
|
|
case float32:
|
|
v = float64(value)
|
|
case float64:
|
|
v = float64(value)
|
|
default:
|
|
return fmt.Errorf("Expected float representable in float64, received %T %v", value, value)
|
|
}
|
|
|
|
w.WriteInt32(8)
|
|
|
|
p := unsafe.Pointer(&v)
|
|
w.WriteInt64(*(*int64)(p))
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeText(rows *Rows, fd *FieldDescription, size int32) string {
|
|
return rows.mr.ReadString(size)
|
|
}
|
|
|
|
func encodeText(w *WriteBuf, value interface{}) error {
|
|
s, ok := value.(string)
|
|
if !ok {
|
|
return fmt.Errorf("Expected string, received %T", value)
|
|
}
|
|
|
|
w.WriteInt32(int32(len(s)))
|
|
w.WriteBytes([]byte(s))
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeBytea(rows *Rows, fd *FieldDescription, size int32) []byte {
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
b, err := hex.DecodeString(s[2:])
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s)))
|
|
return nil
|
|
}
|
|
return b
|
|
case BinaryFormatCode:
|
|
return rows.mr.ReadBytes(size)
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func encodeBytea(w *WriteBuf, value interface{}) error {
|
|
b, ok := value.([]byte)
|
|
if !ok {
|
|
return fmt.Errorf("Expected []byte, received %T", value)
|
|
}
|
|
|
|
w.WriteInt32(int32(len(b)))
|
|
w.WriteBytes(b)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeDate(rows *Rows, fd *FieldDescription, size int32) time.Time {
|
|
var zeroTime time.Time
|
|
|
|
if fd.DataType != DateOid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, fd.DataType)))
|
|
return zeroTime
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
t, err := time.ParseInLocation("2006-01-02", s, time.Local)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s)))
|
|
return zeroTime
|
|
}
|
|
return t
|
|
case BinaryFormatCode:
|
|
if size != 4 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", size)))
|
|
}
|
|
dayOffset := rows.mr.ReadInt32()
|
|
return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return zeroTime
|
|
}
|
|
}
|
|
|
|
func encodeDate(w *WriteBuf, value interface{}) error {
|
|
t, ok := value.(time.Time)
|
|
if !ok {
|
|
return fmt.Errorf("Expected time.Time, received %T", value)
|
|
}
|
|
|
|
s := t.Format("2006-01-02")
|
|
return encodeText(w, s)
|
|
}
|
|
|
|
const microsecFromUnixEpochToY2K = 946684800 * 1000000
|
|
|
|
func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time {
|
|
var zeroTime time.Time
|
|
|
|
if fd.DataType != TimestampTzOid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, fd.DataType)))
|
|
return zeroTime
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
t, err := time.Parse("2006-01-02 15:04:05.999999-07", s)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s)))
|
|
return zeroTime
|
|
}
|
|
return t
|
|
case BinaryFormatCode:
|
|
if size != 8 {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size)))
|
|
}
|
|
microsecSinceY2K := rows.mr.ReadInt64()
|
|
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
|
|
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return zeroTime
|
|
}
|
|
}
|
|
|
|
func encodeTimestampTz(w *WriteBuf, value interface{}) error {
|
|
t, ok := value.(time.Time)
|
|
if !ok {
|
|
return fmt.Errorf("Expected time.Time, received %T", value)
|
|
}
|
|
|
|
microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000
|
|
microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K
|
|
|
|
w.WriteInt32(8)
|
|
w.WriteInt64(microsecSinceY2K)
|
|
|
|
return nil
|
|
}
|
|
|
|
func decodeTimestamp(rows *Rows, fd *FieldDescription, size int32) time.Time {
|
|
var zeroTime time.Time
|
|
|
|
if fd.DataType != TimestampOid {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, fd.DataType)))
|
|
return zeroTime
|
|
}
|
|
|
|
switch fd.FormatCode {
|
|
case TextFormatCode:
|
|
s := rows.mr.ReadString(size)
|
|
t, err := time.ParseInLocation("2006-01-02 15:04:05.999999", s, time.Local)
|
|
if err != nil {
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s)))
|
|
return zeroTime
|
|
}
|
|
return t
|
|
case BinaryFormatCode:
|
|
rows.Fatal(ProtocolError("Can't decode binary timestamp"))
|
|
return zeroTime
|
|
default:
|
|
rows.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode)))
|
|
return zeroTime
|
|
}
|
|
}
|
|
|
|
func encodeTimestamp(w *WriteBuf, value interface{}) error {
|
|
t, ok := value.(time.Time)
|
|
if !ok {
|
|
return fmt.Errorf("Expected time.Time, received %T", value)
|
|
}
|
|
|
|
s := t.Format("2006-01-02 15:04:05.999999")
|
|
return encodeText(w, s)
|
|
|
|
return nil
|
|
}
|