diff --git a/pgconn.go b/pgconn.go index 44a08cc8..271e6628 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1084,26 +1084,44 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - var readErr error + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error) signalMessageChan := pgConn.signalMessage() - for readErr == nil && pgErr == nil { - var n int - n, readErr = r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - return nil, err + go func() { + buf := make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + + for { + n, readErr := r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, writeErr := pgConn.conn.Write(buf) + if writeErr != nil { + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: } } + }() + var copyErr error + for copyErr == nil && pgErr == nil { select { + case copyErr = <-copyErrChan: case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { @@ -1120,13 +1138,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: } } + close(abortCopyChan) buf = buf[:0] - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Message: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index c37a2fb2..19ad3a0a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1463,6 +1463,46 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + func TestConnEscapeString(t *testing.T) { t.Parallel()