2
0

Link context errors and underlying conn errors

Using golang.org/x/xerrors type errors both errors can be exposed.
This commit is contained in:
Jack Christensen
2019-04-20 15:53:30 -05:00
parent f3b5f6b275
commit 0f8e1d30e2
3 changed files with 105 additions and 64 deletions
+85
View File
@@ -0,0 +1,85 @@
package pgconn
import (
"context"
"net"
errors "golang.org/x/xerrors"
)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// linkedError connects two errors as if err wrapped next.
type linkedError struct {
err error
next error
}
func (le *linkedError) Error() string {
return le.err.Error()
}
func (le *linkedError) Is(target error) bool {
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
func linkErrors(outer, inner error) error {
if outer == nil {
return inner
}
if inner == nil {
return outer
}
return &linkedError{err: outer, next: inner}
}
+16 -60
View File
@@ -26,33 +26,6 @@ const (
connStatusBusy connStatusBusy
) )
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
// LISTEN/NOTIFY notification. // LISTEN/NOTIFY notification.
type Notice PgError type Notice PgError
@@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event. // notice event.
type NotificationHandler func(*PgConn, *Notification) type NotificationHandler func(*PgConn, *Notification)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct { type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection conn net.Conn // the underlying TCP or unix domain socket connection
@@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil { if err != nil {
return preferContextOverNetTimeoutError(ctx, err) return linkErrors(ctx.Err(), err)
} }
_, err = pgConn.conn.Read(make([]byte, 1)) _, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF { if err != io.EOF {
return preferContextOverNetTimeoutError(ctx, err) return linkErrors(ctx.Err(), err)
} }
return pgConn.conn.Close() return pgConn.conn.Close()
@@ -469,15 +434,6 @@ func (ct CommandTag) String() string {
return string(ct) return string(ct)
} }
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
}
return err
}
type PreparedStatementDescription struct { type PreparedStatementDescription struct {
Name string Name string
SQL string SQL string
@@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
_, err := pgConn.conn.Write(buf) _, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
psd := &PreparedStatementDescription{Name: name, SQL: sql} psd := &PreparedStatementDescription{Name: name, SQL: sql}
@@ -520,7 +476,7 @@ readloop:
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf) _, err = cancelConn.Write(buf)
if err != nil { if err != nil {
return preferContextOverNetTimeoutError(ctx, err) return linkErrors(ctx.Err(), err)
} }
_, err = cancelConn.Read(buf) _, err = cancelConn.Read(buf)
if err != io.EOF { if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
} }
return nil return nil
@@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
return preferContextOverNetTimeoutError(ctx, err) return linkErrors(ctx.Err(), err)
} }
switch msg.(type) { switch msg.(type) {
@@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose() pgConn.hardClose()
pgConn.doneChanToDeadline.cleanup() pgConn.doneChanToDeadline.cleanup()
multiResult.closed = true multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err) multiResult.err = linkErrors(ctx.Err(), err)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
} }
@@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
pgConn.hardClose() pgConn.hardClose()
pgConn.unlock() pgConn.unlock()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
// Read results // Read results
@@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err := pgConn.conn.Write(buf) _, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
// Read until copy in response or error. // Read until copy in response or error.
@@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
} }
@@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
// Read results // Read results
@@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.ReceiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, linkErrors(ctx.Err(), err)
} }
switch msg := msg.(type) { switch msg := msg.(type) {
+4 -4
View File
@@ -18,7 +18,7 @@ import (
"time" "time"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/pkg/errors" errors "golang.org/x/xerrors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx) err = pgConn.WaitForNotification(ctx)
cancel() cancel()
require.Equal(t, context.DeadlineExceeded, err) assert.True(t, errors.Is(err, context.DeadlineExceeded))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag(nil), res)
assert.False(t, pgConn.IsAlive()) assert.False(t, pgConn.IsAlive())
@@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel() cancel()
assert.Equal(t, int64(0), ct.RowsAffected()) assert.Equal(t, int64(0), ct.RowsAffected())
require.Equal(t, context.DeadlineExceeded, err) assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.False(t, pgConn.IsAlive()) assert.False(t, pgConn.IsAlive())
} }