From f756d9d5919fa50c65dc83821c195d1bbae850e9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 12 Apr 2019 21:31:59 -0500 Subject: [PATCH] Extract scan value to pgtype --- pgtype/pgtype.go | 54 +++++++++++++++++++++++++++++++++++++++ query.go | 66 +++++------------------------------------------- 2 files changed, 60 insertions(+), 60 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 2643314e..8f41d068 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql" "reflect" "github.com/pkg/errors" @@ -84,6 +85,12 @@ func (im InfinityModifier) String() string { } } +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error @@ -207,6 +214,53 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { return ci2 } +func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error { + if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode { + return dest.DecodeBinary(ci, buf) + } + + if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode { + return dest.DecodeText(ci, buf) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + switch formatCode { + case TextFormatCode: + if textDecoder, ok := value.(TextDecoder); ok { + err := textDecoder.DecodeText(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.TextDecoder", value) + } + case BinaryFormatCode: + if binaryDecoder, ok := value.(BinaryDecoder); ok { + err := binaryDecoder.DecodeBinary(ci, buf) + if err != nil { + return err + } + } else { + return errors.Errorf("%T is not a pgtype.BinaryDecoder", value) + } + default: + return errors.Errorf("unknown format code: %v", formatCode) + } + + if scanner, ok := dest.(sql.Scanner); ok { + sqlSrc, err := DatabaseSQLValue(ci, value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) + } else { + return value.AssignTo(dest) + } + } + return errors.Errorf("unknown oid: %v", oid) +} + var nameValues map[string]Value func init() { diff --git a/query.go b/query.go index a2402b51..5cb503ba 100644 --- a/query.go +++ b/query.go @@ -2,7 +2,6 @@ package pgx import ( "context" - "database/sql" "fmt" "reflect" "time" @@ -186,9 +185,9 @@ func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) { return buf, fd, true } -func (rows *connRows) Scan(dest ...interface{}) (err error) { +func (rows *connRows) Scan(dest ...interface{}) error { if len(rows.fields) != len(dest) { - err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) + err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) rows.fatal(err) return err } @@ -200,63 +199,10 @@ func (rows *connRows) Scan(dest ...interface{}) (err error) { continue } - if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { - err = s.DecodeBinary(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { - err = s.DecodeText(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { - value := dt.Value - switch fd.FormatCode { - case TextFormatCode: - if textDecoder, ok := value.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) - } - case BinaryFormatCode: - if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) - } - default: - rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) - } - - if rows.Err() == nil { - if scanner, ok := d.(sql.Scanner); ok { - sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.fatal(err) - } - err = scanner.Scan(sqlSrc) - if err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } else if err := value.AssignTo(d); err != nil { - rows.fatal(scanArgError{col: i, err: err}) - } - } - } else { - rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) - } - } - - if rows.Err() != nil { - return rows.Err() + err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d) + if err != nil { + rows.fatal(scanArgError{col: i, err: err}) + return err } }