2
0

Add SendBytes and ReceiveMessage

This commit is contained in:
Jack Christensen
2019-08-20 14:11:16 -05:00
parent 0a2ed72cf7
commit d364370a31
3 changed files with 108 additions and 11 deletions
+1 -1
View File
@@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
}
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
msg, err := c.ReceiveMessage()
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
+67 -10
View File
@@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.conn.Close()
return nil, err
@@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch
}
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
// error to call SendBytes while reading the result of a query.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return linkErrors(err, ErrNoBytesSent)
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return linkErrors(ctx.Err(), err)
}
return nil
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
// the OnNotification callback.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
msg, err := pgConn.receiveMessage()
return msg, err
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
@@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
readloop:
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
@@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
defer pgConn.contextWatcher.Unwatch()
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
return linkErrors(ctx.Err(), err)
}
@@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
var commandTag CommandTag
var pgErr error
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
@@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var pgErr error
pendingCopyInResponse := true
for pendingCopyInResponse {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
@@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
select {
case <-signalMessageChan:
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
@@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// Read results
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
@@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
}
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := mrr.pgConn.ReceiveMessage()
msg, err := mrr.pgConn.receiveMessage()
if err != nil {
mrr.pgConn.contextWatcher.Unwatch()
@@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
if rr.multiResultReader == nil {
msg, err = rr.pgConn.ReceiveMessage()
msg, err = rr.pgConn.receiveMessage()
} else {
msg, err = rr.multiResultReader.receiveMessage()
}
+40
View File
@@ -18,6 +18,7 @@ import (
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
"github.com/stretchr/testify/assert"
@@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
msg, err := pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok := msg.(*pgproto3.RowDescription)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.DataRow)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.CommandComplete)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.ReadyForQuery)
require.True(t, ok)
ensureConnValid(t, pgConn)
}
func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
if err != nil {