diff --git a/bench_test.go b/bench_test.go index dd391c71..47404114 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" @@ -1120,7 +1119,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } type queryRecorder struct { - conn nbconn.Conn + conn net.Conn writeBuf []byte readCount int } @@ -1136,14 +1135,6 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) { return qr.conn.Write(b) } -func (qr *queryRecorder) BufferReadUntilBlock() error { - return qr.conn.BufferReadUntilBlock() -} - -func (qr *queryRecorder) Flush() error { - return qr.conn.Flush() -} - func (qr *queryRecorder) Close() error { return qr.conn.Close() } diff --git a/internal/nbconn/bufferqueue.go b/internal/nbconn/bufferqueue.go deleted file mode 100644 index 4bf25481..00000000 --- a/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( - "sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { - lock sync.Mutex - queue []*[]byte - r, w int -} - -func (bq *bufferQueue) pushBack(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - bq.queue[bq.w] = buf - bq.w++ -} - -func (bq *bufferQueue) pushFront(buf *[]byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) - bq.queue[bq.r] = buf - bq.w++ -} - -func (bq *bufferQueue) popFront() *[]byte { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.r == bq.w { - return nil - } - - buf := bq.queue[bq.r] - bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. - bq.r++ - - if bq.r == bq.w { - bq.r = 0 - bq.w = 0 - if len(bq.queue) > minBufferQueueLen { - bq.queue = make([]*[]byte, minBufferQueueLen) - } - } - - return buf -} - -func (bq *bufferQueue) growQueue() { - desiredLen := (len(bq.queue) + 1) * 3 / 2 - if desiredLen < minBufferQueueLen { - desiredLen = minBufferQueueLen - } - - newQueue := make([]*[]byte, desiredLen) - copy(newQueue, bq.queue) - bq.queue = newQueue -} diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go deleted file mode 100644 index a9fcd62e..00000000 --- a/internal/nbconn/nbconn.go +++ /dev/null @@ -1,550 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( - "crypto/tls" - "errors" - "net" - "os" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWriteWaitDuration = 100 * time.Millisecond -const minNonblockingReadWaitDuration = time.Microsecond -const maxNonblockingReadWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { - return "would block" -} - -func (*wouldBlockError) Timeout() bool { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { - net.Conn - - // Flush flushes any buffered writes. - Flush() error - - // BufferReadUntilBlock reads and buffers any successfully read bytes until the read would block. - BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { - // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit - // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and - // https://github.com/jackc/pgx/issues/1307. Only access with atomics - closed int64 // 0 = not closed, 1 = closed - - conn net.Conn - rawConn syscall.RawConn - - readQueue bufferQueue - writeQueue bufferQueue - - readFlushLock sync.Mutex - // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockWriteFunc func(fd uintptr) (done bool) - nonblockWriteBuf []byte - nonblockWriteErr error - nonblockWriteN int - - // non-blocking reads with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockReadFunc func(fd uintptr) (done bool) - nonblockReadBuf []byte - nonblockReadErr error - nonblockReadN int - - readDeadlineLock sync.Mutex - readDeadline time.Time - readNonblocking bool - fakeNonBlockingShortReadCount int - fakeNonblockingReadWaitDuration time.Duration - - writeDeadlineLock sync.Mutex - writeDeadline time.Time - // The following fields are used in nbconn_real_non_block_windows - - // nbOperMu Used to prevent concurrent SetBlockingMode calls - nbOperMu sync.Mutex - // nbOperCnt Tracks how many operations performing simultaneously - nbOperCnt int -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { - nc := &NetConn{ - conn: conn, - fakeNonblockingReadWaitDuration: maxNonblockingReadWaitDuration, - } - - if !fakeNonBlockingIO { - if sc, ok := conn.(syscall.Conn); ok { - if rawConn, err := sc.SyscallConn(); err == nil { - nc.rawConn = rawConn - } - } - } - - return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - - err = c.flush() - if err != nil { - return 0, err - } - - for n < len(b) { - buf := c.readQueue.popFront() - if buf == nil { - break - } - copiedN := copy(b[n:], *buf) - if copiedN < len(*buf) { - *buf = (*buf)[copiedN:] - c.readQueue.pushFront(buf) - } else { - iobufpool.Put(buf) - } - n += copiedN - } - - // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to - // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. - if n > 0 { - return n, nil - } - - var readNonblocking bool - c.readDeadlineLock.Lock() - readNonblocking = c.readNonblocking - c.readDeadlineLock.Unlock() - - var readN int - if readNonblocking { - if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil { - return n, setSockModeErr - } - - defer func() { - _ = c.SetBlockingMode(true) - }() - - readN, err = c.nonblockingRead(b[n:]) - } else { - readN, err = c.conn.Read(b[n:]) - } - n += readN - return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - buf := iobufpool.Get(len(b)) - copy(*buf, b) - c.writeQueue.pushBack(buf) - return len(b), nil -} - -func (c *NetConn) Close() (err error) { - swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) - if !swapped { - return errClosed - } - - defer func() { - closeErr := c.conn.Close() - if err == nil { - err = closeErr - } - }() - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - err = c.flush() - if err != nil { - return err - } - - return nil -} - -func (c *NetConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { - err := c.SetReadDeadline(t) - if err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - if c.readDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.readDeadline = t - return nil - } - - if t == NonBlockingDeadline { - c.readNonblocking = true - t = time.Time{} - } else { - c.readNonblocking = false - } - - c.readDeadline = t - - return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - if c.writeDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.writeDeadline = t - return nil - } - - c.writeDeadline = t - - return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { - if c.isClosed() { - return errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { - var stopChan chan struct{} - var errChan chan error - - if err := c.SetBlockingMode(false); err != nil { - return err - } - - defer func() { - _ = c.SetBlockingMode(true) - }() - - defer func() { - if stopChan != nil { - select { - case stopChan <- struct{}{}: - case <-errChan: - } - } - }() - - for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { - remainingBuf := *buf - for len(remainingBuf) > 0 { - n, err := c.nonblockingWrite(remainingBuf) - remainingBuf = remainingBuf[n:] - if err != nil { - if !errors.Is(err, ErrWouldBlock) { - *buf = (*buf)[:len(remainingBuf)] - copy(*buf, remainingBuf) - c.writeQueue.pushFront(buf) - return err - } - - // Writing was blocked. Reading might unblock it. - if stopChan == nil { - stopChan, errChan = c.bufferNonblockingRead() - } - - select { - case err := <-errChan: - stopChan = nil - return err - default: - } - - } - } - iobufpool.Put(buf) - } - - return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { - if err := c.SetBlockingMode(false); err != nil { - return err - } - - defer func() { - _ = c.SetBlockingMode(true) - }() - - for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(*buf) - if n > 0 { - *buf = (*buf)[:n] - c.readQueue.pushBack(buf) - } else if n == 0 { - iobufpool.Put(buf) - } - - if err != nil { - if errors.Is(err, ErrWouldBlock) { - return nil - } else { - return err - } - } - } -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { - stopChan = make(chan struct{}) - errChan = make(chan error, 1) - - go func() { - for { - err := c.BufferReadUntilBlock() - if err != nil { - errChan <- err - return - } - - select { - case <-stopChan: - return - default: - } - } - }() - - return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { - closed := atomic.LoadInt64(&c.closed) - return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingWrite(b) - } else { - return c.realNonblockingWrite(b) - } -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWriteWaitDuration) - if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.conn.SetWriteDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetWriteDeadline(c.writeDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingRead(b) - } else { - return c.realNonblockingRead(b) - } -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - // The first 5 reads only read 1 byte at a time. This should give us 4 chances to read when we are sure the bytes are - // already in Go or the OS's receive buffer. - if c.fakeNonBlockingShortReadCount < 5 && len(b) > 0 && c.fakeNonblockingReadWaitDuration < minNonblockingReadWaitDuration { - b = b[:1] - } - - startTime := time.Now() - deadline := startTime.Add(c.fakeNonblockingReadWaitDuration) - if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.conn.SetReadDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // If the read was successful and the wait duration is not already the minimum - if err == nil && c.fakeNonblockingReadWaitDuration > minNonblockingReadWaitDuration { - endTime := time.Now() - - if n > 0 && c.fakeNonBlockingShortReadCount < 5 { - c.fakeNonBlockingShortReadCount++ - } - - // The wait duration should be 2x the fastest read that has occurred. This should give reasonable assurance that - // a Read deadline will not block a read before it has a chance to read data already in Go or the OS's receive - // buffer. - proposedWait := endTime.Sub(startTime) * 2 - if proposedWait < minNonblockingReadWaitDuration { - proposedWait = minNonblockingReadWaitDuration - } - if proposedWait < c.fakeNonblockingReadWaitDuration { - c.fakeNonblockingReadWaitDuration = proposedWait - } - } - - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetReadDeadline(c.readDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { - tc := tls.Client(conn, config) - err := tc.Handshake() - if err != nil { - return nil, err - } - - // Ensure last written part of Handshake is actually sent. - err = conn.Flush() - if err != nil { - return nil, err - } - - return &TLSConn{ - tlsConn: tc, - nbConn: conn, - }, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { - tlsConn *tls.Conn - nbConn *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { - // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then - // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our - // own 5 second deadline then make all set deadlines no-op. - tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) - tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - - return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/internal/nbconn/nbconn_fake_non_block.go b/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index 71c7388d..00000000 --- a/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !unix && !windows - -package nbconn - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) -} diff --git a/internal/nbconn/nbconn_real_non_block.go b/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index 863b86ad..00000000 --- a/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,86 +0,0 @@ -//go:build unix - -package nbconn - -import ( - "errors" - "io" - "syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - if c.nonblockWriteFunc == nil { - c.nonblockWriteFunc = func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - } - } - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - - err = c.rawConn.Write(c.nonblockWriteFunc) - n = c.nonblockWriteN - c.nonblockWriteBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - if c.nonblockReadFunc == nil { - c.nonblockReadFunc = func(fd uintptr) (done bool) { - c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) - return true - } - } - c.nonblockReadBuf = b - c.nonblockReadN = 0 - c.nonblockReadErr = nil - - err = c.rawConn.Read(c.nonblockReadFunc) - n = c.nonblockReadN - c.nonblockReadBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockReadErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} - -func (c *NetConn) SetBlockingMode(blocking bool) error { - // Do nothing on UNIX systems - return nil -} diff --git a/internal/nbconn/nbconn_real_non_block_windows.go b/internal/nbconn/nbconn_real_non_block_windows.go deleted file mode 100644 index fdf628f4..00000000 --- a/internal/nbconn/nbconn_real_non_block_windows.go +++ /dev/null @@ -1,227 +0,0 @@ -//go:build windows - -package nbconn - -import ( - "errors" - "fmt" - "golang.org/x/sys/windows" - "io" - "syscall" - "time" - "unsafe" -) - -var dll = syscall.MustLoadDLL("ws2_32.dll") - -// int ioctlsocket( -// -// [in] SOCKET s, -// [in] long cmd, -// [in, out] u_long *argp -// -// ); -var ioctlsocket = dll.MustFindProc("ioctlsocket") - -var deadlineExpErr = errors.New("i/o timeout") - -type sockMode int - -const ( - FIONBIO uint32 = 0x8004667e - sockModeBlocking sockMode = 0 - sockModeNonBlocking sockMode = 1 -) - -func setSockMode(fd uintptr, mode sockMode) error { - res, _, err := ioctlsocket.Call(fd, uintptr(FIONBIO), uintptr(unsafe.Pointer(&mode))) - // Upon successful completion, the ioctlsocket returns zero. - if res != 0 && err != nil { - return err - } - - return nil -} - -func (c *NetConn) isDeadlineSet(dl time.Time) bool { - return !dl.IsZero() && !dl.Equal(NonBlockingDeadline) && !dl.Equal(disableSetDeadlineDeadline) -} - -func (c *NetConn) isWriteDeadlineExpired() bool { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - return c.isDeadlineSet(c.writeDeadline) && !time.Now().Before(c.writeDeadline) -} - -func (c *NetConn) isReadDeadlineExpired() bool { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - return c.isDeadlineSet(c.readDeadline) && !time.Now().Before(c.readDeadline) -} - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - if c.nonblockWriteFunc == nil { - c.nonblockWriteFunc = func(fd uintptr) (done bool) { - var written uint32 - var buf syscall.WSABuf - buf.Buf = &c.nonblockWriteBuf[0] - buf.Len = uint32(len(c.nonblockWriteBuf)) - c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil) - c.nonblockWriteN = int(written) - - return true - } - } - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - - if c.isWriteDeadlineExpired() { - c.nonblockWriteErr = deadlineExpErr - - return 0, c.nonblockWriteErr - } - - err = c.rawConn.Write(c.nonblockWriteFunc) - n = c.nonblockWriteN - c.nonblockWriteBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, windows.WSAEWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - if c.nonblockReadFunc == nil { - c.nonblockReadFunc = func(fd uintptr) (done bool) { - var read uint32 - var flags uint32 - var buf syscall.WSABuf - buf.Buf = &c.nonblockReadBuf[0] - buf.Len = uint32(len(c.nonblockReadBuf)) - c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil) - c.nonblockReadN = int(read) - - return true - } - } - c.nonblockReadBuf = b - c.nonblockReadN = 0 - c.nonblockReadErr = nil - - if c.isReadDeadlineExpired() { - c.nonblockReadErr = deadlineExpErr - - return 0, c.nonblockReadErr - } - - err = c.rawConn.Read(c.nonblockReadFunc) - n = c.nonblockReadN - c.nonblockReadBuf = nil // ensure that no reference to b is kept. - if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, windows.WSAEWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockReadErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} - -func (c *NetConn) SetBlockingMode(blocking bool) error { - // Fake non-blocking I/O is ignored - if c.rawConn == nil { - return nil - } - - // Prevent concurrent SetBlockingMode calls - c.nbOperMu.Lock() - defer c.nbOperMu.Unlock() - - // Guard against negative value (which should never happen in practice) - if c.nbOperCnt < 0 { - c.nbOperCnt = 0 - } - - if blocking { - // Socket is already in blocking mode - if c.nbOperCnt == 0 { - return nil - } - - c.nbOperCnt-- - - // Not ready to exit from non-blocking mode, there is pending non-blocking operations - if c.nbOperCnt > 0 { - return nil - } - } else { - c.nbOperCnt++ - - // Socket is already in non-blocking mode - if c.nbOperCnt > 1 { - return nil - } - } - - mode := sockModeNonBlocking - if blocking { - mode = sockModeBlocking - } - - var ctrlErr, err error - - ctrlErr = c.rawConn.Control(func(fd uintptr) { - err = setSockMode(fd, mode) - }) - - if ctrlErr != nil || err != nil { - retErr := ctrlErr - if retErr == nil { - retErr = err - } - - // Revert counters inc/dec in case of error - if blocking { - c.nbOperCnt++ - - return fmt.Errorf("cannot set socket to blocking mode: %w", retErr) - } else { - c.nbOperCnt-- - - return fmt.Errorf("cannot set socket to non-blocking mode: %w", retErr) - } - } - - return nil -} diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go deleted file mode 100644 index 4fb2282b..00000000 --- a/internal/nbconn/nbconn_test.go +++ /dev/null @@ -1,599 +0,0 @@ -package nbconn_test - -import ( - "crypto/tls" - "io" - "net" - "runtime" - "strings" - "testing" - "time" - - "github.com/jackc/pgx/v5/internal/nbconn" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Test keys generated with: -// -// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost' - -var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE----- -MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls -b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ -BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB -ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5 -yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT -caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT -0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW -c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v -7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg -Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw -HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g -TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk -D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB -hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y -E7ZYmaKTMOhvkg== ------END CERTIFICATE-----`) - -// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in -// source code. -var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY----- -MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny -k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+ -fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px -N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav -IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM -4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX -IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8 -TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL -CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ -/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn -lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I -Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9 -YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp -RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq -MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd -3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE -Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0 -TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA -riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr -IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu -nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk -WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc -Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77 -DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD -pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG -2qWm8jTPeDC3sq+67s2oojHf+Q== ------END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY")) - -func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) { - for _, tt := range []struct { - name string - makeConns func(t *testing.T) (local, remote net.Conn) - useTLS bool - fakeNonBlockingIO bool - }{ - { - name: "Pipe", - makeConns: makePipeConns, - useTLS: false, - fakeNonBlockingIO: true, - }, - { - name: "TCP with Fake Non-blocking IO", - makeConns: makeTCPConns, - useTLS: false, - fakeNonBlockingIO: true, - }, - { - name: "TLS over TCP with Fake Non-blocking IO", - makeConns: makeTCPConns, - useTLS: true, - fakeNonBlockingIO: true, - }, - { - name: "TCP with Real Non-blocking IO", - makeConns: makeTCPConns, - useTLS: false, - fakeNonBlockingIO: false, - }, - { - name: "TLS over TCP with Real Non-blocking IO", - makeConns: makeTCPConns, - useTLS: true, - fakeNonBlockingIO: false, - }, - } { - t.Run(tt.name, func(t *testing.T) { - local, remote := tt.makeConns(t) - - // Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get - // garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never - // uses remote it may be garbage collected leading to the connection being closed. - defer local.Close() - defer remote.Close() - - var conn nbconn.Conn - netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO) - - if tt.useTLS { - cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey) - require.NoError(t, err) - - tlsServer := tls.Server(remote, &tls.Config{ - Certificates: []tls.Certificate{cert}, - }) - serverTLSHandshakeChan := make(chan error) - go func() { - err := tlsServer.Handshake() - serverTLSHandshakeChan <- err - }() - - tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true}) - require.NoError(t, err) - conn = tlsConn - - err = <-serverTLSHandshakeChan - require.NoError(t, err) - remote = tlsServer - } else { - conn = netConn - } - - f(t, conn, remote) - }) - } -} - -// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is -// useful for testing an exact sequence of reads and writes with the underlying connection blocking. -func makePipeConns(t *testing.T) (local, remote net.Conn) { - local, remote = net.Pipe() - t.Cleanup(func() { - local.Close() - remote.Close() - }) - - return local, remote -} - -// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost. -func makeTCPConns(t *testing.T) (local, remote net.Conn) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer ln.Close() - - type acceptResultT struct { - conn net.Conn - err error - } - acceptChan := make(chan acceptResultT) - - go func() { - conn, err := ln.Accept() - acceptChan <- acceptResultT{conn: conn, err: err} - }() - - local, err = net.Dial("tcp", ln.Addr().String()) - require.NoError(t, err) - - acceptResult := <-acceptChan - require.NoError(t, acceptResult.err) - - remote = acceptResult.conn - - return local, remote -} - -func TestWriteIsBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - // net.Pipe is synchronous so the Write would block if not buffered. - writeBuf := []byte("test") - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, 4, n) - - errChan := make(chan error, 1) - go func() { - err := conn.Flush() - errChan <- err - }() - - readBuf := make([]byte, len(writeBuf)) - _, err = remote.Read(readBuf) - require.NoError(t, err) - - require.NoError(t, <-errChan) - }) -} - -func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - err := conn.SetWriteDeadline(time.Now()) - require.NoError(t, err) - - writeBuf := []byte("test") - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, 4, n) - }) -} - -func TestReadFlushesWriteBuffer(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - writeBuf := []byte("test") - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, 4, n) - - errChan := make(chan error, 2) - go func() { - readBuf := make([]byte, len(writeBuf)) - _, err := remote.Read(readBuf) - errChan <- err - - _, err = remote.Write([]byte("okay")) - errChan <- err - }() - - readBuf := make([]byte, 4) - _, err = conn.Read(readBuf) - require.NoError(t, err) - require.Equal(t, []byte("okay"), readBuf) - - require.NoError(t, <-errChan) - require.NoError(t, <-errChan) - }) -} - -func TestCloseFlushesWriteBuffer(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - writeBuf := []byte("test") - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, 4, n) - - errChan := make(chan error, 1) - go func() { - readBuf := make([]byte, len(writeBuf)) - _, err := remote.Read(readBuf) - errChan <- err - }() - - err = conn.Close() - require.NoError(t, err) - - require.NoError(t, <-errChan) - }) -} - -// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with -// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing -// large values. -func TestInternalNonBlockingWrite(t *testing.T) { - const deadlockSize = 4 * 1024 * 1024 - - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - writeBuf := make([]byte, deadlockSize) - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, deadlockSize, n) - - errChan := make(chan error, 1) - go func() { - remoteWriteBuf := make([]byte, deadlockSize) - _, err := remote.Write(remoteWriteBuf) - if err != nil { - errChan <- err - return - } - - readBuf := make([]byte, deadlockSize) - _, err = io.ReadFull(remote, readBuf) - errChan <- err - }() - - readBuf := make([]byte, deadlockSize) - _, err = io.ReadFull(conn, readBuf) - require.NoError(t, err) - - err = conn.Close() - require.NoError(t, err) - if runtime.GOOS == "windows" && t.Name() == "TestInternalNonBlockingWrite/TLS_over_TCP_with_Fake_Non-blocking_IO" { - // this test is expected to fail on Windows see https://github.com/golang/go/issues/58764 - require.Error(t, <-errChan) - } else { - require.NoError(t, <-errChan) - } - }) -} - -func TestInternalNonBlockingWriteWithDeadline(t *testing.T) { - const deadlockSize = 4 * 1024 * 1024 - - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - writeBuf := make([]byte, deadlockSize) - n, err := conn.Write(writeBuf) - require.NoError(t, err) - require.EqualValues(t, deadlockSize, n) - - err = conn.SetDeadline(time.Now()) - require.NoError(t, err) - - err = conn.Flush() - require.Error(t, err) - require.Contains(t, err.Error(), "i/o timeout") - }) -} - -func TestNonBlockingRead(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - err := conn.SetReadDeadline(nbconn.NonBlockingDeadline) - require.NoError(t, err) - - buf := make([]byte, 4) - n, err := conn.Read(buf) - require.ErrorIs(t, err, nbconn.ErrWouldBlock) - require.EqualValues(t, 0, n) - - errChan := make(chan error, 1) - go func() { - _, err := remote.Write([]byte("okay")) - errChan <- err - }() - - err = conn.SetReadDeadline(time.Time{}) - require.NoError(t, err) - - n, err = conn.Read(buf) - require.NoError(t, err) - require.EqualValues(t, 4, n) - }) -} - -func TestBufferNonBlockingRead(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - err := conn.BufferReadUntilBlock() - require.NoError(t, err) - - errChan := make(chan error, 1) - go func() { - err := remote.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err != nil { - errChan <- err - return - } - - _, err = remote.Write([]byte("okay")) - errChan <- err - }() - - readLoop: - for i := 0; i < 1000; i++ { - err := conn.BufferReadUntilBlock() - require.NoError(t, err) - select { - case err := <-errChan: - require.NoError(t, err) - break readLoop - default: - time.Sleep(time.Millisecond) - } - } - - buf := make([]byte, 4) - n, err := conn.Read(buf) - require.NoError(t, err) - assert.EqualValues(t, 4, n) - assert.Equal(t, []byte("okay"), buf) - }) -} - -func TestReadPreviouslyBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - - errChan := make(chan error, 1) - go func() { - err := func() error { - _, err := remote.Write([]byte("alpha")) - if err != nil { - return err - } - - readBuf := make([]byte, 4) - _, err = remote.Read(readBuf) - if err != nil { - return err - } - - return nil - }() - errChan <- err - }() - - _, err := conn.Write([]byte("test")) - require.NoError(t, err) - - // Because net.Pipe() is synchronous conn.Flush must buffer a read. - err = conn.Flush() - require.NoError(t, err) - - readBuf := make([]byte, 5) - n, err := conn.Read(readBuf) - require.NoError(t, err) - require.EqualValues(t, 5, n) - require.Equal(t, []byte("alpha"), readBuf) - }) -} - -func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - errChan := make(chan error, 1) - go func() { - err := func() error { - _, err := remote.Write([]byte("alpha")) - if err != nil { - return err - } - - readBuf := make([]byte, 4) - _, err = remote.Read(readBuf) - if err != nil { - return err - } - - return nil - }() - errChan <- err - }() - - _, err := conn.Write([]byte("test")) - require.NoError(t, err) - - // Because net.Pipe() is synchronous conn.Flush must buffer a read. - err = conn.Flush() - require.NoError(t, err) - - readBuf := make([]byte, 10) - n, err := conn.Read(readBuf) - require.NoError(t, err) - require.EqualValues(t, 5, n) - require.Equal(t, []byte("alpha"), readBuf[:n]) - }) -} - -func TestReadPreviouslyBufferedPartialRead(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - - errChan := make(chan error, 1) - go func() { - err := func() error { - _, err := remote.Write([]byte("alpha")) - if err != nil { - return err - } - - readBuf := make([]byte, 4) - _, err = remote.Read(readBuf) - if err != nil { - return err - } - - return nil - }() - errChan <- err - }() - - _, err := conn.Write([]byte("test")) - require.NoError(t, err) - - // Because net.Pipe() is synchronous conn.Flush must buffer a read. - err = conn.Flush() - require.NoError(t, err) - - readBuf := make([]byte, 2) - n, err := conn.Read(readBuf) - require.NoError(t, err) - require.EqualValues(t, 2, n) - require.Equal(t, []byte("al"), readBuf) - - readBuf = make([]byte, 3) - n, err = conn.Read(readBuf) - require.NoError(t, err) - require.EqualValues(t, 3, n) - require.Equal(t, []byte("pha"), readBuf) - }) -} - -func TestReadMultiplePreviouslyBuffered(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - errChan := make(chan error, 1) - go func() { - err := func() error { - _, err := remote.Write([]byte("alpha")) - if err != nil { - return err - } - - _, err = remote.Write([]byte("beta")) - if err != nil { - return err - } - - readBuf := make([]byte, 4) - _, err = remote.Read(readBuf) - if err != nil { - return err - } - - return nil - }() - errChan <- err - }() - - _, err := conn.Write([]byte("test")) - require.NoError(t, err) - - // Because net.Pipe() is synchronous conn.Flush must buffer a read. - err = conn.Flush() - require.NoError(t, err) - - readBuf := make([]byte, 9) - n, err := io.ReadFull(conn, readBuf) - require.NoError(t, err) - require.EqualValues(t, 9, n) - require.Equal(t, []byte("alphabeta"), readBuf) - }) -} - -func TestReadPreviouslyBufferedAndReadMore(t *testing.T) { - testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { - - flushCompleteChan := make(chan struct{}) - errChan := make(chan error, 1) - go func() { - err := func() error { - _, err := remote.Write([]byte("alpha")) - if err != nil { - return err - } - - readBuf := make([]byte, 4) - _, err = remote.Read(readBuf) - if err != nil { - return err - } - - <-flushCompleteChan - - _, err = remote.Write([]byte("beta")) - if err != nil { - return err - } - - return nil - }() - errChan <- err - }() - - _, err := conn.Write([]byte("test")) - require.NoError(t, err) - - // Because net.Pipe() is synchronous conn.Flush must buffer a read. - err = conn.Flush() - require.NoError(t, err) - - close(flushCompleteChan) - - readBuf := make([]byte, 9) - - n, err := io.ReadFull(conn, readBuf) - require.NoError(t, err) - require.EqualValues(t, 9, n) - require.Equal(t, []byte("alphabeta"), readBuf) - - err = <-errChan - require.NoError(t, err) - }) -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 6e8140a6..47871124 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -16,7 +16,6 @@ import ( "time" "github.com/jackc/pgx/v5/internal/iobufpool" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" @@ -65,7 +64,7 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + conn net.Conn pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -266,14 +265,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = nbNetConn - pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -392,7 +390,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -407,12 +405,7 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err return nil, errors.New("server refused TLS connection") } - tlsConn, err := nbconn.TLSClient(conn, tlsConfig) - if err != nil { - return nil, err - } - - return tlsConn, nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { @@ -468,10 +461,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.frontend.Receive() if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - return nil, err - } - // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -1174,11 +1163,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return CommandTag{}, err } - err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } + // err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) + // if err != nil { + // pgConn.asyncClose() + // return CommandTag{}, err + // } nonblocking := true defer func() { if nonblocking { @@ -1217,9 +1206,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pgErr == nil { msg, err := pgConn.receiveMessage() if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - break - } + // if errors.Is(err, nbconn.ErrWouldBlock) { + // break + // } pgConn.asyncClose() return CommandTag{}, normalizeTimeoutError(ctx, err) } @@ -1638,15 +1627,19 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { // connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // without the client knowing whether the server received it or not. func (pgConn *PgConn) CheckConn() error { - if err := pgConn.lock(); err != nil { - return err - } - defer pgConn.unlock() + // if err := pgConn.lock(); err != nil { + // return err + // } + // defer pgConn.unlock() - err := pgConn.conn.BufferReadUntilBlock() - if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { - return err - } + rr := pgConn.ExecParams(context.Background(), "select 1", nil, nil, nil, nil) + _, err := rr.Close() + return err + + // err := pgConn.conn.BufferReadUntilBlock() + // if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { + // return err + // } return nil } @@ -1660,7 +1653,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + Conn net.Conn PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server