2
0

Partial conversion of pgx to use pgconn

This commit is contained in:
Jack Christensen
2019-01-26 16:46:30 -06:00
parent e3d431d0df
commit d3a2c1c107
17 changed files with 877 additions and 1830 deletions
+41 -218
View File
@@ -2,12 +2,11 @@ package pgx
import (
"bytes"
"context"
"fmt"
"io"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors"
)
@@ -58,39 +57,6 @@ type copyFrom struct {
readerErrChan chan error
}
func (ct *copyFrom) readUntilReadyForQuery() {
for {
msg, err := ct.conn.rxMsg()
if err != nil {
ct.readerErrChan <- err
close(ct.readerErrChan)
return
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
ct.conn.rxReadyForQuery(msg)
close(ct.readerErrChan)
return
case *pgproto3.CommandComplete:
case *pgproto3.ErrorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
default:
err = ct.conn.processContextFreeMsg(msg)
if err != nil {
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
}
}
}
}
func (ct *copyFrom) waitForReaderDone() error {
var err error
for err = range ct.readerErrChan {
}
return err
}
func (ct *copyFrom) run() (int, error) {
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
@@ -107,163 +73,74 @@ func (ct *copyFrom) run() (int, error) {
return 0, err
}
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
if err != nil {
return 0, err
}
r, w := io.Pipe()
err = ct.conn.readUntilCopyInResponse()
if err != nil {
return 0, err
}
go func() {
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf
panicked := true
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone()
defer func() {
if panicked {
ct.conn.die(errors.New("panic while in copy from"))
moreRows := true
for moreRows {
var err error
moreRows, buf, err = ct.buildCopyBuf(buf, ps)
if err != nil {
w.CloseWithError(err)
return
}
if ct.rowSrc.Err() != nil {
w.CloseWithError(ct.rowSrc.Err())
return
}
if len(buf) > 0 {
_, err = w.Write(buf)
if err != nil {
w.Close()
return
}
}
buf = buf[:0]
}
w.Close()
}()
buf := ct.conn.wbuf
buf = append(buf, copyData)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
var sentCount int
moreRows := true
for moreRows {
select {
case err = <-ct.readerErrChan:
panicked = false
return 0, err
default:
}
var addedRows int
var err error
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
if err != nil {
panicked = false
ct.cancelCopyIn()
return 0, err
}
sentCount += addedRows
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = ct.conn.pgConn.Conn().Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
buf = buf[0:5]
}
if ct.rowSrc.Err() != nil {
panicked = false
ct.cancelCopyIn()
return 0, ct.rowSrc.Err()
}
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
_, err = ct.conn.pgConn.Conn().Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
}
err = ct.waitForReaderDone()
if err != nil {
panicked = false
return 0, err
}
panicked = false
return sentCount, nil
return int(commandTag.RowsAffected()), err
}
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
var rowCount int
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
for ct.rowSrc.Next() {
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, 0, err
return false, nil, err
}
if len(values) != len(ct.columnNames) {
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
return false, nil, 0, err
return false, nil, err
}
}
rowCount++
if len(buf) > 65536 {
return true, buf, rowCount, nil
return true, buf, nil
}
}
return false, buf, rowCount, nil
}
func (c *Conn) readUntilCopyInResponse() error {
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
func (ct *copyFrom) cancelCopyIn() error {
buf := ct.conn.wbuf
buf = append(buf, copyFail)
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, "client error: abort"...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err := ct.conn.pgConn.Conn().Write(buf)
if err != nil {
ct.conn.die(err)
return err
}
return nil
return false, buf, nil
}
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
@@ -283,57 +160,3 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
return ct.run()
}
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
func (c *Conn) CopyFromReader(r io.Reader, sql string) (pgconn.CommandTag, error) {
if err := c.sendSimpleQuery(sql); err != nil {
return "", err
}
if err := c.readUntilCopyInResponse(); err != nil {
return "", err
}
buf := c.wbuf
buf = append(buf, copyData)
sp := len(buf)
for {
n, err := r.Read(buf[5:cap(buf)])
if err == io.EOF && n == 0 {
break
}
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err
}
}
buf = buf[:0]
buf = append(buf, copyDone)
buf = pgio.AppendInt32(buf, 4)
if _, err := c.pgConn.Conn().Write(buf); err != nil {
return "", err
}
for {
msg, err := c.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return "", err
case *pgproto3.CommandComplete:
return pgconn.CommandTag(msg.CommandTag), nil
case *pgproto3.ErrorResponse:
return "", c.rxErrorResponse(msg)
default:
return "", c.processContextFreeMsg(msg)
}
}
}