Set socket to non-blocking mode in Read, Flush and BufferReadUntilBlock operations
This commit is contained in:
committed by
Jack Christensen
parent
3db7d1774e
commit
b2b4fbcf57
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jackc/pgx/v5/internal/nbconn"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
@@ -132,17 +131,6 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||||||
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
|
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
if realNbConn, ok := ct.conn.pgConn.Conn().(*nbconn.NetConn); ok {
|
|
||||||
if err := realNbConn.SetBlockingMode(false); err != nil {
|
|
||||||
return 0, fmt.Errorf("cannot set socket non-blocking mode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
// TODO: Deal with it
|
|
||||||
_ = realNbConn.SetBlockingMode(true)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
r, w := io.Pipe()
|
r, w := io.Pipe()
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ package nbconn
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -97,8 +98,8 @@ type NetConn struct {
|
|||||||
writeDeadlineLock sync.Mutex
|
writeDeadlineLock sync.Mutex
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
|
||||||
// Indicates that underlying socket connection mode explicitly set to be non-blocking
|
// nbOperCnt Tracks how many operations performing simultaneously
|
||||||
isNonBlocking bool
|
nbOperCnt int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
|
||||||
@@ -160,6 +161,18 @@ func (c *NetConn) Read(b []byte) (n int, err error) {
|
|||||||
|
|
||||||
var readN int
|
var readN int
|
||||||
if readNonblocking {
|
if readNonblocking {
|
||||||
|
if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil {
|
||||||
|
err = fmt.Errorf("cannot set socket to non-blocking mode: %w", setSockModeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = c.SetBlockingMode(true)
|
||||||
|
}()
|
||||||
|
|
||||||
readN, err = c.nonblockingRead(b[n:])
|
readN, err = c.nonblockingRead(b[n:])
|
||||||
} else {
|
} else {
|
||||||
readN, err = c.conn.Read(b[n:])
|
readN, err = c.conn.Read(b[n:])
|
||||||
@@ -284,6 +297,14 @@ func (c *NetConn) flush() error {
|
|||||||
var stopChan chan struct{}
|
var stopChan chan struct{}
|
||||||
var errChan chan error
|
var errChan chan error
|
||||||
|
|
||||||
|
if err := c.SetBlockingMode(false); err != nil {
|
||||||
|
return fmt.Errorf("cannot set socket to non-blocking mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = c.SetBlockingMode(true)
|
||||||
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if stopChan != nil {
|
if stopChan != nil {
|
||||||
select {
|
select {
|
||||||
@@ -327,6 +348,14 @@ func (c *NetConn) flush() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *NetConn) BufferReadUntilBlock() error {
|
func (c *NetConn) BufferReadUntilBlock() error {
|
||||||
|
if err := c.SetBlockingMode(false); err != nil {
|
||||||
|
return fmt.Errorf("cannot set socket to non-blocking mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = c.SetBlockingMode(true)
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := iobufpool.Get(8 * 1024)
|
buf := iobufpool.Get(8 * 1024)
|
||||||
n, err := c.nonblockingRead(*buf)
|
n, err := c.nonblockingRead(*buf)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"io"
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@@ -43,14 +44,6 @@ func setSockMode(fd uintptr, mode sockMode) error {
|
|||||||
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
||||||
if c.nonblockWriteFunc == nil {
|
if c.nonblockWriteFunc == nil {
|
||||||
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
|
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
|
||||||
if !c.isNonBlocking {
|
|
||||||
// Make sock non-blocking
|
|
||||||
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
|
|
||||||
c.nonblockWriteErr = err
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var written uint32
|
var written uint32
|
||||||
var buf syscall.WSABuf
|
var buf syscall.WSABuf
|
||||||
buf.Buf = &c.nonblockWriteBuf[0]
|
buf.Buf = &c.nonblockWriteBuf[0]
|
||||||
@@ -58,14 +51,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
|||||||
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
|
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
|
||||||
c.nonblockWriteN = int(written)
|
c.nonblockWriteN = int(written)
|
||||||
|
|
||||||
if !c.isNonBlocking {
|
|
||||||
// Make sock blocking again
|
|
||||||
if err := setSockMode(fd, sockModeBlocking); err != nil {
|
|
||||||
c.nonblockWriteErr = err
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,14 +83,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
|
|||||||
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
||||||
if c.nonblockReadFunc == nil {
|
if c.nonblockReadFunc == nil {
|
||||||
c.nonblockReadFunc = func(fd uintptr) (done bool) {
|
c.nonblockReadFunc = func(fd uintptr) (done bool) {
|
||||||
if !c.isNonBlocking {
|
|
||||||
// Make sock non-blocking
|
|
||||||
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
|
|
||||||
c.nonblockReadErr = err
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var read uint32
|
var read uint32
|
||||||
var flags uint32
|
var flags uint32
|
||||||
var buf syscall.WSABuf
|
var buf syscall.WSABuf
|
||||||
@@ -114,14 +91,6 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
|||||||
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
|
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
|
||||||
c.nonblockReadN = int(read)
|
c.nonblockReadN = int(read)
|
||||||
|
|
||||||
if !c.isNonBlocking {
|
|
||||||
// Make sock blocking again
|
|
||||||
if err := setSockMode(fd, sockModeBlocking); err != nil {
|
|
||||||
c.nonblockReadErr = err
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -157,22 +126,52 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *NetConn) SetBlockingMode(blocking bool) error {
|
func (c *NetConn) SetBlockingMode(blocking bool) error {
|
||||||
|
// Fake non-blocking I/O is ignored
|
||||||
|
if c.rawConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if blocking {
|
||||||
|
// No ready to exit from non-blocking mode, there are pending non-blocking operations
|
||||||
|
if atomic.AddInt32(&c.nbOperCnt, -1) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Socket is already in non-blocking state
|
||||||
|
if atomic.AddInt32(&c.nbOperCnt, 1) > 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//fmt.Println("socket reverting to blocking mode")
|
||||||
|
}
|
||||||
|
|
||||||
mode := sockModeNonBlocking
|
mode := sockModeNonBlocking
|
||||||
if blocking {
|
if blocking {
|
||||||
mode = sockModeBlocking
|
mode = sockModeBlocking
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var ctrlErr, err error
|
||||||
|
|
||||||
if ctrlErr := c.rawConn.Control(func(fd uintptr) {
|
ctrlErr = c.rawConn.Control(func(fd uintptr) {
|
||||||
err = setSockMode(fd, mode)
|
err = setSockMode(fd, mode)
|
||||||
}); ctrlErr != nil {
|
})
|
||||||
return ctrlErr
|
|
||||||
|
if ctrlErr != nil || err != nil {
|
||||||
|
// Revert counters inc/dec in case of error
|
||||||
|
if blocking {
|
||||||
|
atomic.AddInt32(&c.nbOperCnt, 1)
|
||||||
|
} else {
|
||||||
|
atomic.AddInt32(&c.nbOperCnt, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctrlErr != nil {
|
||||||
|
return ctrlErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
return nil
|
||||||
c.isNonBlocking = !blocking
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user