Add pgconn.CheckConn
This commit is contained in:
+36
-13
@@ -13,6 +13,7 @@ package nbconn
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -46,11 +47,16 @@ func (*wouldBlockError) Error() string {
|
||||
func (*wouldBlockError) Timeout() bool { return true }
|
||||
func (*wouldBlockError) Temporary() bool { return true }
|
||||
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the
|
||||
// underlying connection.
|
||||
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
|
||||
// the underlying connection.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
|
||||
// Flush flushes any buffered writes.
|
||||
Flush() error
|
||||
|
||||
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
|
||||
BufferReadUntilBlock() error
|
||||
}
|
||||
|
||||
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
|
||||
@@ -303,24 +309,35 @@ func (c *NetConn) flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *NetConn) BufferReadUntilBlock() error {
|
||||
for {
|
||||
buf := iobufpool.Get(8 * 1024)
|
||||
n, err := c.nonblockingRead(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
c.readQueue.pushBack(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrWouldBlock) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
|
||||
stopChan = make(chan struct{})
|
||||
errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
buf := iobufpool.Get(8 * 1024)
|
||||
n, err := c.nonblockingRead(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
c.readQueue.pushBack(buf)
|
||||
}
|
||||
|
||||
err := c.BufferReadUntilBlock()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrWouldBlock) {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -456,6 +473,11 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// syscall read did not return an error and 0 bytes were read means EOF.
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -494,6 +516,7 @@ type TLSConn struct {
|
||||
|
||||
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
|
||||
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
|
||||
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
|
||||
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
|
||||
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
|
||||
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
|
||||
|
||||
@@ -2,6 +2,7 @@ package nbconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -345,6 +346,34 @@ func TestNonBlockingRead(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBufferNonBlockingRead(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
err := conn.BufferReadUntilBlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := remote.Write([]byte("okay"))
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err = conn.BufferReadUntilBlock()
|
||||
if !errors.Is(err, nbconn.ErrWouldBlock) {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, err := conn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, n)
|
||||
require.Equal(t, []byte("okay"), buf)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadPreviouslyBuffered(t *testing.T) {
|
||||
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user