Use pgproto3.FieldDescription instead of pgx version
This allows removing a malloc and memcpy.
This commit is contained in:
@@ -56,7 +56,7 @@ type Conn struct {
|
|||||||
type PreparedStatement struct {
|
type PreparedStatement struct {
|
||||||
Name string
|
Name string
|
||||||
SQL string
|
SQL string
|
||||||
FieldDescriptions []FieldDescription
|
FieldDescriptions []pgproto3.FieldDescription
|
||||||
ParameterOIDs []pgtype.OID
|
ParameterOIDs []pgtype.OID
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,15 +213,12 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
|||||||
Name: psd.Name,
|
Name: psd.Name,
|
||||||
SQL: psd.SQL,
|
SQL: psd.SQL,
|
||||||
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
||||||
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
|
FieldDescriptions: psd.Fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range ps.ParameterOIDs {
|
for i := range ps.ParameterOIDs {
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
|
||||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if name != "" {
|
if name != "" {
|
||||||
c.preparedStatements[name] = ps
|
c.preparedStatements[name] = ps
|
||||||
@@ -416,7 +413,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||||||
|
|
||||||
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
resultFormats[i] = BinaryFormatCode
|
resultFormats[i] = BinaryFormatCode
|
||||||
} else {
|
} else {
|
||||||
@@ -453,15 +450,12 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||||||
Name: psd.Name,
|
Name: psd.Name,
|
||||||
SQL: psd.SQL,
|
SQL: psd.SQL,
|
||||||
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
||||||
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
|
FieldDescriptions: psd.Fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range ps.ParameterOIDs {
|
for i := range ps.ParameterOIDs {
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
|
||||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
arguments, err = convertDriverValuers(arguments)
|
arguments, err = convertDriverValuers(arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -481,7 +475,7 @@ func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||||||
|
|
||||||
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
resultFormats := make([]int16, len(ps.FieldDescriptions))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
resultFormats[i] = BinaryFormatCode
|
resultFormats[i] = BinaryFormatCode
|
||||||
} else {
|
} else {
|
||||||
@@ -549,22 +543,6 @@ func newencodePreparedStatementArgument(ci *pgtype.ConnInfo, oid pgtype.OID, arg
|
|||||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||||
}
|
}
|
||||||
|
|
||||||
// pgproto3FieldDescriptionToPgxFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
|
||||||
// FieldDescription.
|
|
||||||
func pgproto3FieldDescriptionToPgxFieldDescription(connInfo *pgtype.ConnInfo, src *pgproto3.FieldDescription, dst *FieldDescription) {
|
|
||||||
dst.Name = string(src.Name)
|
|
||||||
dst.Table = pgtype.OID(src.TableOID)
|
|
||||||
dst.AttributeNumber = src.TableAttributeNumber
|
|
||||||
dst.DataType = pgtype.OID(src.DataTypeOID)
|
|
||||||
dst.DataTypeSize = src.DataTypeSize
|
|
||||||
dst.Modifier = src.TypeModifier
|
|
||||||
dst.FormatCode = src.Format
|
|
||||||
|
|
||||||
if dt, ok := connInfo.DataTypeForOID(dst.DataType); ok {
|
|
||||||
dst.DataTypeName = dt.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
||||||
if len(c.preallocatedRows) == 0 {
|
if len(c.preallocatedRows) == 0 {
|
||||||
c.preallocatedRows = make([]connRows, 64)
|
c.preallocatedRows = make([]connRows, 64)
|
||||||
@@ -628,15 +606,12 @@ optionLoop:
|
|||||||
Name: psd.Name,
|
Name: psd.Name,
|
||||||
SQL: psd.SQL,
|
SQL: psd.SQL,
|
||||||
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)),
|
||||||
FieldDescriptions: make([]FieldDescription, len(psd.Fields)),
|
FieldDescriptions: psd.Fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range ps.ParameterOIDs {
|
for i := range ps.ParameterOIDs {
|
||||||
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i])
|
||||||
}
|
}
|
||||||
for i := range ps.FieldDescriptions {
|
|
||||||
pgproto3FieldDescriptionToPgxFieldDescription(c.ConnInfo, &psd.Fields[i], &ps.FieldDescriptions[i])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
rows.sql = ps.SQL
|
rows.sql = ps.SQL
|
||||||
|
|
||||||
@@ -658,13 +633,13 @@ optionLoop:
|
|||||||
if resultFormatsByOID != nil {
|
if resultFormatsByOID != nil {
|
||||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
resultFormats[i] = resultFormatsByOID[ps.FieldDescriptions[i].DataType]
|
resultFormats[i] = resultFormatsByOID[pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resultFormats == nil {
|
if resultFormats == nil {
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
c.eqb.AppendResultFormat(BinaryFormatCode)
|
c.eqb.AppendResultFormat(BinaryFormatCode)
|
||||||
} else {
|
} else {
|
||||||
@@ -725,7 +700,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||||||
if resultFormats == nil {
|
if resultFormats == nil {
|
||||||
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
resultFormats = make([]int16, len(ps.FieldDescriptions))
|
||||||
for i := range resultFormats {
|
for i := range resultFormats {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok {
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||||
resultFormats[i] = BinaryFormatCode
|
resultFormats[i] = BinaryFormatCode
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+2
-1
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/jackc/pgio"
|
"github.com/jackc/pgio"
|
||||||
|
"github.com/jackc/pgtype"
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
|
|||||||
|
|
||||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||||
for i, val := range values {
|
for i, val := range values {
|
||||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, pgtype.OID(ps.FieldDescriptions[i].DataTypeOID), val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, err
|
return false, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
+3
-68
@@ -2,82 +2,17 @@ package pgx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"math"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgio"
|
"github.com/jackc/pgio"
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
copyData = 'd'
|
copyData = 'd'
|
||||||
copyFail = 'f'
|
copyFail = 'f'
|
||||||
copyDone = 'c'
|
copyDone = 'c'
|
||||||
varHeaderSize = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type FieldDescription struct {
|
|
||||||
Name string
|
|
||||||
Table pgtype.OID
|
|
||||||
AttributeNumber uint16
|
|
||||||
DataType pgtype.OID
|
|
||||||
DataTypeSize int16
|
|
||||||
DataTypeName string
|
|
||||||
Modifier int32
|
|
||||||
FormatCode int16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fd FieldDescription) Length() (int64, bool) {
|
|
||||||
switch fd.DataType {
|
|
||||||
case pgtype.TextOID, pgtype.ByteaOID:
|
|
||||||
return math.MaxInt64, true
|
|
||||||
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
|
|
||||||
return int64(fd.Modifier - varHeaderSize), true
|
|
||||||
default:
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) {
|
|
||||||
switch fd.DataType {
|
|
||||||
case pgtype.NumericOID:
|
|
||||||
mod := fd.Modifier - varHeaderSize
|
|
||||||
precision = int64((mod >> 16) & 0xffff)
|
|
||||||
scale = int64(mod & 0xffff)
|
|
||||||
return precision, scale, true
|
|
||||||
default:
|
|
||||||
return 0, 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fd FieldDescription) Type() reflect.Type {
|
|
||||||
switch fd.DataType {
|
|
||||||
case pgtype.Float8OID:
|
|
||||||
return reflect.TypeOf(float64(0))
|
|
||||||
case pgtype.Float4OID:
|
|
||||||
return reflect.TypeOf(float32(0))
|
|
||||||
case pgtype.Int8OID:
|
|
||||||
return reflect.TypeOf(int64(0))
|
|
||||||
case pgtype.Int4OID:
|
|
||||||
return reflect.TypeOf(int32(0))
|
|
||||||
case pgtype.Int2OID:
|
|
||||||
return reflect.TypeOf(int16(0))
|
|
||||||
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
|
|
||||||
return reflect.TypeOf("")
|
|
||||||
case pgtype.BoolOID:
|
|
||||||
return reflect.TypeOf(false)
|
|
||||||
case pgtype.NumericOID:
|
|
||||||
return reflect.TypeOf(float64(0))
|
|
||||||
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
|
||||||
return reflect.TypeOf(time.Time{})
|
|
||||||
case pgtype.ByteaOID:
|
|
||||||
return reflect.TypeOf([]byte(nil))
|
|
||||||
default:
|
|
||||||
return reflect.TypeOf(new(interface{})).Elem()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
|
func convertDriverValuers(args []interface{}) ([]interface{}, error) {
|
||||||
for i, arg := range args {
|
for i, arg := range args {
|
||||||
switch arg := arg.(type) {
|
switch arg := arg.(type) {
|
||||||
|
|||||||
+8
-7
@@ -1,6 +1,7 @@
|
|||||||
package pool
|
package pool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -8,12 +9,12 @@ type errRows struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (errRows) Close() {}
|
func (errRows) Close() {}
|
||||||
func (e errRows) Err() error { return e.err }
|
func (e errRows) Err() error { return e.err }
|
||||||
func (errRows) FieldDescriptions() []pgx.FieldDescription { return nil }
|
func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil }
|
||||||
func (errRows) Next() bool { return false }
|
func (errRows) Next() bool { return false }
|
||||||
func (e errRows) Scan(dest ...interface{}) error { return e.err }
|
func (e errRows) Scan(dest ...interface{}) error { return e.err }
|
||||||
func (e errRows) Values() ([]interface{}, error) { return nil, e.err }
|
func (e errRows) Values() ([]interface{}, error) { return nil, e.err }
|
||||||
|
|
||||||
type errRow struct {
|
type errRow struct {
|
||||||
err error
|
err error
|
||||||
@@ -42,7 +43,7 @@ func (rows *poolRows) Err() error {
|
|||||||
return rows.r.Err()
|
return rows.r.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *poolRows) FieldDescriptions() []pgx.FieldDescription {
|
func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription {
|
||||||
return rows.r.FieldDescriptions()
|
return rows.r.FieldDescriptions()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -248,7 +248,7 @@ func TestIdentifySystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
for _, fd := range r.FieldDescriptions() {
|
for _, fd := range r.FieldDescriptions() {
|
||||||
t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
|
t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rowCount int
|
var rowCount int
|
||||||
@@ -307,7 +307,7 @@ func TestGetTimelineHistory(t *testing.T) {
|
|||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
for _, fd := range r.FieldDescriptions() {
|
for _, fd := range r.FieldDescriptions() {
|
||||||
t.Logf("Field: %s of type %v", fd.Name, fd.DataType)
|
t.Logf("Field: %s of type %v", fd.Name, fd.DataTypeOID)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rowCount int
|
var rowCount int
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
"github.com/jackc/pgtype"
|
"github.com/jackc/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ type Rows interface {
|
|||||||
Close()
|
Close()
|
||||||
|
|
||||||
Err() error
|
Err() error
|
||||||
FieldDescriptions() []FieldDescription
|
FieldDescriptions() []pgproto3.FieldDescription
|
||||||
|
|
||||||
// Next prepares the next row for reading. It returns true if there is another
|
// Next prepares the next row for reading. It returns true if there is another
|
||||||
// row and false if no more rows are available. It automatically closes rows
|
// row and false if no more rows are available. It automatically closes rows
|
||||||
@@ -77,7 +78,6 @@ type connRows struct {
|
|||||||
logger rowLog
|
logger rowLog
|
||||||
connInfo *pgtype.ConnInfo
|
connInfo *pgtype.ConnInfo
|
||||||
values [][]byte
|
values [][]byte
|
||||||
fields []FieldDescription
|
|
||||||
rowCount int
|
rowCount int
|
||||||
columnIdx int
|
columnIdx int
|
||||||
err error
|
err error
|
||||||
@@ -89,8 +89,8 @@ type connRows struct {
|
|||||||
resultReader *pgconn.ResultReader
|
resultReader *pgconn.ResultReader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) FieldDescriptions() []FieldDescription {
|
func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
|
||||||
return rows.fields
|
return rows.resultReader.FieldDescriptions()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) Close() {
|
func (rows *connRows) Close() {
|
||||||
@@ -140,13 +140,6 @@ func (rows *connRows) Next() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rows.resultReader.NextRow() {
|
if rows.resultReader.NextRow() {
|
||||||
if rows.fields == nil {
|
|
||||||
rrFieldDescriptions := rows.resultReader.FieldDescriptions()
|
|
||||||
rows.fields = make([]FieldDescription, len(rrFieldDescriptions))
|
|
||||||
for i := range rrFieldDescriptions {
|
|
||||||
pgproto3FieldDescriptionToPgxFieldDescription(rows.connInfo, &rrFieldDescriptions[i], &rows.fields[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rows.rowCount++
|
rows.rowCount++
|
||||||
rows.columnIdx = 0
|
rows.columnIdx = 0
|
||||||
rows.values = rows.resultReader.Values()
|
rows.values = rows.resultReader.Values()
|
||||||
@@ -157,24 +150,24 @@ func (rows *connRows) Next() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) {
|
func (rows *connRows) nextColumn() ([]byte, *pgproto3.FieldDescription, bool) {
|
||||||
if rows.closed {
|
if rows.closed {
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
if len(rows.fields) <= rows.columnIdx {
|
if len(rows.FieldDescriptions()) <= rows.columnIdx {
|
||||||
rows.fatal(ProtocolError("No next column available"))
|
rows.fatal(ProtocolError("No next column available"))
|
||||||
return nil, nil, false
|
return nil, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := rows.values[rows.columnIdx]
|
buf := rows.values[rows.columnIdx]
|
||||||
fd := &rows.fields[rows.columnIdx]
|
fd := &rows.FieldDescriptions()[rows.columnIdx]
|
||||||
rows.columnIdx++
|
rows.columnIdx++
|
||||||
return buf, fd, true
|
return buf, fd, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) Scan(dest ...interface{}) error {
|
func (rows *connRows) Scan(dest ...interface{}) error {
|
||||||
if len(rows.fields) != len(dest) {
|
if len(rows.FieldDescriptions()) != len(dest) {
|
||||||
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
|
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.FieldDescriptions()))
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -186,7 +179,7 @@ func (rows *connRows) Scan(dest ...interface{}) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := rows.connInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
err := rows.connInfo.Scan(pgtype.OID(fd.DataTypeOID), fd.Format, buf, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
rows.fatal(scanArgError{col: i, err: err})
|
||||||
return err
|
return err
|
||||||
@@ -201,9 +194,9 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||||||
return nil, errors.New("rows is closed")
|
return nil, errors.New("rows is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
values := make([]interface{}, 0, len(rows.fields))
|
values := make([]interface{}, 0, len(rows.FieldDescriptions()))
|
||||||
|
|
||||||
for range rows.fields {
|
for range rows.FieldDescriptions() {
|
||||||
buf, fd, _ := rows.nextColumn()
|
buf, fd, _ := rows.nextColumn()
|
||||||
|
|
||||||
if buf == nil {
|
if buf == nil {
|
||||||
@@ -211,10 +204,10 @@ func (rows *connRows) Values() ([]interface{}, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if dt, ok := rows.connInfo.DataTypeForOID(fd.DataType); ok {
|
if dt, ok := rows.connInfo.DataTypeForOID(pgtype.OID(fd.DataTypeOID)); ok {
|
||||||
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value)
|
||||||
|
|
||||||
switch fd.FormatCode {
|
switch fd.Format {
|
||||||
case TextFormatCode:
|
case TextFormatCode:
|
||||||
decoder := value.(pgtype.TextDecoder)
|
decoder := value.(pgtype.TextDecoder)
|
||||||
if decoder == nil {
|
if decoder == nil {
|
||||||
|
|||||||
+59
-7
@@ -74,6 +74,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -260,7 +261,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
|
|||||||
|
|
||||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||||
more := rows.Next()
|
more := rows.Next()
|
||||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Ping(ctx context.Context) error {
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
@@ -301,6 +302,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
|
conn *Conn
|
||||||
rows pgx.Rows
|
rows pgx.Rows
|
||||||
values []interface{}
|
values []interface{}
|
||||||
skipNext bool
|
skipNext bool
|
||||||
@@ -311,32 +313,82 @@ func (r *Rows) Columns() []string {
|
|||||||
fieldDescriptions := r.rows.FieldDescriptions()
|
fieldDescriptions := r.rows.FieldDescriptions()
|
||||||
names := make([]string, 0, len(fieldDescriptions))
|
names := make([]string, 0, len(fieldDescriptions))
|
||||||
for _, fd := range fieldDescriptions {
|
for _, fd := range fieldDescriptions {
|
||||||
names = append(names, fd.Name)
|
names = append(names, string(fd.Name))
|
||||||
}
|
}
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnTypeDatabaseTypeName return the database system type name.
|
// ColumnTypeDatabaseTypeName return the database system type name.
|
||||||
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||||
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
|
if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(pgtype.OID(r.rows.FieldDescriptions()[index].DataTypeOID)); ok {
|
||||||
|
return strings.ToUpper(dt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const varHeaderSize = 4
|
||||||
|
|
||||||
// ColumnTypeLength returns the length of the column type if the column is a
|
// ColumnTypeLength returns the length of the column type if the column is a
|
||||||
// variable length type. If the column is not a variable length type ok
|
// variable length type. If the column is not a variable length type ok
|
||||||
// should return false.
|
// should return false.
|
||||||
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
||||||
return r.rows.FieldDescriptions()[index].Length()
|
fd := r.rows.FieldDescriptions()[index]
|
||||||
|
|
||||||
|
switch fd.DataTypeOID {
|
||||||
|
case pgtype.TextOID, pgtype.ByteaOID:
|
||||||
|
return math.MaxInt64, true
|
||||||
|
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
|
||||||
|
return int64(fd.TypeModifier - varHeaderSize), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
||||||
// types. If not applicable, ok should be false.
|
// types. If not applicable, ok should be false.
|
||||||
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||||
return r.rows.FieldDescriptions()[index].PrecisionScale()
|
fd := r.rows.FieldDescriptions()[index]
|
||||||
|
|
||||||
|
switch fd.DataTypeOID {
|
||||||
|
case pgtype.NumericOID:
|
||||||
|
mod := fd.TypeModifier - varHeaderSize
|
||||||
|
precision = int64((mod >> 16) & 0xffff)
|
||||||
|
scale = int64(mod & 0xffff)
|
||||||
|
return precision, scale, true
|
||||||
|
default:
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
||||||
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
||||||
return r.rows.FieldDescriptions()[index].Type()
|
fd := r.rows.FieldDescriptions()[index]
|
||||||
|
|
||||||
|
switch fd.DataTypeOID {
|
||||||
|
case pgtype.Float8OID:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case pgtype.Float4OID:
|
||||||
|
return reflect.TypeOf(float32(0))
|
||||||
|
case pgtype.Int8OID:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case pgtype.Int4OID:
|
||||||
|
return reflect.TypeOf(int32(0))
|
||||||
|
case pgtype.Int2OID:
|
||||||
|
return reflect.TypeOf(int16(0))
|
||||||
|
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case pgtype.BoolOID:
|
||||||
|
return reflect.TypeOf(false)
|
||||||
|
case pgtype.NumericOID:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case pgtype.ByteaOID:
|
||||||
|
return reflect.TypeOf([]byte(nil))
|
||||||
|
default:
|
||||||
|
return reflect.TypeOf(new(interface{})).Elem()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rows) Close() error {
|
func (r *Rows) Close() error {
|
||||||
@@ -348,7 +400,7 @@ func (r *Rows) Next(dest []driver.Value) error {
|
|||||||
if r.values == nil {
|
if r.values == nil {
|
||||||
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
|
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
|
||||||
for i, fd := range r.rows.FieldDescriptions() {
|
for i, fd := range r.rows.FieldDescriptions() {
|
||||||
switch fd.DataType {
|
switch fd.DataTypeOID {
|
||||||
case pgtype.BoolOID:
|
case pgtype.BoolOID:
|
||||||
r.values[i] = &pgtype.Bool{}
|
r.values[i] = &pgtype.Bool{}
|
||||||
case pgtype.ByteaOID:
|
case pgtype.ByteaOID:
|
||||||
|
|||||||
Reference in New Issue
Block a user