diff --git a/pgconn.go b/pgconn.go index e8baffa2..d8ec6b07 100644 --- a/pgconn.go +++ b/pgconn.go @@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) - for { - n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF && n == 0 { - break - } - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) + var readErr error + for readErr == nil { + n, readErr = r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true - cleanupContextDeadline() - <-pgConn.controller + cleanupContextDeadline() + <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) + return "", preferContextOverNetTimeoutError(ctx, err) + } } } - // Send copy done buf = buf[:0] - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) - + if readErr == io.EOF { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + buf = copyFail.Encode(buf) + } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.conn.Close()