2
0

Remove unneeded WriteBuf

This commit is contained in:
Jack Christensen
2017-05-02 21:26:45 -05:00
parent 6e64a0c867
commit 458dd24a9f
6 changed files with 173 additions and 211 deletions
+75 -43
View File
@@ -20,6 +20,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@@ -86,8 +87,7 @@ func (cc *ConnConfig) networkAddress() (network, address string) {
type Conn struct { type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection conn net.Conn // the underlying TCP or unix domain socket connection
lastActivityTime time.Time // the last time the connection was used lastActivityTime time.Time // the last time the connection was used
wbuf [1024]byte wbuf []byte
writeBuf WriteBuf
pid uint32 // backend pid pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server RuntimeParams map[string]string // parameters that have been reported by the server
@@ -279,6 +279,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.cancelQueryCompleted = make(chan struct{}, 1) c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{}) c.doneChan = make(chan struct{})
c.closedChan = make(chan error) c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
if tlsConfig != nil { if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) { if c.shouldLog(LogLevelDebug) {
@@ -707,32 +708,42 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
} }
// parse // parse
wbuf := newWriteBuf(c, 'P') buf := c.wbuf
wbuf.WriteCString(name) buf = append(buf, 'P')
wbuf.WriteCString(sql) sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, name...)
buf = append(buf, 0)
buf = append(buf, sql...)
buf = append(buf, 0)
if opts != nil { if opts != nil {
if len(opts.ParameterOids) > 65535 { if len(opts.ParameterOids) > 65535 {
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))
} }
wbuf.WriteInt16(int16(len(opts.ParameterOids))) buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids)))
for _, oid := range opts.ParameterOids { for _, oid := range opts.ParameterOids {
wbuf.WriteInt32(int32(oid)) buf = pgio.AppendInt32(buf, int32(oid))
} }
} else { } else {
wbuf.WriteInt16(0) buf = pgio.AppendInt16(buf, 0)
} }
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// describe // describe
wbuf.startMsg('D') buf = append(buf, 'D')
wbuf.WriteByte('S') sp = len(buf)
wbuf.WriteCString(name) buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// sync // sync
wbuf.startMsg('S') buf = append(buf, 'S')
wbuf.closeMsg() buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf) _, err = c.conn.Write(buf)
if err != nil { if err != nil {
c.die(err) c.die(err)
return nil, err return nil, err
@@ -813,15 +824,20 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
delete(c.preparedStatements, name) delete(c.preparedStatements, name)
// close // close
wbuf := newWriteBuf(c, 'C') buf := c.wbuf
wbuf.WriteByte('S') buf = append(buf, 'C')
wbuf.WriteCString(name) sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// flush // flush
wbuf.startMsg('H') buf = append(buf, 'H')
wbuf.closeMsg() buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf) _, err = c.conn.Write(buf)
if err != nil { if err != nil {
c.die(err) c.die(err)
return err return err
@@ -943,11 +959,15 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
} }
if len(args) == 0 { if len(args) == 0 {
wbuf := newWriteBuf(c, 'Q') buf := c.wbuf
wbuf.WriteCString(sql) buf = append(buf, 'Q')
wbuf.closeMsg() sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, sql...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := c.conn.Write(wbuf.buf) _, err := c.conn.Write(buf)
if err != nil { if err != nil {
c.die(err) c.die(err)
return err return err
@@ -975,37 +995,45 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
} }
// bind // bind
wbuf := newWriteBuf(c, 'B') buf := c.wbuf
wbuf.WriteByte(0) buf = append(buf, 'B')
wbuf.WriteCString(ps.Name) sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 0)
buf = append(buf, ps.Name...)
buf = append(buf, 0)
wbuf.WriteInt16(int16(len(ps.ParameterOids))) buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids)))
for i, oid := range ps.ParameterOids { for i, oid := range ps.ParameterOids {
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
} }
wbuf.WriteInt16(int16(len(arguments))) buf = pgio.AppendInt16(buf, int16(len(arguments)))
for i, oid := range ps.ParameterOids { for i, oid := range ps.ParameterOids {
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil { var err error
buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i])
if err != nil {
return err return err
} }
} }
wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions)))
for _, fd := range ps.FieldDescriptions { for _, fd := range ps.FieldDescriptions {
wbuf.WriteInt16(fd.FormatCode) buf = pgio.AppendInt16(buf, fd.FormatCode)
} }
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// execute // execute
wbuf.startMsg('E') buf = append(buf, 'E')
wbuf.WriteByte(0) buf = pgio.AppendInt32(buf, 9)
wbuf.WriteInt32(0) buf = append(buf, 0)
buf = pgio.AppendInt32(buf, 0)
// sync // sync
wbuf.startMsg('S') buf = append(buf, 'S')
wbuf.closeMsg() buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(wbuf.buf) _, err = c.conn.Write(buf)
if err != nil { if err != nil {
c.die(err) c.die(err)
} }
@@ -1180,11 +1208,15 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error {
} }
func (c *Conn) txPasswordMessage(password string) (err error) { func (c *Conn) txPasswordMessage(password string) (err error) {
wbuf := newWriteBuf(c, 'p') buf := c.wbuf
wbuf.WriteCString(password) buf = append(buf, 'p')
wbuf.closeMsg() sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, password...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = c.conn.Write(wbuf.buf) _, err = c.conn.Write(buf)
return err return err
} }
+33 -24
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
) )
@@ -89,14 +90,14 @@ func (ct *copyFrom) waitForReaderDone() error {
func (ct *copyFrom) run() (int, error) { func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize() quotedTableName := ct.tableName.Sanitize()
buf := &bytes.Buffer{} cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames { for i, cn := range ct.columnNames {
if i != 0 { if i != 0 {
buf.WriteString(", ") cbuf.WriteString(", ")
} }
buf.WriteString(quoteIdentifier(cn)) cbuf.WriteString(quoteIdentifier(cn))
} }
quotedColumnNames := buf.String() quotedColumnNames := cbuf.String()
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
if err != nil { if err != nil {
@@ -116,11 +117,14 @@ func (ct *copyFrom) run() (int, error) {
go ct.readUntilReadyForQuery() go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone() defer ct.waitForReaderDone()
wbuf := newWriteBuf(ct.conn, copyData) buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) buf = append(buf, "PGCOPY\n\377\r\n\000"...)
wbuf.WriteInt32(0) buf = pgio.AppendInt32(buf, 0)
wbuf.WriteInt32(0) buf = pgio.AppendInt32(buf, 0)
var sentCount int var sentCount int
@@ -131,18 +135,16 @@ func (ct *copyFrom) run() (int, error) {
default: default:
} }
if len(wbuf.buf) > 65536 { if len(buf) > 65536 {
wbuf.closeMsg() pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.conn.Write(wbuf.buf) _, err = ct.conn.conn.Write(buf)
if err != nil { if err != nil {
ct.conn.die(err) ct.conn.die(err)
return 0, err return 0, err
} }
// Directly manipulate wbuf to reset to reuse the same buffer // Directly manipulate wbuf to reset to reuse the same buffer
wbuf.buf = wbuf.buf[0:5] buf = buf[0:5]
wbuf.buf[0] = copyData
wbuf.sizeIdx = 1
} }
sentCount++ sentCount++
@@ -157,9 +159,9 @@ func (ct *copyFrom) run() (int, error) {
return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
} }
wbuf.WriteInt16(int16(len(ct.columnNames))) buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values { for i, val := range values {
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val) buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil { if err != nil {
ct.cancelCopyIn() ct.cancelCopyIn()
return 0, err return 0, err
@@ -173,11 +175,13 @@ func (ct *copyFrom) run() (int, error) {
return 0, ct.rowSrc.Err() return 0, ct.rowSrc.Err()
} }
wbuf.WriteInt16(-1) // terminate the copy stream buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
wbuf.startMsg(copyDone) buf = append(buf, copyDone)
wbuf.closeMsg() buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.conn.Write(wbuf.buf)
_, err = ct.conn.conn.Write(buf)
if err != nil { if err != nil {
ct.conn.die(err) ct.conn.die(err)
return 0, err return 0, err
@@ -210,10 +214,15 @@ func (c *Conn) readUntilCopyInResponse() error {
} }
func (ct *copyFrom) cancelCopyIn() error { func (ct *copyFrom) cancelCopyIn() error {
wbuf := newWriteBuf(ct.conn, copyFail) buf := ct.conn.wbuf
wbuf.WriteCString("client error: abort") buf = append(buf, copyFail)
wbuf.closeMsg() sp := len(buf)
_, err := ct.conn.conn.Write(wbuf.buf) buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.conn.Write(buf)
if err != nil { if err != nil {
ct.conn.die(err) ct.conn.die(err)
return err return err
+17 -12
View File
@@ -3,6 +3,7 @@ package pgx
import ( import (
"encoding/binary" "encoding/binary"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@@ -55,19 +56,23 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) {
return nil, err return nil, err
} }
wbuf := newWriteBuf(f.cn, 'F') // function call buf := f.cn.wbuf
wbuf.WriteInt32(int32(oid)) // function object id buf = append(buf, 'F') // function call
wbuf.WriteInt16(1) // # of argument format codes sp := len(buf)
wbuf.WriteInt16(1) // format code: binary buf = pgio.AppendInt32(buf, -1)
wbuf.WriteInt16(int16(len(args))) // # of arguments
for _, arg := range args {
wbuf.WriteInt32(int32(len(arg))) // length of argument
wbuf.WriteBytes(arg) // argument value
}
wbuf.WriteInt16(1) // response format code (binary)
wbuf.closeMsg()
if _, err := f.cn.conn.Write(wbuf.buf); err != nil { buf = pgio.AppendInt32(buf, int32(oid)) // function object id
buf = pgio.AppendInt16(buf, 1) // # of argument format codes
buf = pgio.AppendInt16(buf, 1) // format code: binary
buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments
for _, arg := range args {
buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument
buf = append(buf, arg...) // argument value
}
buf = pgio.AppendInt16(buf, 1) // response format code (binary)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
if _, err := f.cn.conn.Write(buf); err != nil {
return nil, err return nil, err
} }
-87
View File
@@ -92,90 +92,3 @@ type PgError struct {
func (pe PgError) Error() string { func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
} }
func newWriteBuf(c *Conn, t byte) *WriteBuf {
buf := append(c.wbuf[0:0], t, 0, 0, 0, 0)
c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c}
return &c.writeBuf
}
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
// by the Encoder interface when implementing custom encoders.
type WriteBuf struct {
buf []byte
convBuf [8]byte
sizeIdx int
conn *Conn
}
func (wb *WriteBuf) startMsg(t byte) {
wb.closeMsg()
wb.buf = append(wb.buf, t, 0, 0, 0, 0)
wb.sizeIdx = len(wb.buf) - 4
}
func (wb *WriteBuf) closeMsg() {
binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx))
}
func (wb *WriteBuf) reserveSize() int {
sizePosition := len(wb.buf)
wb.buf = append(wb.buf, 0, 0, 0, 0)
return sizePosition
}
func (wb *WriteBuf) setComputedSize(sizePosition int) {
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4))
}
func (wb *WriteBuf) setSize(sizePosition int, size int32) {
binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size))
}
func (wb *WriteBuf) WriteByte(b byte) {
wb.buf = append(wb.buf, b)
}
func (wb *WriteBuf) WriteCString(s string) {
wb.buf = append(wb.buf, []byte(s)...)
wb.buf = append(wb.buf, 0)
}
func (wb *WriteBuf) WriteInt16(n int16) {
wb.WriteUint16(uint16(n))
}
func (wb *WriteBuf) WriteUint16(n uint16) (int, error) {
binary.BigEndian.PutUint16(wb.convBuf[:2], n)
wb.buf = append(wb.buf, wb.convBuf[:2]...)
return 2, nil
}
func (wb *WriteBuf) WriteInt32(n int32) {
wb.WriteUint32(uint32(n))
}
func (wb *WriteBuf) WriteUint32(n uint32) (int, error) {
binary.BigEndian.PutUint32(wb.convBuf[:4], n)
wb.buf = append(wb.buf, wb.convBuf[:4]...)
return 4, nil
}
func (wb *WriteBuf) WriteInt64(n int64) {
wb.WriteUint64(uint64(n))
}
func (wb *WriteBuf) WriteUint64(n uint64) (int, error) {
binary.BigEndian.PutUint64(wb.convBuf[:8], n)
wb.buf = append(wb.buf, wb.convBuf[:8]...)
return 8, nil
}
func (wb *WriteBuf) WriteBytes(b []byte) {
wb.buf = append(wb.buf, b...)
}
func (wb *WriteBuf) Write(b []byte) (int, error) {
wb.buf = append(wb.buf, b...)
return len(b), nil
}
+14 -9
View File
@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
) )
@@ -175,17 +176,21 @@ type ReplicationConn struct {
// message to the server, as well as carries the WAL position of the // message to the server, as well as carries the WAL position of the
// client, which then updates the server's replication slot position. // client, which then updates the server's replication slot position.
func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) {
writeBuf := newWriteBuf(rc.c, copyData) buf := rc.c.wbuf
writeBuf.WriteByte(standbyStatusUpdate) buf = append(buf, copyData)
writeBuf.WriteInt64(int64(k.WalWritePosition)) sp := len(buf)
writeBuf.WriteInt64(int64(k.WalFlushPosition)) buf = pgio.AppendInt32(buf, -1)
writeBuf.WriteInt64(int64(k.WalApplyPosition))
writeBuf.WriteInt64(int64(k.ClientTime))
writeBuf.WriteByte(k.ReplyRequested)
writeBuf.closeMsg() buf = append(buf, standbyStatusUpdate)
buf = pgio.AppendInt64(buf, int64(k.WalWritePosition))
buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition))
buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition))
buf = pgio.AppendInt64(buf, int64(k.ClientTime))
buf = append(buf, k.ReplyRequested)
_, err = rc.c.conn.Write(writeBuf.buf) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = rc.c.conn.Write(buf)
if err != nil { if err != nil {
rc.c.die(err) rc.c.die(err)
} }
+34 -36
View File
@@ -97,84 +97,82 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
} }
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) {
if arg == nil { if arg == nil {
wbuf.WriteInt32(-1) return pgio.AppendInt32(buf, -1), nil
return nil
} }
switch arg := arg.(type) { switch arg := arg.(type) {
case pgtype.BinaryEncoder: case pgtype.BinaryEncoder:
sp := len(wbuf.buf) sp := len(buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) buf = pgio.AppendInt32(buf, -1)
argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) argBuf, err := arg.EncodeBinary(ci, buf)
if err != nil { if err != nil {
return err return nil, err
} }
if argBuf != nil { if argBuf != nil {
wbuf.buf = argBuf buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
} }
return nil return buf, nil
case pgtype.TextEncoder: case pgtype.TextEncoder:
sp := len(wbuf.buf) sp := len(buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) buf = pgio.AppendInt32(buf, -1)
argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf) argBuf, err := arg.EncodeText(ci, buf)
if err != nil { if err != nil {
return err return nil, err
} }
if argBuf != nil { if argBuf != nil {
wbuf.buf = argBuf buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
} }
return nil return buf, nil
case driver.Valuer: case driver.Valuer:
v, err := arg.Value() v, err := arg.Value()
if err != nil { if err != nil {
return err return nil, err
} }
return encodePreparedStatementArgument(wbuf, oid, v) return encodePreparedStatementArgument(ci, buf, oid, v)
case string: case string:
wbuf.WriteInt32(int32(len(arg))) buf = pgio.AppendInt32(buf, int32(len(arg)))
wbuf.WriteBytes([]byte(arg)) buf = append(buf, arg...)
return nil return buf, nil
} }
refVal := reflect.ValueOf(arg) refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr { if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() { if refVal.IsNil() {
wbuf.WriteInt32(-1) return pgio.AppendInt32(buf, -1), nil
return nil
} }
arg = refVal.Elem().Interface() arg = refVal.Elem().Interface()
return encodePreparedStatementArgument(wbuf, oid, arg) return encodePreparedStatementArgument(ci, buf, oid, arg)
} }
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { if dt, ok := ci.DataTypeForOid(oid); ok {
value := dt.Value value := dt.Value
err := value.Set(arg) err := value.Set(arg)
if err != nil { if err != nil {
return err return nil, err
} }
sp := len(wbuf.buf) sp := len(buf)
wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) buf = pgio.AppendInt32(buf, -1)
argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
if err != nil { if err != nil {
return err return nil, err
} }
if argBuf != nil { if argBuf != nil {
wbuf.buf = argBuf buf = argBuf
pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
} }
return nil return buf, nil
} }
if strippedArg, ok := stripNamedType(&refVal); ok { if strippedArg, ok := stripNamedType(&refVal); ok {
return encodePreparedStatementArgument(wbuf, oid, strippedArg) return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
} }
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
} }
// chooseParameterFormatCode determines the correct format code for an // chooseParameterFormatCode determines the correct format code for an