Link context errors and underlying conn errors
Using golang.org/x/xerrors type errors both errors can be exposed.
This commit is contained in:
@@ -0,0 +1,85 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
errors "golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||||
|
// PostgreSQL server refuses to use TLS
|
||||||
|
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||||
|
|
||||||
|
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
|
||||||
|
// action is attempted.
|
||||||
|
var ErrConnBusy = errors.New("conn is busy")
|
||||||
|
|
||||||
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
|
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||||
|
// detailed field description.
|
||||||
|
type PgError struct {
|
||||||
|
Severity string
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pe *PgError) Error() string {
|
||||||
|
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// linkedError connects two errors as if err wrapped next.
|
||||||
|
type linkedError struct {
|
||||||
|
err error
|
||||||
|
next error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (le *linkedError) Error() string {
|
||||||
|
return le.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (le *linkedError) Is(target error) bool {
|
||||||
|
return errors.Is(le.err, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (le *linkedError) As(target interface{}) bool {
|
||||||
|
return errors.As(le.err, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (le *linkedError) Unwrap() error {
|
||||||
|
return le.next
|
||||||
|
}
|
||||||
|
|
||||||
|
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||||
|
// true. Otherwise returns err.
|
||||||
|
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
|
||||||
|
func linkErrors(outer, inner error) error {
|
||||||
|
if outer == nil {
|
||||||
|
return inner
|
||||||
|
}
|
||||||
|
if inner == nil {
|
||||||
|
return outer
|
||||||
|
}
|
||||||
|
return &linkedError{err: outer, next: inner}
|
||||||
|
}
|
||||||
@@ -26,33 +26,6 @@ const (
|
|||||||
connStatusBusy
|
connStatusBusy
|
||||||
)
|
)
|
||||||
|
|
||||||
// PgError represents an error reported by the PostgreSQL server. See
|
|
||||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
|
||||||
// detailed field description.
|
|
||||||
type PgError struct {
|
|
||||||
Severity string
|
|
||||||
Code string
|
|
||||||
Message string
|
|
||||||
Detail string
|
|
||||||
Hint string
|
|
||||||
Position int32
|
|
||||||
InternalPosition int32
|
|
||||||
InternalQuery string
|
|
||||||
Where string
|
|
||||||
SchemaName string
|
|
||||||
TableName string
|
|
||||||
ColumnName string
|
|
||||||
DataTypeName string
|
|
||||||
ConstraintName string
|
|
||||||
File string
|
|
||||||
Line int32
|
|
||||||
Routine string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pe *PgError) Error() string {
|
|
||||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
|
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
|
||||||
// LISTEN/NOTIFY notification.
|
// LISTEN/NOTIFY notification.
|
||||||
type Notice PgError
|
type Notice PgError
|
||||||
@@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice)
|
|||||||
// notice event.
|
// notice event.
|
||||||
type NotificationHandler func(*PgConn, *Notification)
|
type NotificationHandler func(*PgConn, *Notification)
|
||||||
|
|
||||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
|
||||||
// PostgreSQL server refuses to use TLS
|
|
||||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
|
||||||
|
|
||||||
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
|
|
||||||
// action is attempted.
|
|
||||||
var ErrConnBusy = errors.New("conn is busy")
|
|
||||||
|
|
||||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
type PgConn struct {
|
type PgConn struct {
|
||||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
@@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
|||||||
|
|
||||||
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
|
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pgConn.conn.Read(make([]byte, 1))
|
_, err = pgConn.conn.Read(make([]byte, 1))
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pgConn.conn.Close()
|
return pgConn.conn.Close()
|
||||||
@@ -469,15 +434,6 @@ func (ct CommandTag) String() string {
|
|||||||
return string(ct)
|
return string(ct)
|
||||||
}
|
}
|
||||||
|
|
||||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
|
||||||
// true. Otherwise returns err.
|
|
||||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
|
||||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type PreparedStatementDescription struct {
|
type PreparedStatementDescription struct {
|
||||||
Name string
|
Name string
|
||||||
SQL string
|
SQL string
|
||||||
@@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||||||
_, err := pgConn.conn.Write(buf)
|
_, err := pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
psd := &PreparedStatementDescription{Name: name, SQL: sql}
|
psd := &PreparedStatementDescription{Name: name, SQL: sql}
|
||||||
@@ -520,7 +476,7 @@ readloop:
|
|||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
@@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
|||||||
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
|
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
|
||||||
_, err = cancelConn.Write(buf)
|
_, err = cancelConn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = cancelConn.Read(buf)
|
_, err = cancelConn.Read(buf)
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err))
|
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg.(type) {
|
switch msg.(type) {
|
||||||
@@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
pgConn.doneChanToDeadline.cleanup()
|
pgConn.doneChanToDeadline.cleanup()
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
multiResult.err = linkErrors(ctx.Err(), err)
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
@@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
|
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read results
|
// Read results
|
||||||
@@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
@@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err := pgConn.conn.Write(buf)
|
_, err := pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read until copy in response or error.
|
// Read until copy in response or error.
|
||||||
@@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
@@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
@@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read results
|
// Read results
|
||||||
@@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
|
|||||||
+4
-4
@@ -18,7 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/pkg/errors"
|
errors "golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
err = pgConn.WaitForNotification(ctx)
|
err = pgConn.WaitForNotification(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
require.Equal(t, context.DeadlineExceeded, err)
|
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||||
|
|
||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
@@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
||||||
assert.Equal(t, context.DeadlineExceeded, err)
|
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||||
|
|
||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
@@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
|
|||||||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||||
cancel()
|
cancel()
|
||||||
assert.Equal(t, int64(0), ct.RowsAffected())
|
assert.Equal(t, int64(0), ct.RowsAffected())
|
||||||
require.Equal(t, context.DeadlineExceeded, err)
|
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||||
|
|
||||||
assert.False(t, pgConn.IsAlive())
|
assert.False(t, pgConn.IsAlive())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user