From d364370a31359546fb19828f737073b19a56f812 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 14:11:16 -0500 Subject: [PATCH] Add SendBytes and ReceiveMessage --- auth_scram.go | 2 +- pgconn.go | 77 +++++++++++++++++++++++++++++++++++++++++++------- pgconn_test.go | 40 ++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index bdaf3e92..4409a080 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { - msg, err := c.ReceiveMessage() + msg, err := c.receiveMessage() if err != nil { return nil, err } diff --git a/pgconn.go b/pgconn.go index 63e19ed1..abbc2d10 100644 --- a/pgconn.go +++ b/pgconn.go @@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() return nil, err @@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + return linkErrors(ctx.Err(), err) + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + msg, err := pgConn.receiveMessage() + return msg, err +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ readloop: for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { defer pgConn.contextWatcher.Unwatch() for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { return linkErrors(ctx.Err(), err) } @@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var commandTag CommandTag var pgErr error for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Read results for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() + msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() @@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() + msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } diff --git a/pgconn_test.go b/pgconn_test.go index 1b90b9d2..f385bc19 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" @@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {