Access underlying net.Conn via method
Also remove some dead code.
This commit is contained in:
@@ -58,7 +58,7 @@ var ErrTLSRefused = errors.New("server refused TLS connection")
|
|||||||
|
|
||||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
type PgConn struct {
|
type PgConn struct {
|
||||||
NetConn net.Conn // the underlying TCP or unix domain socket connection
|
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
PID uint32 // backend pid
|
PID uint32 // backend pid
|
||||||
SecretKey uint32 // key to use to send a cancel query message to the server
|
SecretKey uint32 // key to use to send a cancel query message to the server
|
||||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||||
@@ -132,7 +132,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||||
pgConn.NetConn, err = config.DialFunc(ctx, network, address)
|
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -141,12 +141,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
|
|
||||||
if config.TLSConfig != nil {
|
if config.TLSConfig != nil {
|
||||||
if err := pgConn.startTLS(config.TLSConfig); err != nil {
|
if err := pgConn.startTLS(config.TLSConfig); err != nil {
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn)
|
pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -166,8 +166,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
startupMsg.Parameters["database"] = config.Database
|
startupMsg.Parameters["database"] = config.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil {
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,14 +183,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
pgConn.SecretKey = msg.SecretKey
|
pgConn.SecretKey = msg.SecretKey
|
||||||
case *pgproto3.Authentication:
|
case *pgproto3.Authentication:
|
||||||
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
if config.AfterConnectFunc != nil {
|
if config.AfterConnectFunc != nil {
|
||||||
err := config.AfterConnectFunc(ctx, pgConn)
|
err := config.AfterConnectFunc(ctx, pgConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, fmt.Errorf("AfterConnectFunc: %v", err)
|
return nil, fmt.Errorf("AfterConnectFunc: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,7 +198,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
case *pgproto3.ParameterStatus:
|
case *pgproto3.ParameterStatus:
|
||||||
// handled by ReceiveMessage
|
// handled by ReceiveMessage
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, &PgError{
|
return nil, &PgError{
|
||||||
Severity: msg.Severity,
|
Severity: msg.Severity,
|
||||||
Code: msg.Code,
|
Code: msg.Code,
|
||||||
@@ -219,20 +219,20 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
Routine: msg.Routine,
|
Routine: msg.Routine,
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, errors.New("unexpected message")
|
return nil, errors.New("unexpected message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||||
err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103})
|
err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, 1)
|
response := make([]byte, 1)
|
||||||
if _, err = io.ReadFull(pgConn.NetConn, response); err != nil {
|
if _, err = io.ReadFull(pgConn.conn, response); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
|||||||
return ErrTLSRefused
|
return ErrTLSRefused
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig)
|
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -262,7 +262,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
|||||||
|
|
||||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||||
msg := &pgproto3.PasswordMessage{Password: password}
|
msg := &pgproto3.PasswordMessage{Password: password}
|
||||||
_, err = pgConn.NetConn.Write(msg.Encode(nil))
|
_, err = pgConn.conn.Write(msg.Encode(nil))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,6 +299,11 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conn returns the underlying net.Conn.
|
||||||
|
func (pgConn *PgConn) Conn() net.Conn {
|
||||||
|
return pgConn.conn
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
|
// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
|
||||||
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
|
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
|
||||||
// underlying net.Conn.Close() will always be called regardless of any other errors.
|
// underlying net.Conn.Close() will always be called regardless of any other errors.
|
||||||
@@ -308,22 +313,22 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
pgConn.closed = true
|
pgConn.closed = true
|
||||||
|
|
||||||
defer pgConn.NetConn.Close()
|
defer pgConn.conn.Close()
|
||||||
|
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
|
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContext()
|
defer cleanupContext()
|
||||||
|
|
||||||
_, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4})
|
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pgConn.NetConn.Read(make([]byte, 1))
|
_, err = pgConn.conn.Read(make([]byte, 1))
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return pgConn.NetConn.Close()
|
return pgConn.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
||||||
@@ -380,7 +385,7 @@ type PgResultReader struct {
|
|||||||
// consumed it returns nil. If an error occurs it will be reported on the
|
// consumed it returns nil. If an error occurs it will be reported on the
|
||||||
// returned PgResultReader.
|
// returned PgResultReader.
|
||||||
func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
|
func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader {
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
|
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
for pgConn.pendingReadyForQueryCount > 0 {
|
for pgConn.pendingReadyForQueryCount > 0 {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
@@ -491,14 +496,14 @@ func (rr *PgResultReader) close() {
|
|||||||
func (pgConn *PgConn) Flush(ctx context.Context) error {
|
func (pgConn *PgConn) Flush(ctx context.Context) error {
|
||||||
defer pgConn.resetBatch()
|
defer pgConn.resetBatch()
|
||||||
|
|
||||||
cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn)
|
cleanup := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
n, err := pgConn.NetConn.Write(pgConn.batchBuf)
|
n, err := pgConn.conn.Write(pgConn.batchBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
// Close connection because cannot recover from partially sent message.
|
// Close connection because cannot recover from partially sent message.
|
||||||
pgConn.NetConn.Close()
|
pgConn.conn.Close()
|
||||||
pgConn.closed = true
|
pgConn.closed = true
|
||||||
}
|
}
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return preferContextOverNetTimeoutError(ctx, err)
|
||||||
@@ -563,14 +568,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
|
|||||||
pgConn.resetBatch()
|
pgConn.resetBatch()
|
||||||
|
|
||||||
// Clear any existing timeout
|
// Clear any existing timeout
|
||||||
pgConn.NetConn.SetDeadline(time.Time{})
|
pgConn.conn.SetDeadline(time.Time{})
|
||||||
|
|
||||||
// Try to cancel any in-progress requests
|
// Try to cancel any in-progress requests
|
||||||
for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ {
|
for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ {
|
||||||
pgConn.CancelRequest(ctx)
|
pgConn.CancelRequest(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
|
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContext()
|
defer cleanupContext()
|
||||||
|
|
||||||
for pgConn.pendingReadyForQueryCount > 0 {
|
for pgConn.pendingReadyForQueryCount > 0 {
|
||||||
@@ -683,7 +688,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
|||||||
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
|
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
|
||||||
// the connection config. This is important in high availability configurations where fallback connections may be
|
// the connection config. This is important in high availability configurations where fallback connections may be
|
||||||
// specified or DNS may be used to load balance.
|
// specified or DNS may be used to load balance.
|
||||||
serverAddr := pgConn.NetConn.RemoteAddr()
|
serverAddr := pgConn.conn.RemoteAddr()
|
||||||
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
|
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user