diff --git a/pgconn.go b/pgconn.go index 5644904a..50607095 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1435,12 +1435,17 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + if rr.commandConcluded { return } rr.commandTag = commandTag - rr.err = err rr.rowValues = nil rr.commandConcluded = true } diff --git a/pgconn_test.go b/pgconn_test.go index 6362c51b..379aa266 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1752,6 +1752,71 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { pgConn.Close(closeCtx) } +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {