2
0

Restructure sending messages

Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead
of directly writing to the underlying net.Conn. This will allow tracing
messages as well as simplify pipeline mode.
This commit is contained in:
Jack Christensen
2022-05-21 11:06:44 -05:00
parent 989a4835de
commit 5714896b10
9 changed files with 105 additions and 226 deletions
+2 -1
View File
@@ -97,7 +97,8 @@ type sendMessageStep struct {
} }
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
return backend.Send(e.msg) backend.Send(e.msg)
return backend.Flush()
} }
func SendMessage(msg pgproto3.BackendMessage) Step { func SendMessage(msg pgproto3.BackendMessage) Step {
+1 -1
View File
@@ -222,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) {
User: settings["user"], User: settings["user"],
Password: settings["password"], Password: settings["password"],
RuntimeParams: make(map[string]string), RuntimeParams: make(map[string]string),
BuildFrontend: func(r io.Reader, w io.Writer) Frontend { BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w) return pgproto3.NewFrontend(r, w)
}, },
} }
-17
View File
@@ -178,23 +178,6 @@ func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
} }
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
func redactPW(connString string) string { func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil { if u, err := url.Parse(connString); err == nil {
-70
View File
@@ -1,70 +0,0 @@
package pgconn_test
import (
"context"
"io"
"os"
"testing"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// frontendWrapper allows to hijack a regular frontend, and inject a specific response
type frontendWrapper struct {
front pgconn.Frontend
msg pgproto3.BackendMessage
}
// frontendWrapper implements the pgconn.Frontend interface
var _ pgconn.Frontend = (*frontendWrapper)(nil)
func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) {
if f.msg != nil {
return f.msg, nil
}
return f.front.Receive()
}
func TestFrontendFatalErrExec(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
buildFrontend := config.BuildFrontend
var front *frontendWrapper
config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend {
wrapped := buildFrontend(r, w)
front = &frontendWrapper{wrapped, nil}
return front
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, conn)
require.NotNil(t, front)
// set frontend to return a "FATAL" message on next call
front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"}
_, err = conn.Exec(context.Background(), "SELECT 1").ReadAll()
assert.Error(t, err)
err = conn.Close(context.Background())
assert.NoError(t, err)
select {
case <-conn.CleanupDone():
t.Log("ok, CleanupDone() is not blocking")
default:
assert.Fail(t, "connection closed but CleanupDone() still blocking")
}
}
+35 -86
View File
@@ -29,8 +29,6 @@ const (
connStatusBusy connStatusBusy
) )
const wbufLen = 1024
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
// LISTEN/NOTIFY notification. // LISTEN/NOTIFY notification.
type Notice PgError type Notice PgError
@@ -50,7 +48,7 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
@@ -64,11 +62,6 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event. // notice event.
type NotificationHandler func(*PgConn, *Notification) type NotificationHandler func(*PgConn, *Notification)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)
}
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct { type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection conn net.Conn // the underlying TCP or unix domain socket connection
@@ -76,7 +69,7 @@ type PgConn struct {
secretKey uint32 // key to use to send a cancel query message to the server secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte txStatus byte
frontend Frontend frontend *pgproto3.Frontend
config *Config config *Config
@@ -90,7 +83,6 @@ type PgConn struct {
peekedMsg pgproto3.BackendMessage peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources // Reusable / preallocated resources
wbuf []byte // write buffer
resultReader ResultReader resultReader ResultReader
multiResultReader MultiResultReader multiResultReader MultiResultReader
contextWatcher *ctxwatch.ContextWatcher contextWatcher *ctxwatch.ContextWatcher
@@ -230,7 +222,6 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn) pgConn := new(PgConn)
pgConn.config = config pgConn.config = config
pgConn.wbuf = make([]byte, 0, wbufLen)
pgConn.cleanupDone = make(chan struct{}) pgConn.cleanupDone = make(chan struct{})
var err error var err error
@@ -282,7 +273,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
startupMsg.Parameters["database"] = config.Database startupMsg.Parameters["database"] = config.Database
} }
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.frontend.Send(&startupMsg)
if err := pgConn.frontend.Flush(); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: err} return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
} }
@@ -383,9 +375,8 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
} }
func (pgConn *PgConn) txPasswordMessage(password string) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password} pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) return pgConn.frontend.Flush()
return err
} }
func hexMD5(s string) string { func hexMD5(s string) string {
@@ -412,36 +403,6 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch return ch
} }
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
// error to call SendBytes while reading the result of a query.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@@ -797,15 +758,13 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
} }
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name})
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) pgConn.frontend.Send(&pgproto3.Sync{})
buf = (&pgproto3.Sync{}).Encode(buf) err := pgConn.frontend.Flush()
n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, &writeError{err: err, safeToRetry: n == 0} return nil, err
} }
psd := &StatementDescription{Name: name, SQL: sql} psd := &StatementDescription{Name: name, SQL: sql}
@@ -971,15 +930,13 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
} }
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Query{String: sql})
buf = (&pgproto3.Query{String: sql}).Encode(buf) err := pgConn.frontend.Flush()
n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
multiResult.closed = true multiResult.closed = true
multiResult.err = &writeError{err: err, safeToRetry: n == 0} multiResult.err = err
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
} }
@@ -1045,11 +1002,10 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
return result return result
} }
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(buf, result) pgConn.execExtendedSuffix(result)
return result return result
} }
@@ -1072,10 +1028,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
return result return result
} }
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(buf, result) pgConn.execExtendedSuffix(result)
return result return result
} }
@@ -1115,15 +1070,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result return result
} }
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'})
buf = (&pgproto3.Execute{}).Encode(buf) pgConn.frontend.Send(&pgproto3.Execute{})
buf = (&pgproto3.Sync{}).Encode(buf) pgConn.frontend.Send(&pgproto3.Sync{})
n, err := pgConn.conn.Write(buf) err := pgConn.frontend.Flush()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0}) result.concludeCommand(CommandTag{}, err)
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
@@ -1151,14 +1106,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
} }
// Send copy to command // Send copy to command
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Query{String: sql})
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.frontend.Flush()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
pgConn.unlock() pgConn.unlock()
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} return CommandTag{}, err
} }
// Read results // Read results
@@ -1211,13 +1165,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
// Send copy to command // Send copy to command
buf := pgConn.wbuf pgConn.frontend.Send(&pgproto3.Query{String: sql})
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.frontend.Flush()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} return CommandTag{}, err
} }
// Send copy data // Send copy data
@@ -1280,15 +1233,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
close(abortCopyChan) close(abortCopyChan)
buf = buf[:0]
if copyErr == io.EOF || pgErr != nil { if copyErr == io.EOF || pgErr != nil {
copyDone := &pgproto3.CopyDone{} pgConn.frontend.Send(&pgproto3.CopyDone{})
buf = copyDone.Encode(buf)
} else { } else {
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
buf = copyFail.Encode(buf)
} }
_, err = pgConn.conn.Write(buf) err = pgConn.frontend.Flush()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return CommandTag{}, err return CommandTag{}, err
@@ -1692,7 +1642,7 @@ type HijackedConn struct {
SecretKey uint32 // key to use to send a cancel query message to the server SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server ParameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte TxStatus byte
Frontend Frontend Frontend *pgproto3.Frontend
Config *Config Config *Config
} }
@@ -1736,7 +1686,6 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
status: connStatusIdle, status: connStatusIdle,
wbuf: make([]byte, 0, wbufLen),
cleanupDone: make(chan struct{}), cleanupDone: make(chan struct{}),
} }
-43
View File
@@ -1915,49 +1915,6 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
} }
} }
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect.
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
msg, err := pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok := msg.(*pgproto3.RowDescription)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.DataRow)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.CommandComplete)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.ReadyForQuery)
require.True(t, ok)
ensureConnValid(t, pgConn)
}
func TestHijackAndConstruct(t *testing.T) { func TestHijackAndConstruct(t *testing.T) {
t.Parallel() t.Parallel()
+24 -4
View File
@@ -11,6 +11,8 @@ type Backend struct {
cr *chunkReader cr *chunkReader
w io.Writer w io.Writer
wbuf []byte
// Frontend message flyweights // Frontend message flyweights
bind Bind bind Bind
cancelRequest CancelRequest cancelRequest CancelRequest
@@ -47,10 +49,28 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w} return &Backend{cr: cr, w: w}
} }
// Send sends a message to the frontend. // Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
func (b *Backend) Send(msg BackendMessage) error { // called.
_, err := b.w.Write(msg.Encode(nil)) func (b *Backend) Send(msg BackendMessage) {
return err b.wbuf = msg.Encode(b.wbuf)
}
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
if len(b.wbuf) > maxLen {
b.wbuf = make([]byte, 0, maxLen)
} else {
b.wbuf = b.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
} }
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
+24 -4
View File
@@ -12,6 +12,8 @@ type Frontend struct {
cr *chunkReader cr *chunkReader
w io.Writer w io.Writer
wbuf []byte
// Backend message flyweights // Backend message flyweights
authenticationOk AuthenticationOk authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword authenticationCleartextPassword AuthenticationCleartextPassword
@@ -56,10 +58,28 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w} return &Frontend{cr: cr, w: w}
} }
// Send sends a message to the backend. // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
func (f *Frontend) Send(msg FrontendMessage) error { // called.
_, err := f.w.Write(msg.Encode(nil)) func (f *Frontend) Send(msg FrontendMessage) {
return err f.wbuf = msg.Encode(f.wbuf)
}
// Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error {
n, err := f.w.Write(f.wbuf)
const maxLen = 1024
if len(f.wbuf) > maxLen {
f.wbuf = make([]byte, 0, maxLen)
} else {
f.wbuf = f.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
} }
func translateEOFtoErrUnexpectedEOF(err error) error { func translateEOFtoErrUnexpectedEOF(err error) error {
+19
View File
@@ -17,11 +17,13 @@ type Message interface {
Encode(dst []byte) []byte Encode(dst []byte) []byte
} }
// FrontendMessage is a message sent by the frontend (i.e. the client).
type FrontendMessage interface { type FrontendMessage interface {
Message Message
Frontend() // no-op method to distinguish frontend from backend methods Frontend() // no-op method to distinguish frontend from backend methods
} }
// BackendMessage is a message sent by the backend (i.e. the server).
type BackendMessage interface { type BackendMessage interface {
Message Message
Backend() // no-op method to distinguish frontend from backend methods Backend() // no-op method to distinguish frontend from backend methods
@@ -50,6 +52,23 @@ func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType) return fmt.Sprintf("%s body is invalid", e.messageType)
} }
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
// getValueFromJSON gets the value from a protocol message representation in JSON. // getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) { func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil { if v == nil {