2
0

Fix deadlock when CopyFromSource panics

fixes #433
This commit is contained in:
Jack Christensen
2018-07-14 11:26:09 -05:00
parent 3cbe92ebb5
commit 20c02acd63
2 changed files with 101 additions and 27 deletions
+59 -27
View File
@@ -115,8 +115,15 @@ func (ct *copyFrom) run() (int, error) {
return 0, err return 0, err
} }
panicked := true
go ct.readUntilReadyForQuery() go ct.readUntilReadyForQuery()
defer ct.waitForReaderDone() defer ct.waitForReaderDone()
defer func() {
if panicked {
ct.conn.die(errors.New("panic while in copy from"))
}
}()
buf := ct.conn.wbuf buf := ct.conn.wbuf
buf = append(buf, copyData) buf = append(buf, copyData)
@@ -129,49 +136,40 @@ func (ct *copyFrom) run() (int, error) {
var sentCount int var sentCount int
for ct.rowSrc.Next() { moreRows := true
for moreRows {
select { select {
case err = <-ct.readerErrChan: case err = <-ct.readerErrChan:
panicked = false
return 0, err return 0, err
default: default:
} }
if len(buf) > 65536 { var addedRows int
pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) var err error
_, err = ct.conn.conn.Write(buf) moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
if err != nil {
ct.conn.die(err)
return 0, err
}
// Directly manipulate wbuf to reset to reuse the same buffer
buf = buf[0:5]
}
sentCount++
values, err := ct.rowSrc.Values()
if err != nil { if err != nil {
panicked = false
ct.cancelCopyIn() ct.cancelCopyIn()
return 0, err return 0, err
} }
if len(values) != len(ct.columnNames) { sentCount += addedRows
ct.cancelCopyIn() pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
_, err = ct.conn.conn.Write(buf)
if err != nil {
panicked = false
ct.conn.die(err)
return 0, err
} }
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) // Directly manipulate wbuf to reset to reuse the same buffer
for i, val := range values { buf = buf[0:5]
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err
}
}
} }
if ct.rowSrc.Err() != nil { if ct.rowSrc.Err() != nil {
panicked = false
ct.cancelCopyIn() ct.cancelCopyIn()
return 0, ct.rowSrc.Err() return 0, ct.rowSrc.Err()
} }
@@ -184,17 +182,51 @@ func (ct *copyFrom) run() (int, error) {
_, err = ct.conn.conn.Write(buf) _, err = ct.conn.conn.Write(buf)
if err != nil { if err != nil {
panicked = false
ct.conn.die(err) ct.conn.die(err)
return 0, err return 0, err
} }
err = ct.waitForReaderDone() err = ct.waitForReaderDone()
if err != nil { if err != nil {
panicked = false
return 0, err return 0, err
} }
panicked = false
return sentCount, nil return sentCount, nil
} }
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
var rowCount int
for ct.rowSrc.Next() {
values, err := ct.rowSrc.Values()
if err != nil {
return false, nil, 0, err
}
if len(values) != len(ct.columnNames) {
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
}
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
return false, nil, 0, err
}
}
rowCount++
if len(buf) > 65536 {
return true, buf, rowCount, nil
}
}
return false, buf, rowCount, nil
}
func (c *Conn) readUntilCopyInResponse() error { func (c *Conn) readUntilCopyInResponse() error {
for { for {
msg, err := c.rxMsg() msg, err := c.rxMsg()
+42
View File
@@ -426,3 +426,45 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
type nextPanicSource struct {
}
func (cfs *nextPanicSource) Next() bool {
panic("crash")
}
func (cfs *nextPanicSource) Values() ([]interface{}, error) {
return []interface{}{nil}, nil // should never get here
}
func (cfs *nextPanicSource) Err() error {
return nil // should never gets here
}
func TestConnCopyFromCopyFromSourceNextPanic(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a bytea not null
)`)
caughtPanic := false
func() {
defer func() {
if x := recover(); x != nil {
caughtPanic = true
}
}()
conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &nextPanicSource{})
}()
if conn.IsAlive() {
t.Error("panic should have killed conn")
}
}