From 087b8b2ba8681657182bdadedabb24337552dc84 Mon Sep 17 00:00:00 2001 From: Dmitry K Date: Sun, 26 Feb 2023 00:00:37 +0300 Subject: [PATCH] Try to make windows non-blocking I/O --- go.mod | 1 + go.sum | 2 + .../nbconn/nbconn_real_non_block_windows.go | 61 ++++++++++++++++++- 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index b16d6cd5..2f035e7f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/puddle/v2 v2.2.0 github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.6.0 + golang.org/x/sys v0.5.0 golang.org/x/text v0.7.0 ) diff --git a/go.sum b/go.sum index 35df01ee..cd6e49b1 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/nbconn/nbconn_real_non_block_windows.go b/internal/nbconn/nbconn_real_non_block_windows.go index 259c8f8a..5211f12b 100644 --- a/internal/nbconn/nbconn_real_non_block_windows.go +++ b/internal/nbconn/nbconn_real_non_block_windows.go @@ -4,20 +4,64 @@ package nbconn import ( "errors" + "golang.org/x/sys/windows" "io" "syscall" + "unsafe" ) +var dll = syscall.MustLoadDLL("ws2_32.dll") + +// int ioctlsocket( +// +// [in] SOCKET s, +// [in] long cmd, +// [in, out] u_long *argp +// +// ); +var ioctlsocket = dll.MustFindProc("ioctlsocket") + +type sockMode int + +const ( + FIONBIO int = 0x8004667e + sockModeBlocking sockMode = 0 + sockModeNonBlocking sockMode = 1 +) + +func setSockMode(fd uintptr, mode sockMode) error { + res, _, err := ioctlsocket.Call(fd, uintptr(FIONBIO), uintptr(unsafe.Pointer(&mode))) + // Upon successful completion, the ioctlsocket returns zero. + if res != 0 && err != nil { + return err + } + + return nil +} + // realNonblockingWrite does a non-blocking write. readFlushLock must already be held. func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { if c.nonblockWriteFunc == nil { c.nonblockWriteFunc = func(fd uintptr) (done bool) { + // Make sock non-blocking + if err := setSockMode(fd, sockModeNonBlocking); err != nil { + c.nonblockWriteErr = err + return true + } + var written uint32 var buf syscall.WSABuf buf.Buf = &c.nonblockWriteBuf[0] buf.Len = uint32(len(c.nonblockWriteBuf)) c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil) c.nonblockWriteN = int(written) + + // Make sock blocking again + if err := setSockMode(fd, sockModeBlocking); err != nil { + c.nonblockWriteErr = err + return true + } + return true } } @@ -29,7 +73,7 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { n = c.nonblockWriteN c.nonblockWriteBuf = nil // ensure that no reference to b is kept. if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { + if errors.Is(c.nonblockWriteErr, windows.WSAEWOULDBLOCK) { err = ErrWouldBlock } else { err = c.nonblockWriteErr @@ -50,6 +94,12 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { if c.nonblockReadFunc == nil { c.nonblockReadFunc = func(fd uintptr) (done bool) { + // Make sock non-blocking + if err := setSockMode(fd, sockModeNonBlocking); err != nil { + c.nonblockWriteErr = err + return true + } + var read uint32 var flags uint32 var buf syscall.WSABuf @@ -57,6 +107,13 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { buf.Len = uint32(len(c.nonblockReadBuf)) c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil) c.nonblockReadN = int(read) + + // Make sock blocking again + if err := setSockMode(fd, sockModeBlocking); err != nil { + c.nonblockWriteErr = err + return true + } + return true } } @@ -68,7 +125,7 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { n = c.nonblockReadN c.nonblockReadBuf = nil // ensure that no reference to b is kept. if err == nil && c.nonblockReadErr != nil { - if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { + if errors.Is(c.nonblockReadErr, windows.WSAEWOULDBLOCK) { err = ErrWouldBlock } else { err = c.nonblockReadErr