Add SendBytes and ReceiveMessage
This commit is contained in:
+1
-1
@@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||||
msg, err := c.ReceiveMessage()
|
msg, err := c.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
|
||||||
|
// error to call SendBytes while reading the result of a query.
|
||||||
|
//
|
||||||
|
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||||
|
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||||
|
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return linkErrors(err, ErrNoBytesSent)
|
||||||
|
}
|
||||||
|
defer pgConn.unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
|
n, err := pgConn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.hardClose()
|
||||||
|
if n == 0 {
|
||||||
|
err = linkErrors(err, ErrNoBytesSent)
|
||||||
|
}
|
||||||
|
return linkErrors(ctx.Err(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||||
|
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||||
|
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||||
|
// the OnNotification callback.
|
||||||
|
//
|
||||||
|
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||||
|
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||||
|
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return nil, linkErrors(err, ErrNoBytesSent)
|
||||||
|
}
|
||||||
|
defer pgConn.unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
|
msg, err := pgConn.receiveMessage()
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveMessage receives a message without setting up context cancellation
|
||||||
|
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
var msg pgproto3.BackendMessage
|
var msg pgproto3.BackendMessage
|
||||||
var err error
|
var err error
|
||||||
if pgConn.bufferingReceive {
|
if pgConn.bufferingReceive {
|
||||||
@@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||||||
|
|
||||||
readloop:
|
readloop:
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, linkErrors(ctx.Err(), err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
@@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
|||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return linkErrors(ctx.Err(), err)
|
return linkErrors(ctx.Err(), err)
|
||||||
}
|
}
|
||||||
@@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
var commandTag CommandTag
|
var commandTag CommandTag
|
||||||
var pgErr error
|
var pgErr error
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, linkErrors(ctx.Err(), err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
@@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
var pgErr error
|
var pgErr error
|
||||||
pendingCopyInResponse := true
|
pendingCopyInResponse := true
|
||||||
for pendingCopyInResponse {
|
for pendingCopyInResponse {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, linkErrors(ctx.Err(), err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
@@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-signalMessageChan:
|
case <-signalMessageChan:
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, linkErrors(ctx.Err(), err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
@@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
|
|
||||||
// Read results
|
// Read results
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.receiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
return nil, linkErrors(ctx.Err(), err)
|
return nil, linkErrors(ctx.Err(), err)
|
||||||
@@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
|
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
msg, err := mrr.pgConn.ReceiveMessage()
|
msg, err := mrr.pgConn.receiveMessage()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mrr.pgConn.contextWatcher.Unwatch()
|
mrr.pgConn.contextWatcher.Unwatch()
|
||||||
@@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
|||||||
|
|
||||||
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
||||||
if rr.multiResultReader == nil {
|
if rr.multiResultReader == nil {
|
||||||
msg, err = rr.pgConn.ReceiveMessage()
|
msg, err = rr.pgConn.receiveMessage()
|
||||||
} else {
|
} else {
|
||||||
msg, err = rr.multiResultReader.receiveMessage()
|
msg, err = rr.multiResultReader.receiveMessage()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
queryMsg := pgproto3.Query{String: "select 42"}
|
||||||
|
buf := queryMsg.Encode(nil)
|
||||||
|
|
||||||
|
err = pgConn.SendBytes(ctx, buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg, err := pgConn.ReceiveMessage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok := msg.(*pgproto3.RowDescription)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
msg, err = pgConn.ReceiveMessage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = msg.(*pgproto3.DataRow)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
msg, err = pgConn.ReceiveMessage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = msg.(*pgproto3.CommandComplete)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
msg, err = pgConn.ReceiveMessage(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = msg.(*pgproto3.ReadyForQuery)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user