Restructure sending messages
Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead of directly writing to the underlying net.Conn. This will allow tracing messages as well as simplify pipeline mode.
This commit is contained in:
@@ -97,7 +97,8 @@ type sendMessageStep struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
|
||||||
return backend.Send(e.msg)
|
backend.Send(e.msg)
|
||||||
|
return backend.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendMessage(msg pgproto3.BackendMessage) Step {
|
func SendMessage(msg pgproto3.BackendMessage) Step {
|
||||||
|
|||||||
+1
-1
@@ -222,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
User: settings["user"],
|
User: settings["user"],
|
||||||
Password: settings["password"],
|
Password: settings["password"],
|
||||||
RuntimeParams: make(map[string]string),
|
RuntimeParams: make(map[string]string),
|
||||||
BuildFrontend: func(r io.Reader, w io.Writer) Frontend {
|
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||||
return pgproto3.NewFrontend(r, w)
|
return pgproto3.NewFrontend(r, w)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -178,23 +178,6 @@ func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
|||||||
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeError struct {
|
|
||||||
err error
|
|
||||||
safeToRetry bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *writeError) Error() string {
|
|
||||||
return fmt.Sprintf("write failed: %s", e.err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *writeError) SafeToRetry() bool {
|
|
||||||
return e.safeToRetry
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *writeError) Unwrap() error {
|
|
||||||
return e.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func redactPW(connString string) string {
|
func redactPW(connString string) string {
|
||||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
if u, err := url.Parse(connString); err == nil {
|
if u, err := url.Parse(connString); err == nil {
|
||||||
|
|||||||
@@ -1,70 +0,0 @@
|
|||||||
package pgconn_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// frontendWrapper allows to hijack a regular frontend, and inject a specific response
|
|
||||||
type frontendWrapper struct {
|
|
||||||
front pgconn.Frontend
|
|
||||||
|
|
||||||
msg pgproto3.BackendMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
// frontendWrapper implements the pgconn.Frontend interface
|
|
||||||
var _ pgconn.Frontend = (*frontendWrapper)(nil)
|
|
||||||
|
|
||||||
func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) {
|
|
||||||
if f.msg != nil {
|
|
||||||
return f.msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.front.Receive()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFrontendFatalErrExec(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
buildFrontend := config.BuildFrontend
|
|
||||||
var front *frontendWrapper
|
|
||||||
|
|
||||||
config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend {
|
|
||||||
wrapped := buildFrontend(r, w)
|
|
||||||
front = &frontendWrapper{wrapped, nil}
|
|
||||||
|
|
||||||
return front
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, conn)
|
|
||||||
require.NotNil(t, front)
|
|
||||||
|
|
||||||
// set frontend to return a "FATAL" message on next call
|
|
||||||
front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"}
|
|
||||||
|
|
||||||
_, err = conn.Exec(context.Background(), "SELECT 1").ReadAll()
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
err = conn.Close(context.Background())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-conn.CleanupDone():
|
|
||||||
t.Log("ok, CleanupDone() is not blocking")
|
|
||||||
|
|
||||||
default:
|
|
||||||
assert.Fail(t, "connection closed but CleanupDone() still blocking")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+35
-86
@@ -29,8 +29,6 @@ const (
|
|||||||
connStatusBusy
|
connStatusBusy
|
||||||
)
|
)
|
||||||
|
|
||||||
const wbufLen = 1024
|
|
||||||
|
|
||||||
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
|
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
|
||||||
// LISTEN/NOTIFY notification.
|
// LISTEN/NOTIFY notification.
|
||||||
type Notice PgError
|
type Notice PgError
|
||||||
@@ -50,7 +48,7 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
|||||||
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
|
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
|
||||||
|
|
||||||
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
||||||
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
|
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
|
||||||
|
|
||||||
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
|
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
|
||||||
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
|
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
|
||||||
@@ -64,11 +62,6 @@ type NoticeHandler func(*PgConn, *Notice)
|
|||||||
// notice event.
|
// notice event.
|
||||||
type NotificationHandler func(*PgConn, *Notification)
|
type NotificationHandler func(*PgConn, *Notification)
|
||||||
|
|
||||||
// Frontend used to receive messages from backend.
|
|
||||||
type Frontend interface {
|
|
||||||
Receive() (pgproto3.BackendMessage, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
type PgConn struct {
|
type PgConn struct {
|
||||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
@@ -76,7 +69,7 @@ type PgConn struct {
|
|||||||
secretKey uint32 // key to use to send a cancel query message to the server
|
secretKey uint32 // key to use to send a cancel query message to the server
|
||||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||||
txStatus byte
|
txStatus byte
|
||||||
frontend Frontend
|
frontend *pgproto3.Frontend
|
||||||
|
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
@@ -90,7 +83,6 @@ type PgConn struct {
|
|||||||
peekedMsg pgproto3.BackendMessage
|
peekedMsg pgproto3.BackendMessage
|
||||||
|
|
||||||
// Reusable / preallocated resources
|
// Reusable / preallocated resources
|
||||||
wbuf []byte // write buffer
|
|
||||||
resultReader ResultReader
|
resultReader ResultReader
|
||||||
multiResultReader MultiResultReader
|
multiResultReader MultiResultReader
|
||||||
contextWatcher *ctxwatch.ContextWatcher
|
contextWatcher *ctxwatch.ContextWatcher
|
||||||
@@ -230,7 +222,6 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
|
|||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
pgConn := new(PgConn)
|
pgConn := new(PgConn)
|
||||||
pgConn.config = config
|
pgConn.config = config
|
||||||
pgConn.wbuf = make([]byte, 0, wbufLen)
|
|
||||||
pgConn.cleanupDone = make(chan struct{})
|
pgConn.cleanupDone = make(chan struct{})
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@@ -282,7 +273,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
startupMsg.Parameters["database"] = config.Database
|
startupMsg.Parameters["database"] = config.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
|
pgConn.frontend.Send(&startupMsg)
|
||||||
|
if err := pgConn.frontend.Flush(); err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
||||||
}
|
}
|
||||||
@@ -383,9 +375,8 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||||
msg := &pgproto3.PasswordMessage{Password: password}
|
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
|
||||||
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
|
return pgConn.frontend.Flush()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func hexMD5(s string) string {
|
func hexMD5(s string) string {
|
||||||
@@ -412,36 +403,6 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
|
|
||||||
// error to call SendBytes while reading the result of a query.
|
|
||||||
//
|
|
||||||
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
|
||||||
// See https://www.postgresql.org/docs/current/protocol.html.
|
|
||||||
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
|
||||||
if err := pgConn.lock(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer pgConn.unlock()
|
|
||||||
|
|
||||||
if ctx != context.Background() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return newContextAlreadyDoneError(ctx)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
|
||||||
if err != nil {
|
|
||||||
pgConn.asyncClose()
|
|
||||||
return &writeError{err: err, safeToRetry: n == 0}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||||
@@ -797,15 +758,13 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||||
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||||
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Sync{})
|
||||||
buf = (&pgproto3.Sync{}).Encode(buf)
|
err := pgConn.frontend.Flush()
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return nil, &writeError{err: err, safeToRetry: n == 0}
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
psd := &StatementDescription{Name: name, SQL: sql}
|
psd := &StatementDescription{Name: name, SQL: sql}
|
||||||
@@ -971,15 +930,13 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Query{String: sql})
|
||||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
err := pgConn.frontend.Flush()
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
pgConn.contextWatcher.Unwatch()
|
pgConn.contextWatcher.Unwatch()
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
|
multiResult.err = err
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
@@ -1045,11 +1002,10 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
||||||
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||||
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
|
|
||||||
|
|
||||||
pgConn.execExtendedSuffix(buf, result)
|
pgConn.execExtendedSuffix(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
@@ -1072,10 +1028,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||||
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
|
|
||||||
|
|
||||||
pgConn.execExtendedSuffix(buf, result)
|
pgConn.execExtendedSuffix(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
@@ -1115,15 +1070,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
|
func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
|
||||||
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'})
|
||||||
buf = (&pgproto3.Execute{}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Execute{})
|
||||||
buf = (&pgproto3.Sync{}).Encode(buf)
|
pgConn.frontend.Send(&pgproto3.Sync{})
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
err := pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0})
|
result.concludeCommand(CommandTag{}, err)
|
||||||
pgConn.contextWatcher.Unwatch()
|
pgConn.contextWatcher.Unwatch()
|
||||||
result.closed = true
|
result.closed = true
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
@@ -1151,14 +1106,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send copy to command
|
// Send copy to command
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Query{String: sql})
|
||||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
err := pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
pgConn.unlock()
|
pgConn.unlock()
|
||||||
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
|
return CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read results
|
// Read results
|
||||||
@@ -1211,13 +1165,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send copy to command
|
// Send copy to command
|
||||||
buf := pgConn.wbuf
|
pgConn.frontend.Send(&pgproto3.Query{String: sql})
|
||||||
buf = (&pgproto3.Query{String: sql}).Encode(buf)
|
|
||||||
|
|
||||||
n, err := pgConn.conn.Write(buf)
|
err := pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
|
return CommandTag{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send copy data
|
// Send copy data
|
||||||
@@ -1280,15 +1233,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
}
|
}
|
||||||
close(abortCopyChan)
|
close(abortCopyChan)
|
||||||
|
|
||||||
buf = buf[:0]
|
|
||||||
if copyErr == io.EOF || pgErr != nil {
|
if copyErr == io.EOF || pgErr != nil {
|
||||||
copyDone := &pgproto3.CopyDone{}
|
pgConn.frontend.Send(&pgproto3.CopyDone{})
|
||||||
buf = copyDone.Encode(buf)
|
|
||||||
} else {
|
} else {
|
||||||
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
|
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||||
buf = copyFail.Encode(buf)
|
|
||||||
}
|
}
|
||||||
_, err = pgConn.conn.Write(buf)
|
err = pgConn.frontend.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.asyncClose()
|
pgConn.asyncClose()
|
||||||
return CommandTag{}, err
|
return CommandTag{}, err
|
||||||
@@ -1692,7 +1642,7 @@ type HijackedConn struct {
|
|||||||
SecretKey uint32 // key to use to send a cancel query message to the server
|
SecretKey uint32 // key to use to send a cancel query message to the server
|
||||||
ParameterStatuses map[string]string // parameters that have been reported by the server
|
ParameterStatuses map[string]string // parameters that have been reported by the server
|
||||||
TxStatus byte
|
TxStatus byte
|
||||||
Frontend Frontend
|
Frontend *pgproto3.Frontend
|
||||||
Config *Config
|
Config *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1736,7 +1686,6 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
|||||||
|
|
||||||
status: connStatusIdle,
|
status: connStatusIdle,
|
||||||
|
|
||||||
wbuf: make([]byte, 0, wbufLen),
|
|
||||||
cleanupDone: make(chan struct{}),
|
cleanupDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1915,49 +1915,6 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect.
|
|
||||||
|
|
||||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer closeConn(t, pgConn)
|
|
||||||
|
|
||||||
queryMsg := pgproto3.Query{String: "select 42"}
|
|
||||||
buf := queryMsg.Encode(nil)
|
|
||||||
|
|
||||||
err = pgConn.SendBytes(ctx, buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
msg, err := pgConn.ReceiveMessage(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok := msg.(*pgproto3.RowDescription)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
msg, err = pgConn.ReceiveMessage(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok = msg.(*pgproto3.DataRow)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
msg, err = pgConn.ReceiveMessage(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok = msg.(*pgproto3.CommandComplete)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
msg, err = pgConn.ReceiveMessage(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok = msg.(*pgproto3.ReadyForQuery)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
ensureConnValid(t, pgConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHijackAndConstruct(t *testing.T) {
|
func TestHijackAndConstruct(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
+24
-4
@@ -11,6 +11,8 @@ type Backend struct {
|
|||||||
cr *chunkReader
|
cr *chunkReader
|
||||||
w io.Writer
|
w io.Writer
|
||||||
|
|
||||||
|
wbuf []byte
|
||||||
|
|
||||||
// Frontend message flyweights
|
// Frontend message flyweights
|
||||||
bind Bind
|
bind Bind
|
||||||
cancelRequest CancelRequest
|
cancelRequest CancelRequest
|
||||||
@@ -47,10 +49,28 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
|
|||||||
return &Backend{cr: cr, w: w}
|
return &Backend{cr: cr, w: w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a message to the frontend.
|
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
|
||||||
func (b *Backend) Send(msg BackendMessage) error {
|
// called.
|
||||||
_, err := b.w.Write(msg.Encode(nil))
|
func (b *Backend) Send(msg BackendMessage) {
|
||||||
return err
|
b.wbuf = msg.Encode(b.wbuf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes any pending messages to the frontend (i.e. the client).
|
||||||
|
func (b *Backend) Flush() error {
|
||||||
|
n, err := b.w.Write(b.wbuf)
|
||||||
|
|
||||||
|
const maxLen = 1024
|
||||||
|
if len(b.wbuf) > maxLen {
|
||||||
|
b.wbuf = make([]byte, 0, maxLen)
|
||||||
|
} else {
|
||||||
|
b.wbuf = b.wbuf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &writeError{err: err, safeToRetry: n == 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||||
|
|||||||
+24
-4
@@ -12,6 +12,8 @@ type Frontend struct {
|
|||||||
cr *chunkReader
|
cr *chunkReader
|
||||||
w io.Writer
|
w io.Writer
|
||||||
|
|
||||||
|
wbuf []byte
|
||||||
|
|
||||||
// Backend message flyweights
|
// Backend message flyweights
|
||||||
authenticationOk AuthenticationOk
|
authenticationOk AuthenticationOk
|
||||||
authenticationCleartextPassword AuthenticationCleartextPassword
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||||
@@ -56,10 +58,28 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
|
|||||||
return &Frontend{cr: cr, w: w}
|
return &Frontend{cr: cr, w: w}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a message to the backend.
|
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
|
||||||
func (f *Frontend) Send(msg FrontendMessage) error {
|
// called.
|
||||||
_, err := f.w.Write(msg.Encode(nil))
|
func (f *Frontend) Send(msg FrontendMessage) {
|
||||||
return err
|
f.wbuf = msg.Encode(f.wbuf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes any pending messages to the backend (i.e. the server).
|
||||||
|
func (f *Frontend) Flush() error {
|
||||||
|
n, err := f.w.Write(f.wbuf)
|
||||||
|
|
||||||
|
const maxLen = 1024
|
||||||
|
if len(f.wbuf) > maxLen {
|
||||||
|
f.wbuf = make([]byte, 0, maxLen)
|
||||||
|
} else {
|
||||||
|
f.wbuf = f.wbuf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return &writeError{err: err, safeToRetry: n == 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func translateEOFtoErrUnexpectedEOF(err error) error {
|
func translateEOFtoErrUnexpectedEOF(err error) error {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ type Message interface {
|
|||||||
Encode(dst []byte) []byte
|
Encode(dst []byte) []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FrontendMessage is a message sent by the frontend (i.e. the client).
|
||||||
type FrontendMessage interface {
|
type FrontendMessage interface {
|
||||||
Message
|
Message
|
||||||
Frontend() // no-op method to distinguish frontend from backend methods
|
Frontend() // no-op method to distinguish frontend from backend methods
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BackendMessage is a message sent by the backend (i.e. the server).
|
||||||
type BackendMessage interface {
|
type BackendMessage interface {
|
||||||
Message
|
Message
|
||||||
Backend() // no-op method to distinguish frontend from backend methods
|
Backend() // no-op method to distinguish frontend from backend methods
|
||||||
@@ -50,6 +52,23 @@ func (e *invalidMessageFormatErr) Error() string {
|
|||||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type writeError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Error() string {
|
||||||
|
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
||||||
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user