diff --git a/pgconn.go b/pgconn.go index 89d7bb45..44a08cc8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -27,6 +27,8 @@ const ( connStatusBusy ) +const wbufLen = 1024 + // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -192,7 +194,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config - pgConn.wbuf = make([]byte, 0, 1024) + pgConn.wbuf = make([]byte, 0, wbufLen) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -1481,3 +1483,68 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + wbuf: make([]byte, 0, wbufLen), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + return pgConn, nil +} diff --git a/pgconn_test.go b/pgconn_test.go index 2c303d81..34982bb7 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1600,6 +1600,32 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { ensureConnValid(t, pgConn) } +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {