From 10c6c50ac9638764295be9bd6e897b7386ba7614 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Wed, 30 Jun 2021 12:54:45 -0400 Subject: [PATCH] Extend handling of unexpected EOF to the backend In the original issue [1] and commit [2], support for unexpected EOF was added to the frontend to detect when a connection was closed abruptly. Additionally, this allows us to differentiate normal io.EOF errors with unexpected errors in the backend. [1] https://github.com/jackc/pgx/issues/662/ [2] https://github.com/jackc/pgproto3/commit/595780be0f9f581451a23a5151b77f782202ad72 --- backend.go | 6 +++--- backend_test.go | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/backend.go b/backend.go index 232aa11d..c9fa87ff 100644 --- a/backend.go +++ b/backend.go @@ -58,7 +58,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err = b.cr.Next(msgSize) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } code := binary.BigEndian.Uint32(buf) @@ -98,7 +98,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { if !b.partialMsg { header, err := b.cr.Next(5) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.msgType = header[0] @@ -152,7 +152,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { msgBody, err := b.cr.Next(b.bodyLen) if err != nil { - return nil, err + return nil, translateEOFtoErrUnexpectedEOF(err) } b.partialMsg = false diff --git a/backend_test.go b/backend_test.go index 43a3f76c..19970c34 100644 --- a/backend_test.go +++ b/backend_test.go @@ -1,9 +1,11 @@ package pgproto3_test import ( + "io" "testing" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -32,3 +34,23 @@ func TestBackendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestBackendReceiveUnexpectedEOF(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 0, 6}) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) + + // Receive FE msg + server.push([]byte{'F', 0, 0, 0, 6}) + msg, err = backend.ReceiveStartupMessage() + assert.Nil(t, msg) + assert.Equal(t, io.ErrUnexpectedEOF, err) +}