From 68d6d1c77950d01d5b693a997c048c15386c7032 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 10:21:16 -0600 Subject: [PATCH] Properly abort CopyFrom on reader error --- pgconn/pgconn.go | 45 ++++++++++++++++++++------------------- pgproto3/copy_fail.go | 49 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 21 deletions(-) create mode 100644 pgproto3/copy_fail.go diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index e8baffa2..d8ec6b07 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) - for { - n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF && n == 0 { - break - } - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) + var readErr error + for readErr == nil { + n, readErr = r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true - cleanupContextDeadline() - <-pgConn.controller + cleanupContextDeadline() + <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) + return "", preferContextOverNetTimeoutError(ctx, err) + } } } - // Send copy done buf = buf[:0] - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) - + if readErr == io.EOF { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + buf = copyFail.Encode(buf) + } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.conn.Close() diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go new file mode 100644 index 00000000..432a311b --- /dev/null +++ b/pgproto3/copy_fail.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyFail struct { + Error string +} + +func (*CopyFail) Frontend() {} +func (*CopyFail) Backend() {} + +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Error = string(src[:idx]) + + return nil +} + +func (src *CopyFail) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Error...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Error string + }{ + Type: "CopyFail", + Error: src.Error, + }) +}