Add PgConn.CopyTo
This commit is contained in:
@@ -747,6 +747,71 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyTo executes the copy command sql and copies the results to w.
|
||||||
|
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case pgConn.controller <- pgConn:
|
||||||
|
}
|
||||||
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
|
// Send copy to command
|
||||||
|
var buf []byte
|
||||||
|
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
||||||
|
|
||||||
|
n, err := pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
// Partially sent messages are a fatal error for the connection.
|
||||||
|
if n > 0 {
|
||||||
|
// Close connection because cannot recover from partially sent message.
|
||||||
|
pgConn.conn.Close()
|
||||||
|
pgConn.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupContextDeadline()
|
||||||
|
<-pgConn.controller
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read results
|
||||||
|
var commandTag CommandTag
|
||||||
|
var pgErr error
|
||||||
|
for {
|
||||||
|
msg, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
cleanupContextDeadline()
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
go pgConn.recoverFromTimeout()
|
||||||
|
} else {
|
||||||
|
<-pgConn.controller
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.CopyDone:
|
||||||
|
case *pgproto3.CopyData:
|
||||||
|
_, err := w.Write(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
// This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup.
|
||||||
|
cleanupContextDeadline()
|
||||||
|
go pgConn.recoverFromTimeout()
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
case *pgproto3.ReadyForQuery:
|
||||||
|
<-pgConn.controller
|
||||||
|
return commandTag, pgErr
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
commandTag = CommandTag(msg.CommandTag)
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
pgErr = errorResponseToPgError(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||||
type MultiResultReader struct {
|
type MultiResultReader struct {
|
||||||
pgConn *PgConn
|
pgConn *PgConn
|
||||||
|
|||||||
+112
@@ -1,6 +1,7 @@
|
|||||||
package pgconn_test
|
package pgconn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -679,6 +680,117 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyToSmall(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int2,
|
||||||
|
b int4,
|
||||||
|
c int8,
|
||||||
|
d varchar,
|
||||||
|
e text,
|
||||||
|
f date,
|
||||||
|
g json
|
||||||
|
)`).ReadAll()
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
|
||||||
|
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
||||||
|
|
||||||
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||||
|
|
||||||
|
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(2), res.RowsAffected())
|
||||||
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyToLarge(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int2,
|
||||||
|
b int4,
|
||||||
|
c int8,
|
||||||
|
d varchar,
|
||||||
|
e text,
|
||||||
|
f date,
|
||||||
|
g json,
|
||||||
|
h bytea
|
||||||
|
)`).ReadAll()
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
inputBytes := make([]byte, 0)
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
|
||||||
|
require.Nil(t, err)
|
||||||
|
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
||||||
|
|
||||||
|
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(1000), res.RowsAffected())
|
||||||
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyToQueryError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
||||||
|
|
||||||
|
res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &pgconn.PgError{}, err)
|
||||||
|
assert.Equal(t, int64(0), res.RowsAffected())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyToCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
outputWriter := &bytes.Buffer{}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
||||||
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
assert.Equal(t, pgconn.CommandTag(""), res)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user