diff --git a/melody_test.go b/melody_test.go index f5e0b0d..653945a 100644 --- a/melody_test.go +++ b/melody_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "net/http" "net/http/httptest" + "os" "strconv" "strings" "sync" @@ -469,7 +470,8 @@ func TestHandleError(t *testing.T) { } func TestHandleErrorWrite(t *testing.T) { - done := make(chan bool) + writeError := make(chan struct{}) + disconnect := make(chan struct{}) ws := NewTestServer() ws.m.Config.WriteWait = 0 @@ -482,12 +484,15 @@ func TestHandleErrorWrite(t *testing.T) { ws.m.HandleError(func(s *Session, err error) { assert.NotNil(t, err) - var closeError *websocket.CloseError - if !errors.As(err, &closeError) { - close(done) + if os.IsTimeout(err) { + close(writeError) } }) + ws.m.HandleDisconnect(func(s *Session) { + close(disconnect) + }) + server := httptest.NewServer(ws) defer server.Close() @@ -496,7 +501,8 @@ func TestHandleErrorWrite(t *testing.T) { go conn.NextReader() - <-done + <-writeError + <-disconnect } func TestErrClosed(t *testing.T) { diff --git a/session.go b/session.go index 8bb81c9..879a724 100644 --- a/session.go +++ b/session.go @@ -105,6 +105,8 @@ loop: } } } + + s.close() } func (s *Session) readPump() {