diff --git a/pgtype.go b/pgtype.go index af8d8661..32c6da5a 100644 --- a/pgtype.go +++ b/pgtype.go @@ -2,6 +2,8 @@ package pgtype import ( "database/sql" + "encoding/binary" + "math" "net" "reflect" "time" @@ -472,76 +474,93 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } -func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interface{}) error { +// ScanPlan is a precompiled plan to scan into a particular destination. This requires care to use as it always scans +// to the same destination. +// +// This is a very low-level optimization. It should only be used to implement a PostgreSQL driver or custom type. +type ScanPlan interface { + // Scan scans src into dst. All parameters except src MUST be the same as were passed to PlanScan when this was + // created. + Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error +} + +type scanPlanDstBinaryDecoder struct { + d BinaryDecoder +} + +func (plan scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.d.DecodeBinary(ci, src) +} + +type scanPlanDstTextDecoder struct { + d TextDecoder +} + +func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + return plan.d.DecodeText(ci, src) +} + +type scanPlanDataTypeSQLScanner DataType + +func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error switch formatCode { case BinaryFormatCode: - if dest, ok := dest.(BinaryDecoder); ok { - return dest.DecodeBinary(ci, buf) - } + err = dt.binaryDecoder.DecodeBinary(ci, src) case TextFormatCode: - if dest, ok := dest.(TextDecoder); ok { - return dest.DecodeText(ci, buf) - } - default: - return errors.Errorf("unknown format code: %v", formatCode) + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err } - var dt *DataType + scanner := dst.(sql.Scanner) + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) +} - if oid == 0 { - if dataType, ok := ci.DataTypeForValue(dest); ok { - dt = dataType - } +type scanPlanDataTypeAssignTo DataType + +func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err + } + + return dt.Value.AssignTo(dst) +} + +type scanPlanSQLScanner struct{} + +func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := dst.(sql.Scanner) + if formatCode == BinaryFormatCode { + return scanner.Scan(src) } else { - if dataType, ok := ci.DataTypeForOID(oid); ok { - dt = dataType - } + return scanner.Scan(string(src)) } +} - if dt != nil { - switch formatCode { - case BinaryFormatCode: - if dt.binaryDecoder != nil { - err := dt.binaryDecoder.DecodeBinary(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.BinaryDecoder", dt.Value) - } - case TextFormatCode: - if dt.textDecoder != nil { - err := dt.textDecoder.DecodeText(ci, buf) - if err != nil { - return err - } - } else { - return errors.Errorf("%T is not a pgtype.TextDecoder", dt.Value) - } - } - - assignToErr := dt.Value.AssignTo(dest) - if assignToErr == nil { - return nil - } - - if scanner, ok := dest.(sql.Scanner); ok { - sqlSrc, err := DatabaseSQLValue(ci, dt.Value) - if err != nil { - return err - } - return scanner.Scan(sqlSrc) - } - - return assignToErr - } +type scanPlanReflection struct{} +func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { // We might be given a pointer to something that implements the decoder interface(s), // even though the pointer itself doesn't. - refVal := reflect.ValueOf(dest) + refVal := reflect.ValueOf(dst) if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { // If the database returned NULL, then we set dest as nil to indicate that. - if buf == nil { + if src == nil { nilPtr := reflect.Zero(refVal.Type().Elem()) refVal.Elem().Set(nilPtr) return nil @@ -551,10 +570,185 @@ func (ci *ConnInfo) Scan(oid uint32, formatCode int16, buf []byte, dest interfac // Then we can retry as that element. elemPtr := reflect.New(refVal.Type().Elem().Elem()) refVal.Elem().Set(elemPtr) - return ci.Scan(oid, formatCode, buf, elemPtr.Interface()) + + plan := ci.PlanScan(oid, formatCode, src, elemPtr.Interface()) + return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) } - return scanUnknownType(oid, formatCode, buf, dest) + return scanUnknownType(oid, formatCode, src, dst) +} + +type scanPlanBinaryInt16 int16 + +func (plan *scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return errors.Errorf("invalid length for int2: %v", len(src)) + } + + *plan = scanPlanBinaryInt16(binary.BigEndian.Uint16(src)) + return nil +} + +type scanPlanBinaryInt32 int32 + +func (plan *scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return errors.Errorf("invalid length for int4: %v", len(src)) + } + + *plan = scanPlanBinaryInt32(binary.BigEndian.Uint32(src)) + return nil +} + +type scanPlanBinaryInt64 int64 + +func (plan *scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return errors.Errorf("invalid length for int8: %v", len(src)) + } + + *plan = scanPlanBinaryInt64(binary.BigEndian.Uint64(src)) + return nil +} + +type scanPlanBinaryFloat32 float32 + +func (plan *scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return errors.Errorf("invalid length for int4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + *plan = scanPlanBinaryFloat32(math.Float32frombits(uint32(n))) + return nil +} + +type scanPlanBinaryFloat64 float64 + +func (plan *scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return errors.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + *plan = scanPlanBinaryFloat64(math.Float64frombits(uint64(n))) + return nil +} + +type scanPlanBinaryBytes []byte + +func (plan *scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + *plan = scanPlanBinaryBytes(src) + return nil +} + +type scanPlanString string + +func (plan *scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return errors.Errorf("cannot scan null into %T", dst) + } + + *plan = scanPlanString(src) + return nil +} + +// PlanScan prepares a plan to scan a value into dst. +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, buf []byte, dst interface{}) ScanPlan { + switch formatCode { + case BinaryFormatCode: + switch d := dst.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return (*scanPlanString)(d) + } + case *int16: + if oid == Int2OID { + return (*scanPlanBinaryInt16)(d) + } + case *int32: + if oid == Int4OID { + return (*scanPlanBinaryInt32)(d) + } + case *int64: + if oid == Int8OID { + return (*scanPlanBinaryInt64)(d) + } + case *float32: + if oid == Float4OID { + return (*scanPlanBinaryFloat32)(d) + } + case *float64: + if oid == Float8OID { + return (*scanPlanBinaryFloat64)(d) + } + case *[]byte: + switch oid { + case ByteaOID, TextOID, VarcharOID: + return (*scanPlanBinaryBytes)(d) + } + case BinaryDecoder: + return scanPlanDstBinaryDecoder{d: d} + } + case TextFormatCode: + switch d := dst.(type) { + case *string: + return (*scanPlanString)(d) + case TextDecoder: + return scanPlanDstTextDecoder{d: d} + } + } + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dst); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { + if _, ok := dst.(sql.Scanner); ok { + return (*scanPlanDataTypeSQLScanner)(dt) + } + return (*scanPlanDataTypeAssignTo)(dt) + } + + if _, ok := dst.(sql.Scanner); ok { + return scanPlanSQLScanner{} + } + + return scanPlanReflection{} +} + +func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { + plan := ci.PlanScan(oid, formatCode, src, dst) + return plan.Scan(ci, oid, formatCode, src, dst) } func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { diff --git a/pgtype_test.go b/pgtype_test.go index 664c5394..45b1b64d 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -170,3 +170,41 @@ func BenchmarkConnInfoScanInt4IntoGoInt32(b *testing.B) { } } } + +func BenchmarkScanPlanScanInt4IntoBinaryDecoder(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v pgtype.Int4 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + + for i := 0; i < b.N; i++ { + v = pgtype.Int4{} + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != (pgtype.Int4{Int: 42, Status: pgtype.Present}) { + b.Fatal("scan failed due to bad value") + } + } +} + +func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { + ci := pgtype.NewConnInfo() + src := []byte{0, 0, 0, 42} + var v int32 + + plan := ci.PlanScan(pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + + for i := 0; i < b.N; i++ { + v = 0 + err := plan.Scan(ci, pgtype.Int4OID, pgtype.BinaryFormatCode, src, &v) + if err != nil { + b.Fatal(err) + } + if v != 42 { + b.Fatal("scan failed due to bad value") + } + } +}