2
0

Make ErrorResponseToPgError public

This commit is contained in:
Jack Christensen
2019-08-20 15:49:57 -05:00
parent d364370a31
commit 11255efe7a
+13 -12
View File
@@ -233,7 +233,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
// handled by ReceiveMessage // handled by ReceiveMessage
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgConn.conn.Close() pgConn.conn.Close()
return nil, errorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.New("unexpected message") return nil, errors.New("unexpected message")
@@ -400,7 +400,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" { if msg.Severity == "FATAL" {
pgConn.hardClose() pgConn.hardClose()
return nil, errorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
} }
case *pgproto3.NoticeResponse: case *pgproto3.NoticeResponse:
if pgConn.Config.OnNotice != nil { if pgConn.Config.OnNotice != nil {
@@ -577,7 +577,7 @@ readloop:
psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
copy(psd.Fields, msg.Fields) copy(psd.Fields, msg.Fields)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
parseErr = errorResponseToPgError(msg) parseErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
break readloop break readloop
} }
@@ -589,7 +589,8 @@ readloop:
return psd, nil return psd, nil
} }
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { // ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{ return &PgError{
Severity: msg.Severity, Severity: msg.Severity,
Code: string(msg.Code), Code: string(msg.Code),
@@ -612,7 +613,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
} }
func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
return (*Notice)(pgerr) return (*Notice)(pgerr)
} }
@@ -898,7 +899,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
} }
} }
@@ -949,7 +950,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case *pgproto3.CopyInResponse: case *pgproto3.CopyInResponse:
pendingCopyInResponse = false pendingCopyInResponse = false
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
return commandTag, pgErr return commandTag, pgErr
} }
@@ -985,7 +986,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
default: default:
} }
@@ -1019,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
} }
} }
@@ -1064,7 +1065,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.closed = true mrr.closed = true
mrr.pgConn.unlock() mrr.pgConn.unlock()
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
mrr.err = errorResponseToPgError(msg) mrr.err = ErrorResponseToPgError(msg)
} }
return msg, nil return msg, nil
@@ -1219,7 +1220,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
switch msg := msg.(type) { switch msg := msg.(type) {
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.err = errorResponseToPgError(msg) rr.err = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
rr.pgConn.contextWatcher.Unwatch() rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock() rr.pgConn.unlock()
@@ -1255,7 +1256,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil) rr.concludeCommand(CommandTag(msg.CommandTag), nil)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.concludeCommand(nil, errorResponseToPgError(msg)) rr.concludeCommand(nil, ErrorResponseToPgError(msg))
} }
return msg, nil return msg, nil