Add pgconn.CheckConn
This commit is contained in:
+10
-1
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/internal/nbconn"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -1236,7 +1237,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type queryRecorder struct {
|
type queryRecorder struct {
|
||||||
conn net.Conn
|
conn nbconn.Conn
|
||||||
writeBuf []byte
|
writeBuf []byte
|
||||||
readCount int
|
readCount int
|
||||||
}
|
}
|
||||||
@@ -1252,6 +1253,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
|
|||||||
return qr.conn.Write(b)
|
return qr.conn.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (qr *queryRecorder) BufferReadUntilBlock() error {
|
||||||
|
return qr.conn.BufferReadUntilBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qr *queryRecorder) Flush() error {
|
||||||
|
return qr.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
func (qr *queryRecorder) Close() error {
|
func (qr *queryRecorder) Close() error {
|
||||||
return qr.conn.Close()
|
return qr.conn.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
+36
-13
@@ -13,6 +13,7 @@ package nbconn
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -46,11 +47,16 @@ func (*wouldBlockError) Error() string {
|
|||||||
func (*wouldBlockError) Timeout() bool { return true }
|
func (*wouldBlockError) Timeout() bool { return true }
|
||||||
func (*wouldBlockError) Temporary() bool { return true }
|
func (*wouldBlockError) Temporary() bool { return true }
|
||||||
|
|
||||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the
|
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
|
||||||
// underlying connection.
|
// the underlying connection.
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
|
// Flush flushes any buffered writes.
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
|
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
|
||||||
|
BufferReadUntilBlock() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||||
@@ -303,24 +309,35 @@ func (c *NetConn) flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *NetConn) BufferReadUntilBlock() error {
|
||||||
|
for {
|
||||||
|
buf := iobufpool.Get(8 * 1024)
|
||||||
|
n, err := c.nonblockingRead(buf)
|
||||||
|
if n > 0 {
|
||||||
|
buf = buf[:n]
|
||||||
|
c.readQueue.pushBack(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrWouldBlock) {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||||
stopChan = make(chan struct{})
|
stopChan = make(chan struct{})
|
||||||
errChan = make(chan error, 1)
|
errChan = make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
buf := iobufpool.Get(8 * 1024)
|
err := c.BufferReadUntilBlock()
|
||||||
n, err := c.nonblockingRead(buf)
|
|
||||||
if n > 0 {
|
|
||||||
buf = buf[:n]
|
|
||||||
c.readQueue.pushBack(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, ErrWouldBlock) {
|
errChan <- err
|
||||||
errChan <- err
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -456,6 +473,11 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// syscall read did not return an error and 0 bytes were read means EOF.
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,6 +516,7 @@ type TLSConn struct {
|
|||||||
|
|
||||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||||
|
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
|
||||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package nbconn_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -345,6 +346,34 @@ func TestNonBlockingRead(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBufferNonBlockingRead(t *testing.T) {
|
||||||
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
err := conn.BufferReadUntilBlock()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := remote.Write([]byte("okay"))
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
err = conn.BufferReadUntilBlock()
|
||||||
|
if !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 4, n)
|
||||||
|
require.Equal(t, []byte("okay"), buf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||||
|
|
||||||
|
|||||||
+21
-9
@@ -65,7 +65,7 @@ type NotificationHandler func(*PgConn, *Notification)
|
|||||||
|
|
||||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
type PgConn struct {
|
type PgConn struct {
|
||||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection
|
||||||
pid uint32 // backend pid
|
pid uint32 // backend pid
|
||||||
secretKey uint32 // key to use to send a cancel query message to the server
|
secretKey uint32 // key to use to send a cancel query message to the server
|
||||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||||
@@ -230,22 +230,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
}
|
}
|
||||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||||
}
|
}
|
||||||
netConn = nbconn.NewNetConn(netConn, false)
|
nbNetConn := nbconn.NewNetConn(netConn, false)
|
||||||
|
|
||||||
pgConn.conn = netConn
|
pgConn.conn = nbNetConn
|
||||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
pgConn.contextWatcher = newContextWatcher(nbNetConn)
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
|
|
||||||
if fallbackConfig.TLSConfig != nil {
|
if fallbackConfig.TLSConfig != nil {
|
||||||
tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig)
|
nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig)
|
||||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
netConn.Close()
|
netConn.Close()
|
||||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = tlsConn
|
pgConn.conn = nbTLSConn
|
||||||
pgConn.contextWatcher = newContextWatcher(tlsConn)
|
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) {
|
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) {
|
||||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1596,6 +1596,18 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
|
|||||||
return strings.Replace(s, "'", "''", -1), nil
|
return strings.Replace(s, "'", "''", -1), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and
|
||||||
|
// buffering until the read would block or an error occurs. This can be used to check if the server has closed the
|
||||||
|
// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
|
||||||
|
// without the client knowing whether the server received it or not.
|
||||||
|
func (pgConn *PgConn) CheckConn() error {
|
||||||
|
err := pgConn.conn.BufferReadUntilBlock()
|
||||||
|
if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
|
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
|
||||||
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
||||||
ct := make([]byte, len(buf))
|
ct := make([]byte, len(buf))
|
||||||
@@ -1608,7 +1620,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
|||||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||||
// compatibility.
|
// compatibility.
|
||||||
type HijackedConn struct {
|
type HijackedConn struct {
|
||||||
Conn net.Conn // the underlying TCP or unix domain socket connection
|
Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection
|
||||||
PID uint32 // backend pid
|
PID uint32 // backend pid
|
||||||
SecretKey uint32 // key to use to send a cancel query message to the server
|
SecretKey uint32 // key to use to send a cancel query message to the server
|
||||||
ParameterStatuses map[string]string // parameters that have been reported by the server
|
ParameterStatuses map[string]string // parameters that have been reported by the server
|
||||||
|
|||||||
@@ -2059,6 +2059,34 @@ func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCheckConn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
c1, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer c1.Close(context.Background())
|
||||||
|
|
||||||
|
if c1.ParameterStatus("crdb_version") != "" {
|
||||||
|
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c1.CheckConn()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
c2, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer c2.Close(context.Background())
|
||||||
|
|
||||||
|
_, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Give a little time for the signal to actually kill the backend.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
err = c1.CheckConn()
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user