From 0f8e1d30e2dc1a4f359761d5418126bb0e0685d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 15:53:30 -0500 Subject: [PATCH] Link context errors and underlying conn errors Using golang.org/x/xerrors type errors both errors can be exposed. --- errors.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 76 ++++++++++---------------------------------- pgconn_test.go | 8 ++--- 3 files changed, 105 insertions(+), 64 deletions(-) create mode 100644 errors.go diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..e42dae16 --- /dev/null +++ b/errors.go @@ -0,0 +1,85 @@ +package pgconn + +import ( + "context" + "net" + + errors "golang.org/x/xerrors" +) + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another +// action is attempted. +var ErrConnBusy = errors.New("conn is busy") + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// linkedError connects two errors as if err wrapped next. +type linkedError struct { + err error + next error +} + +func (le *linkedError) Error() string { + return le.err.Error() +} + +func (le *linkedError) Is(target error) bool { + return errors.Is(le.err, target) +} + +func (le *linkedError) As(target interface{}) bool { + return errors.As(le.err, target) +} + +func (le *linkedError) Unwrap() error { + return le.next +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. +func linkErrors(outer, inner error) error { + if outer == nil { + return inner + } + if inner == nil { + return outer + } + return &linkedError{err: outer, next: inner} +} diff --git a/pgconn.go b/pgconn.go index 14377beb..2911211c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -26,33 +26,6 @@ const ( connStatusBusy ) -// PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for -// detailed field description. -type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string -} - -func (pe *PgError) Error() string { - return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" -} - // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") - -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } return pgConn.conn.Close() @@ -469,15 +434,6 @@ func (ct CommandTag) String() string { return string(ct) } -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() - } - return err -} - type PreparedStatementDescription struct { Name string SQL string @@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} @@ -520,7 +476,7 @@ readloop: msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) + return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) } return nil @@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } switch msg.(type) { @@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) + multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult } @@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read until copy in response or error. @@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } } @@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { diff --git a/pgconn_test.go b/pgconn_test.go index 3fc15e7a..30e6a425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,7 +18,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) ensureConnValid(t, pgConn) } @@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) @@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.False(t, pgConn.IsAlive()) }