From f5faed65688c703f48a5712b2a41fc7db928fea9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:00:08 -0600 Subject: [PATCH] Access underlying net.Conn via method Also remove some dead code. --- pgconn.go | 57 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/pgconn.go b/pgconn.go index fef113e0..776141f9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -58,7 +58,7 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - NetConn net.Conn // the underlying TCP or unix domain socket connection + conn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -132,7 +132,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.NetConn, err = config.DialFunc(ctx, network, address) + pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } @@ -141,12 +141,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) if err != nil { return nil, err } @@ -166,8 +166,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { - pgConn.NetConn.Close() + if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.conn.Close() return nil, err } @@ -183,14 +183,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) } } @@ -198,7 +198,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, @@ -219,20 +219,20 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, errors.New("unexpected message") } } } func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + if _, err = io.ReadFull(pgConn.conn, response); err != nil { return } @@ -240,7 +240,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) return nil } @@ -262,7 +262,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.NetConn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(nil)) return err } @@ -299,6 +299,11 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -308,22 +313,22 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.closed = true - defer pgConn.NetConn.Close() + defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { return preferContextOverNetTimeoutError(ctx, err) } - _, err = pgConn.NetConn.Read(make([]byte, 1)) + _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { return preferContextOverNetTimeoutError(ctx, err) } - return pgConn.NetConn.Close() + return pgConn.conn.Close() } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -380,7 +385,7 @@ type PgResultReader struct { // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() @@ -491,14 +496,14 @@ func (rr *PgResultReader) close() { func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanup() - n, err := pgConn.NetConn.Write(pgConn.batchBuf) + n, err := pgConn.conn.Write(pgConn.batchBuf) if err != nil { if n > 0 { // Close connection because cannot recover from partially sent message. - pgConn.NetConn.Close() + pgConn.conn.Close() pgConn.closed = true } return preferContextOverNetTimeoutError(ctx, err) @@ -563,14 +568,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { pgConn.resetBatch() // Clear any existing timeout - pgConn.NetConn.SetDeadline(time.Time{}) + pgConn.conn.SetDeadline(time.Time{}) // Try to cancel any in-progress requests for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { pgConn.CancelRequest(ctx) } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() for pgConn.pendingReadyForQueryCount > 0 { @@ -683,7 +688,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. - serverAddr := pgConn.NetConn.RemoteAddr() + serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err