Move CopyFrom to pgconn
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgio"
|
||||||
"github.com/jackc/pgx/pgproto3"
|
"github.com/jackc/pgx/pgproto3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -812,6 +813,134 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server.
|
||||||
|
//
|
||||||
|
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
|
||||||
|
// could still block.
|
||||||
|
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case pgConn.controller <- pgConn:
|
||||||
|
}
|
||||||
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
|
// Send copy to command
|
||||||
|
var buf []byte
|
||||||
|
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
||||||
|
|
||||||
|
n, err := pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
// Partially sent messages are a fatal error for the connection.
|
||||||
|
if n > 0 {
|
||||||
|
// Close connection because cannot recover from partially sent message.
|
||||||
|
pgConn.conn.Close()
|
||||||
|
pgConn.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupContextDeadline()
|
||||||
|
<-pgConn.controller
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read until copy in response or error.
|
||||||
|
var commandTag CommandTag
|
||||||
|
var pgErr error
|
||||||
|
pendingCopyInResponse := true
|
||||||
|
for pendingCopyInResponse {
|
||||||
|
msg, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
cleanupContextDeadline()
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
go pgConn.recoverFromTimeout()
|
||||||
|
} else {
|
||||||
|
<-pgConn.controller
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.CopyInResponse:
|
||||||
|
pendingCopyInResponse = false
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
pgErr = errorResponseToPgError(msg)
|
||||||
|
case *pgproto3.ReadyForQuery:
|
||||||
|
<-pgConn.controller
|
||||||
|
return commandTag, pgErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send copy data
|
||||||
|
buf = make([]byte, 0, 65536)
|
||||||
|
buf = append(buf, 'd')
|
||||||
|
sp := len(buf)
|
||||||
|
for {
|
||||||
|
n, err := r.Read(buf[5:cap(buf)])
|
||||||
|
if err == io.EOF && n == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buf = buf[0 : n+5]
|
||||||
|
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||||
|
|
||||||
|
_, err = pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
|
||||||
|
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
|
||||||
|
// close the connection.
|
||||||
|
pgConn.conn.Close()
|
||||||
|
pgConn.closed = true
|
||||||
|
|
||||||
|
cleanupContextDeadline()
|
||||||
|
<-pgConn.controller
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send copy done
|
||||||
|
buf = buf[:0]
|
||||||
|
copyDone := &pgproto3.CopyDone{}
|
||||||
|
buf = copyDone.Encode(buf)
|
||||||
|
|
||||||
|
_, err = pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.conn.Close()
|
||||||
|
pgConn.closed = true
|
||||||
|
|
||||||
|
cleanupContextDeadline()
|
||||||
|
<-pgConn.controller
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read results
|
||||||
|
for {
|
||||||
|
msg, err := pgConn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
cleanupContextDeadline()
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
go pgConn.recoverFromTimeout()
|
||||||
|
} else {
|
||||||
|
<-pgConn.controller
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.ReadyForQuery:
|
||||||
|
<-pgConn.controller
|
||||||
|
return commandTag, pgErr
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
commandTag = CommandTag(msg.CommandTag)
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
pgErr = errorResponseToPgError(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||||
type MultiResultReader struct {
|
type MultiResultReader struct {
|
||||||
pgConn *PgConn
|
pgConn *PgConn
|
||||||
|
|||||||
+136
@@ -2,12 +2,15 @@ package pgconn_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -791,6 +794,139 @@ func TestConnCopyToCanceled(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFrom(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcBuf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
inputRows := [][][]byte{}
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
a := strconv.Itoa(i)
|
||||||
|
b := "foo " + a + " bar"
|
||||||
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
||||||
|
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
|
||||||
|
assert.Equal(t, inputRows, result.Rows)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
f, err := ioutil.TempFile("", "*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gw := gzip.NewWriter(f)
|
||||||
|
|
||||||
|
inputRows := [][][]byte{}
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
a := strconv.Itoa(i)
|
||||||
|
b := "foo " + a + " bar"
|
||||||
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
||||||
|
_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = gw.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = f.Seek(0, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gr, err := gzip.NewReader(f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
||||||
|
|
||||||
|
err = gr.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = f.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = os.Remove(f.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
|
||||||
|
require.NoError(t, result.Err)
|
||||||
|
|
||||||
|
assert.Equal(t, inputRows, result.Rows)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromQuerySyntaxError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcBuf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &pgconn.PgError{}, err)
|
||||||
|
assert.Equal(t, int64(0), res.RowsAffected())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFromQueryNoTableError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
srcBuf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.IsType(t, &pgconn.PgError{}, err)
|
||||||
|
assert.Equal(t, int64(0), res.RowsAffected())
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnEscapeString(t *testing.T) {
|
func TestConnEscapeString(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user