Partial conversion of pgx to use pgconn
This commit is contained in:
+41
-218
@@ -2,12 +2,11 @@ package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -58,39 +57,6 @@ type copyFrom struct {
|
||||
readerErrChan chan error
|
||||
}
|
||||
|
||||
func (ct *copyFrom) readUntilReadyForQuery() {
|
||||
for {
|
||||
msg, err := ct.conn.rxMsg()
|
||||
if err != nil {
|
||||
ct.readerErrChan <- err
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
ct.conn.rxReadyForQuery(msg)
|
||||
close(ct.readerErrChan)
|
||||
return
|
||||
case *pgproto3.CommandComplete:
|
||||
case *pgproto3.ErrorResponse:
|
||||
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = ct.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) waitForReaderDone() error {
|
||||
var err error
|
||||
for err = range ct.readerErrChan {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run() (int, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
@@ -107,163 +73,74 @@ func (ct *copyFrom) run() (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r, w := io.Pipe()
|
||||
|
||||
err = ct.conn.readUntilCopyInResponse()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
go func() {
|
||||
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
||||
buf := ct.conn.wbuf
|
||||
|
||||
panicked := true
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
go ct.readUntilReadyForQuery()
|
||||
defer ct.waitForReaderDone()
|
||||
defer func() {
|
||||
if panicked {
|
||||
ct.conn.die(errors.New("panic while in copy from"))
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
var err error
|
||||
moreRows, buf, err = ct.buildCopyBuf(buf, ps)
|
||||
if err != nil {
|
||||
w.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
w.CloseWithError(ct.rowSrc.Err())
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
_, err = w.Write(buf)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
}
|
||||
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
var sentCount int
|
||||
|
||||
moreRows := true
|
||||
for moreRows {
|
||||
select {
|
||||
case err = <-ct.readerErrChan:
|
||||
panicked = false
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
|
||||
var addedRows int
|
||||
var err error
|
||||
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
}
|
||||
sentCount += addedRows
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err = ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Directly manipulate wbuf to reset to reuse the same buffer
|
||||
buf = buf[0:5]
|
||||
|
||||
}
|
||||
|
||||
if ct.rowSrc.Err() != nil {
|
||||
panicked = false
|
||||
ct.cancelCopyIn()
|
||||
return 0, ct.rowSrc.Err()
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
_, err = ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
panicked = false
|
||||
ct.conn.die(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = ct.waitForReaderDone()
|
||||
if err != nil {
|
||||
panicked = false
|
||||
return 0, err
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return sentCount, nil
|
||||
return int(commandTag.RowsAffected()), err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
|
||||
var rowCount int
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
||||
|
||||
for ct.rowSrc.Next() {
|
||||
values, err := ct.rowSrc.Values()
|
||||
if err != nil {
|
||||
return false, nil, 0, err
|
||||
return false, nil, err
|
||||
}
|
||||
if len(values) != len(ct.columnNames) {
|
||||
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
||||
}
|
||||
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
return false, nil, 0, err
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rowCount++
|
||||
|
||||
if len(buf) > 65536 {
|
||||
return true, buf, rowCount, nil
|
||||
return true, buf, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, buf, rowCount, nil
|
||||
}
|
||||
|
||||
func (c *Conn) readUntilCopyInResponse() error {
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
return nil
|
||||
default:
|
||||
err = c.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *copyFrom) cancelCopyIn() error {
|
||||
buf := ct.conn.wbuf
|
||||
buf = append(buf, copyFail)
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
buf = append(buf, "client error: abort"...)
|
||||
buf = append(buf, 0)
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
||||
|
||||
_, err := ct.conn.pgConn.Conn().Write(buf)
|
||||
if err != nil {
|
||||
ct.conn.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, buf, nil
|
||||
}
|
||||
|
||||
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
||||
@@ -283,57 +160,3 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
|
||||
|
||||
return ct.run()
|
||||
}
|
||||
|
||||
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
||||
func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
|
||||
if err := c.sendSimpleQuery(sql); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.readUntilCopyInResponse(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf := c.wbuf
|
||||
|
||||
buf = append(buf, copyData)
|
||||
sp := len(buf)
|
||||
for {
|
||||
n, err := r.Read(buf[5:cap(buf)])
|
||||
if err == io.EOF && n == 0 {
|
||||
break
|
||||
}
|
||||
buf = buf[0 : n+5]
|
||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
buf = buf[:0]
|
||||
buf = append(buf, copyDone)
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
|
||||
if _, err := c.pgConn.Conn().Write(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
return "", err
|
||||
case *pgproto3.CommandComplete:
|
||||
return pgconn.CommandTag(msg.CommandTag), nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return "", c.rxErrorResponse(msg)
|
||||
default:
|
||||
return "", c.processContextFreeMsg(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user