diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 49062f23..e8baffa2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -812,6 +813,134 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } } +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + + // Send copy to command + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read until copy in response or error. + var commandTag CommandTag + var pgErr error + pendingCopyInResponse := true + for pendingCopyInResponse { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: + pendingCopyInResponse = false + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + } + } + + // Send copy data + buf = make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + for { + n, err := r.Read(buf[5:cap(buf)]) + if err == io.EOF && n == 0 { + break + } + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + } + + // Send copy done + buf = buf[:0] + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read results + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 587acc57..47b3b3fb 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2,12 +2,15 @@ package pgconn_test import ( "bytes" + "compress/gzip" "context" "crypto/tls" "fmt" + "io/ioutil" "log" "net" "os" + "strconv" "testing" "time" @@ -791,6 +794,139 @@ func TestConnCopyToCanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + func TestConnEscapeString(t *testing.T) { t.Parallel()