diff --git a/backend.go b/backend.go index 6944f80d..e9ba38fc 100644 --- a/backend.go +++ b/backend.go @@ -67,7 +67,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err = b.cr.Next(msgSize) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } code := binary.BigEndian.Uint32(buf) @@ -107,7 +107,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.msgType = header[0] @@ -161,7 +161,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { msgBody, err := b.cr.Next(b.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.partialMsg = false diff --git a/backend_test.go b/backend_test.go index 3cfde003..5e9a2ac5 100644 --- a/backend_test.go +++ b/backend_test.go @@ -1,10 +1,12 @@ package pgproto3_test import ( + "io" "testing" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -35,6 +37,26 @@ func TestBackendReceiveInterrupted(t *testing.T) { } } +func TestBackendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) + + // Receive FE msg + server.push([]byte{'F', 0, 0, 0, 6}) + msg, err = backend.ReceiveStartupMessage() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +} + func TestStartupMessage(t *testing.T) { t.Parallel() @@ -93,5 +115,4 @@ func TestStartupMessage(t *testing.T) { }) } }) - }