Add SendBytes and ReceiveMessage
This commit is contained in:
+1
-1
@@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
}
|
||||
|
||||
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||
msg, err := c.ReceiveMessage()
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
@@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||
return ch
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
|
||||
// error to call SendBytes while reading the result of a query.
|
||||
//
|
||||
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return linkErrors(ctx.Err(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||
// the OnNotification callback.
|
||||
//
|
||||
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
msg, err := pgConn.receiveMessage()
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// receiveMessage receives a message without setting up context cancellation
|
||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
var msg pgproto3.BackendMessage
|
||||
var err error
|
||||
if pgConn.bufferingReceive {
|
||||
@@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||
|
||||
readloop:
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
@@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
}
|
||||
@@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
var commandTag CommandTag
|
||||
var pgErr error
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
@@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
var pgErr error
|
||||
pendingCopyInResponse := true
|
||||
for pendingCopyInResponse {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
@@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
|
||||
select {
|
||||
case <-signalMessageChan:
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
@@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
|
||||
// Read results
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
@@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
|
||||
}
|
||||
|
||||
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
msg, err := mrr.pgConn.ReceiveMessage()
|
||||
msg, err := mrr.pgConn.receiveMessage()
|
||||
|
||||
if err != nil {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
@@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
||||
|
||||
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
||||
if rr.multiResultReader == nil {
|
||||
msg, err = rr.pgConn.ReceiveMessage()
|
||||
msg, err = rr.pgConn.receiveMessage()
|
||||
} else {
|
||||
msg, err = rr.multiResultReader.receiveMessage()
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
errors "golang.org/x/xerrors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
queryMsg := pgproto3.Query{String: "select 42"}
|
||||
buf := queryMsg.Encode(nil)
|
||||
|
||||
err = pgConn.SendBytes(ctx, buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg, err := pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok := msg.(*pgproto3.RowDescription)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.DataRow)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.CommandComplete)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.ReadyForQuery)
|
||||
require.True(t, ok)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func Example() {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user