Properly abort CopyFrom on reader error
This commit is contained in:
+24
-21
@@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
buf = make([]byte, 0, 65536)
|
buf = make([]byte, 0, 65536)
|
||||||
buf = append(buf, 'd')
|
buf = append(buf, 'd')
|
||||||
sp := len(buf)
|
sp := len(buf)
|
||||||
for {
|
var readErr error
|
||||||
n, err := r.Read(buf[5:cap(buf)])
|
for readErr == nil {
|
||||||
if err == io.EOF && n == 0 {
|
n, readErr = r.Read(buf[5:cap(buf)])
|
||||||
break
|
if n > 0 {
|
||||||
}
|
buf = buf[0 : n+5]
|
||||||
buf = buf[0 : n+5]
|
pgio.SetInt32(buf[sp:], int32(n+4))
|
||||||
pgio.SetInt32(buf[sp:], int32(n+4))
|
|
||||||
|
|
||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
|
// Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to
|
||||||
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
|
// recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to
|
||||||
// close the connection.
|
// close the connection.
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
pgConn.closed = true
|
pgConn.closed = true
|
||||||
|
|
||||||
cleanupContextDeadline()
|
cleanupContextDeadline()
|
||||||
<-pgConn.controller
|
<-pgConn.controller
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send copy done
|
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
copyDone := &pgproto3.CopyDone{}
|
if readErr == io.EOF {
|
||||||
buf = copyDone.Encode(buf)
|
copyDone := &pgproto3.CopyDone{}
|
||||||
|
buf = copyDone.Encode(buf)
|
||||||
|
} else {
|
||||||
|
copyFail := &pgproto3.CopyFail{Error: readErr.Error()}
|
||||||
|
buf = copyFail.Encode(buf)
|
||||||
|
}
|
||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyFail struct {
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*CopyFail) Frontend() {}
|
||||||
|
func (*CopyFail) Backend() {}
|
||||||
|
|
||||||
|
func (dst *CopyFail) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Error = string(src[:idx])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Error...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *CopyFail) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Error string
|
||||||
|
}{
|
||||||
|
Type: "CopyFail",
|
||||||
|
Error: src.Error,
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user