From f5aecdd4992504d8344ea0730800e38d48b32f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:33:51 -0600 Subject: [PATCH] Extract writeAll --- pgconn.go | 87 ++++++++++++++----------------------------------------- 1 file changed, 21 insertions(+), 66 deletions(-) diff --git a/pgconn.go b/pgconn.go index e34853a0..461ff1c0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,6 +398,15 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } +// writeAll writes the entire buffer successfully or it hard closes the connection. +func (pgConn *PgConn) writeAll(buf []byte) error { + n, err := pgConn.conn.Write(buf) + if err != nil && n > 0 { + pgConn.hardClose() + } + return err +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -482,15 +491,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -654,15 +656,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -718,15 +713,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -770,15 +758,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -801,15 +782,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -869,15 +843,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -913,25 +880,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) + buf = make([]byte, 0, 20000) + // buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error signalMessageChan := pgConn.signalMessage() for readErr == nil && pgErr == nil { + var n int n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - n, err = pgConn.conn.Write(buf) + err = pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { go pgConn.recoverFromTimeoutDuringCopyFrom() @@ -975,8 +938,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.conn.Close() - pgConn.closed = true + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -1414,15 +1376,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - n, err := pgConn.conn.Write(batch.buf) + err := pgConn.writeAll(batch.buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err)