Add pgconn.CheckConn
This commit is contained in:
+21
-9
@@ -65,7 +65,7 @@ type NotificationHandler func(*PgConn, *Notification)
|
||||
|
||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
type PgConn struct {
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection
|
||||
pid uint32 // backend pid
|
||||
secretKey uint32 // key to use to send a cancel query message to the server
|
||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||
@@ -230,22 +230,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
}
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
netConn = nbconn.NewNetConn(netConn, false)
|
||||
nbNetConn := nbconn.NewNetConn(netConn, false)
|
||||
|
||||
pgConn.conn = netConn
|
||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||
pgConn.conn = nbNetConn
|
||||
pgConn.contextWatcher = newContextWatcher(nbNetConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
||||
if fallbackConfig.TLSConfig != nil {
|
||||
tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig)
|
||||
nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig)
|
||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||
if err != nil {
|
||||
netConn.Close()
|
||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||
}
|
||||
|
||||
pgConn.conn = tlsConn
|
||||
pgConn.contextWatcher = newContextWatcher(tlsConn)
|
||||
pgConn.conn = nbTLSConn
|
||||
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
}
|
||||
|
||||
@@ -353,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
||||
)
|
||||
}
|
||||
|
||||
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) {
|
||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1596,6 +1596,18 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
|
||||
return strings.Replace(s, "'", "''", -1), nil
|
||||
}
|
||||
|
||||
// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and
|
||||
// buffering until the read would block or an error occurs. This can be used to check if the server has closed the
|
||||
// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
|
||||
// without the client knowing whether the server received it or not.
|
||||
func (pgConn *PgConn) CheckConn() error {
|
||||
err := pgConn.conn.BufferReadUntilBlock()
|
||||
if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
|
||||
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
||||
ct := make([]byte, len(buf))
|
||||
@@ -1608,7 +1620,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||
// compatibility.
|
||||
type HijackedConn struct {
|
||||
Conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection
|
||||
PID uint32 // backend pid
|
||||
SecretKey uint32 // key to use to send a cancel query message to the server
|
||||
ParameterStatuses map[string]string // parameters that have been reported by the server
|
||||
|
||||
@@ -2059,6 +2059,34 @@ func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCheckConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer c1.Close(context.Background())
|
||||
|
||||
if c1.ParameterStatus("crdb_version") != "" {
|
||||
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||
}
|
||||
|
||||
err = c1.CheckConn()
|
||||
require.NoError(t, err)
|
||||
|
||||
c2, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer c2.Close(context.Background())
|
||||
|
||||
_, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give a little time for the signal to actually kill the backend.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
err = c1.CheckConn()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func Example() {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user