2
0

Move CopyFrom to pgconn

This commit is contained in:
Jack Christensen
2019-01-19 17:24:48 -06:00
parent fb15f44dfa
commit 73003f86ee
2 changed files with 265 additions and 0 deletions
+129
View File
@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
)
@@ -812,6 +813,134 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
}
}
// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server.
//
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
select {
case <-ctx.Done():
return "", ctx.Err()
case pgConn.controller <- pgConn:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
// Send copy to command
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
// Read until copy in response or error.
var commandTag CommandTag
var pgErr error
pendingCopyInResponse := true
for pendingCopyInResponse {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
} else {
<-pgConn.controller
}
return "", preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
pendingCopyInResponse = false
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
<-pgConn.controller
return commandTag, pgErr
}
}
// Send copy data
buf = make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
for {
n, err := r.Read(buf[5:cap(buf)])
if err == io.EOF && n == 0 {
break
}
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
_, err = pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
// close the connection.
pgConn.conn.Close()
pgConn.closed = true
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
}
// Send copy done
buf = buf[:0]
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.conn.Close()
pgConn.closed = true
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
// Read results
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
} else {
<-pgConn.controller
}
return "", preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
<-pgConn.controller
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
}
}
}
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
+136
View File
@@ -2,12 +2,15 @@ package pgconn_test
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"strconv"
"testing"
"time"
@@ -791,6 +794,139 @@ func TestConnCopyToCanceled(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnCopyFrom(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
srcBuf := &bytes.Buffer{}
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
require.NoError(t, err)
}
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
assert.Equal(t, inputRows, result.Rows)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
f, err := ioutil.TempFile("", "*")
require.NoError(t, err)
gw := gzip.NewWriter(f)
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
require.NoError(t, err)
}
err = gw.Close()
require.NoError(t, err)
_, err = f.Seek(0, 0)
require.NoError(t, err)
gr, err := gzip.NewReader(f)
require.NoError(t, err)
ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
err = gr.Close()
require.NoError(t, err)
err = f.Close()
require.NoError(t, err)
err = os.Remove(f.Name())
require.NoError(t, err)
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
assert.Equal(t, inputRows, result.Rows)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromQuerySyntaxError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
srcBuf := &bytes.Buffer{}
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
func TestConnCopyFromQueryNoTableError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
srcBuf := &bytes.Buffer{}
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
func TestConnEscapeString(t *testing.T) {
t.Parallel()