2
0

Fix CopyFrom deadlock when multiple NoticeResponse received during copy

fixes #21
This commit is contained in:
Jack Christensen
2020-01-25 20:32:42 -06:00
parent 6124b07bb1
commit 139342081e
2 changed files with 75 additions and 16 deletions
+35 -16
View File
@@ -1084,26 +1084,44 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
// Send copy data
buf = make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
var readErr error
abortCopyChan := make(chan struct{})
copyErrChan := make(chan error)
signalMessageChan := pgConn.signalMessage()
for readErr == nil && pgErr == nil {
var n int
n, readErr = r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
return nil, err
go func() {
buf := make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
for {
n, readErr := r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
_, writeErr := pgConn.conn.Write(buf)
if writeErr != nil {
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <-abortCopyChan:
return
default:
}
}
}()
var copyErr error
for copyErr == nil && pgErr == nil {
select {
case copyErr = <-copyErrChan:
case <-signalMessageChan:
msg, err := pgConn.receiveMessage()
if err != nil {
@@ -1120,13 +1138,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
default:
}
}
close(abortCopyChan)
buf = buf[:0]
if readErr == io.EOF || pgErr != nil {
if copyErr == io.EOF || pgErr != nil {
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
} else {
copyFail := &pgproto3.CopyFail{Message: readErr.Error()}
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
buf = copyFail.Encode(buf)
}
_, err = pgConn.conn.Write(buf)
+40
View File
@@ -1463,6 +1463,46 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) {
ensureConnValid(t, pgConn)
}
// https://github.com/jackc/pgconn/issues/21
func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
t.Parallel()
ctx := context.Background()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(ctx, `create temporary table sentences(
t text,
ts tsvector
)`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
begin
new.ts := to_tsvector(new.t);
return new;
end
$$ language plpgsql;`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
require.NoError(t, err)
longString := make([]byte, 10001)
for i := range longString {
longString[i] = 'x'
}
buf := &bytes.Buffer{}
for i := 0; i < 1000; i++ {
buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
}
_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
require.NoError(t, err)
}
func TestConnEscapeString(t *testing.T) {
t.Parallel()