diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index 8b672e4e..90597c3e 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -2,7 +2,6 @@ package nbconn_test import ( "crypto/tls" - "errors" "io" "net" "strings" @@ -10,6 +9,7 @@ import ( "time" "github.com/jackc/pgx/v5/internal/nbconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -354,24 +354,34 @@ func TestBufferNonBlockingRead(t *testing.T) { errChan := make(chan error, 1) go func() { - _, err := remote.Write([]byte("okay")) + err := remote.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + errChan <- err + return + } + + _, err = remote.Write([]byte("okay")) errChan <- err }() + readLoop: for i := 0; i < 1000; i++ { - err = conn.BufferReadUntilBlock() - if !errors.Is(err, nbconn.ErrWouldBlock) { - break + err := conn.BufferReadUntilBlock() + require.NoError(t, err) + select { + case err := <-errChan: + require.NoError(t, err) + break readLoop + default: + time.Sleep(time.Millisecond) } - time.Sleep(time.Millisecond) } - require.NoError(t, err) buf := make([]byte, 4) n, err := conn.Read(buf) require.NoError(t, err) - require.EqualValues(t, 4, n) - require.Equal(t, []byte("okay"), buf) + assert.EqualValues(t, 4, n) + assert.Equal(t, []byte("okay"), buf) }) }