2
0

Properly abort CopyFrom on reader error

This commit is contained in:
Jack Christensen
2019-01-26 10:21:16 -06:00
parent 73003f86ee
commit 68d6d1c779
2 changed files with 73 additions and 21 deletions
+24 -21
View File
@@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
buf = make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
for {
n, err := r.Read(buf[5:cap(buf)])
if err == io.EOF && n == 0 {
break
}
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
var readErr error
for readErr == nil {
n, readErr = r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
_, err = pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
// close the connection.
pgConn.conn.Close()
pgConn.closed = true
_, err = pgConn.conn.Write(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
// close the connection.
pgConn.conn.Close()
pgConn.closed = true
cleanupContextDeadline()
<-pgConn.controller
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
return "", preferContextOverNetTimeoutError(ctx, err)
}
}
}
// Send copy done
buf = buf[:0]
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
if readErr == io.EOF {
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
} else {
copyFail := &pgproto3.CopyFail{Error: readErr.Error()}
buf = copyFail.Encode(buf)
}
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.conn.Close()
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/pgio"
)
type CopyFail struct {
Error string
}
func (*CopyFail) Frontend() {}
func (*CopyFail) Backend() {}
func (dst *CopyFail) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CopyFail"}
}
dst.Error = string(src[:idx])
return nil
}
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Error...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *CopyFail) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Error string
}{
Type: "CopyFail",
Error: src.Error,
})
}