Fix CopyFrom deadlock when multiple NoticeResponse received during copy
fixes #21
This commit is contained in:
@@ -1084,26 +1084,44 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send copy data
|
// Send copy data
|
||||||
buf = make([]byte, 0, 65536)
|
abortCopyChan := make(chan struct{})
|
||||||
buf = append(buf, 'd')
|
copyErrChan := make(chan error)
|
||||||
sp := len(buf)
|
|
||||||
var readErr error
|
|
||||||
signalMessageChan := pgConn.signalMessage()
|
signalMessageChan := pgConn.signalMessage()
|
||||||
for readErr == nil && pgErr == nil {
|
|
||||||
var n int
|
|
||||||
n, readErr = r.Read(buf[5:cap(buf)])
|
|
||||||
if n > 0 {
|
|
||||||
buf = buf[0 : n+5]
|
|
||||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
|
||||||
|
|
||||||
_, err = pgConn.conn.Write(buf)
|
go func() {
|
||||||
if err != nil {
|
buf := make([]byte, 0, 65536)
|
||||||
pgConn.asyncClose()
|
buf = append(buf, 'd')
|
||||||
return nil, err
|
sp := len(buf)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, readErr := r.Read(buf[5:cap(buf)])
|
||||||
|
if n > 0 {
|
||||||
|
buf = buf[0 : n+5]
|
||||||
|
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||||
|
|
||||||
|
_, writeErr := pgConn.conn.Write(buf)
|
||||||
|
if writeErr != nil {
|
||||||
|
copyErrChan <- writeErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
copyErrChan <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-abortCopyChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var copyErr error
|
||||||
|
for copyErr == nil && pgErr == nil {
|
||||||
select {
|
select {
|
||||||
|
case copyErr = <-copyErrChan:
|
||||||
case <-signalMessageChan:
|
case <-signalMessageChan:
|
||||||
msg, err := pgConn.receiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1120,13 +1138,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
close(abortCopyChan)
|
||||||
|
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
if readErr == io.EOF || pgErr != nil {
|
if copyErr == io.EOF || pgErr != nil {
|
||||||
copyDone := &pgproto3.CopyDone{}
|
copyDone := &pgproto3.CopyDone{}
|
||||||
buf = copyDone.Encode(buf)
|
buf = copyDone.Encode(buf)
|
||||||
} else {
|
} else {
|
||||||
copyFail := &pgproto3.CopyFail{Message: readErr.Error()}
|
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
|
||||||
buf = copyFail.Encode(buf)
|
buf = copyFail.Encode(buf)
|
||||||
}
|
}
|
||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
|
|||||||
@@ -1463,6 +1463,46 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgconn/issues/21
|
||||||
|
func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, `create temporary table sentences(
|
||||||
|
t text,
|
||||||
|
ts tsvector
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
|
||||||
|
begin
|
||||||
|
new.ts := to_tsvector(new.t);
|
||||||
|
return new;
|
||||||
|
end
|
||||||
|
$$ language plpgsql;`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
longString := make([]byte, 10001)
|
||||||
|
for i := range longString {
|
||||||
|
longString[i] = 'x'
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnEscapeString(t *testing.T) {
|
func TestConnEscapeString(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user