From 009a377028013720956f1c907855265c227e3ecd Mon Sep 17 00:00:00 2001 From: Dmitry K Date: Tue, 21 Mar 2023 10:15:02 +0300 Subject: [PATCH] Use mutex to guard entire `SetBlockingMode` call --- internal/nbconn/nbconn.go | 4 ++- .../nbconn/nbconn_real_non_block_windows.go | 32 +++++++++++++++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index dfcf4c94..38489a74 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -98,7 +98,9 @@ type NetConn struct { writeDeadline time.Time // nbOperCnt Tracks how many operations performing simultaneously - nbOperCnt atomic.Int32 + nbOperCnt int + // nbOperMu Used to prevent concurrent SetBlockingMode calls + nbOperMu sync.Mutex } func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { diff --git a/internal/nbconn/nbconn_real_non_block_windows.go b/internal/nbconn/nbconn_real_non_block_windows.go index aefc1d4a..6be55ba1 100644 --- a/internal/nbconn/nbconn_real_non_block_windows.go +++ b/internal/nbconn/nbconn_real_non_block_windows.go @@ -131,14 +131,32 @@ func (c *NetConn) SetBlockingMode(blocking bool) error { return nil } + // Prevent concurrent SetBlockingMode calls + c.nbOperMu.Lock() + defer c.nbOperMu.Unlock() + + // Guard against negative value (which should never happen in practice) + if c.nbOperCnt < 0 { + c.nbOperCnt = 0 + } + if blocking { - // Not ready to exit from non-blocking mode, there are pending non-blocking operations - if c.nbOperCnt.Add(-1) > 0 { + // Socket is already in blocking mode + if c.nbOperCnt == 0 { + return nil + } + + c.nbOperCnt-- + + // Not ready to exit from non-blocking mode, there is pending non-blocking operations + if c.nbOperCnt > 0 { return nil } } else { - // Socket is already in non-blocking state - if c.nbOperCnt.Add(1) > 1 { + c.nbOperCnt++ + + // Socket is already in non-blocking mode + if c.nbOperCnt > 1 { return nil } } @@ -162,11 +180,13 @@ func (c *NetConn) SetBlockingMode(blocking bool) error { // Revert counters inc/dec in case of error if blocking { - c.nbOperCnt.Add(1) + c.nbOperCnt++ + //c.nbOperCnt.Add(1) return fmt.Errorf("cannot set socket to blocking mode: %w", retErr) } else { - c.nbOperCnt.Add(-1) + c.nbOperCnt-- + //c.nbOperCnt.Add(-1) return fmt.Errorf("cannot set socket to non-blocking mode: %w", retErr) }