diff --git a/go.mod b/go.mod index 96680d4..5513d76 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,13 @@ module github.com/olahol/melody go 1.19 -require github.com/gorilla/websocket v1.5.0 +require ( + github.com/gorilla/websocket v1.5.0 + github.com/stretchr/testify v1.8.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index e5a03d4..ceed388 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,17 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/melody_test.go b/melody_test.go index 214c990..39ddabc 100644 --- a/melody_test.go +++ b/melody_test.go @@ -2,6 +2,7 @@ package melody import ( "bytes" + "errors" "math/rand" "net/http" "net/http/httptest" @@ -12,10 +13,14 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" ) +var TestMsg = []byte("test") + type TestServer struct { - m *Melody + withKeys bool + m *Melody } func NewTestServerHandler(handler handleMessageFunc) *TestServer { @@ -34,7 +39,11 @@ func NewTestServer() *TestServer { } func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.m.HandleRequest(w, r) + if s.withKeys { + s.m.HandleRequestWithKeys(w, r, make(map[string]any)) + } else { + s.m.HandleRequest(w, r) + } } func NewDialer(url string) (*websocket.Conn, error) { @@ -43,79 +52,267 @@ func NewDialer(url string) (*websocket.Conn, error) { return conn, err } +func MustNewDialer(url string) *websocket.Conn { + conn, err := NewDialer(url) + + if err != nil { + panic("could not dail websocket") + } + + return conn +} + func TestEcho(t *testing.T) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { + ws := NewTestServerHandler(func(session *Session, msg []byte) { session.Write(msg) }) - server := httptest.NewServer(echo) + server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { - conn, err := NewDialer(server.URL) + conn := MustNewDialer(server.URL) defer conn.Close() - if err != nil { - t.Error(err) - return false - } - conn.WriteMessage(websocket.TextMessage, []byte(msg)) _, ret, err := conn.ReadMessage() - if err != nil { - t.Error(err) - return false - } + assert.Nil(t, err) - if msg != string(ret) { - t.Errorf("%s should equal %s", msg, string(ret)) - return false - } + assert.Equal(t, msg, string(ret)) return true } - if err := quick.Check(fn, nil); err != nil { - t.Error(err) - } + err := quick.Check(fn, nil) + + assert.Nil(t, err) } -func TestWriteClosed(t *testing.T) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) +func TestEchoBinary(t *testing.T) { + ws := NewTestServerHandler(func(session *Session, msg []byte) { + session.WriteBinary(msg) }) - server := httptest.NewServer(echo) + server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { - conn, err := NewDialer(server.URL) - - if err != nil { - t.Error(err) - return false - } + conn := MustNewDialer(server.URL) + defer conn.Close() conn.WriteMessage(websocket.TextMessage, []byte(msg)) - echo.m.HandleConnect(func(s *Session) { - s.Close() - }) + _, ret, err := conn.ReadMessage() - echo.m.HandleDisconnect(func(s *Session) { - err := s.Write([]byte("hello world")) + assert.Nil(t, err) - if err == nil { - t.Error("should be an error") - } - }) + assert.True(t, bytes.Equal([]byte(msg), ret)) return true } - if err := quick.Check(fn, nil); err != nil { - t.Error(err) + err := quick.Check(fn, nil) + + assert.Nil(t, err) +} + +func TestWriteClosedServer(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.HandleConnect(func(s *Session) { + s.Close() + }) + + ws.m.HandleDisconnect(func(s *Session) { + err := s.Write(TestMsg) + + assert.NotNil(t, err) + close(done) + }) + + conn := MustNewDialer(server.URL) + conn.ReadMessage() + defer conn.Close() + + <-done +} + +func TestWriteClosedClient(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.HandleDisconnect(func(s *Session) { + err := s.Write(TestMsg) + + assert.NotNil(t, err) + close(done) + }) + + conn := MustNewDialer(server.URL) + conn.Close() + + <-done +} + +func TestUpgrader(t *testing.T) { + ws := NewTestServer() + ws.m.HandleMessage(func(session *Session, msg []byte) { + session.Write(msg) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + ws.m.Upgrader = &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return false }, } + + _, err := NewDialer(server.URL) + + assert.ErrorIs(t, err, websocket.ErrBadHandshake) +} + +func TestBroadcast(t *testing.T) { + n := 10 + msg := "test" + + test := func(h func(*TestServer), w func(*websocket.Conn)) { + + ws := NewTestServer() + + h(ws) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + listeners := make([]*websocket.Conn, n) + for i := range listeners { + listener := MustNewDialer(server.URL) + listeners[i] = listener + defer listeners[i].Close() + } + + w(conn) + + for _, listener := range listeners { + _, ret, err := listener.ReadMessage() + + assert.Nil(t, err) + + assert.Equal(t, msg, string(ret)) + } + } + + test(func(ws *TestServer) { + ws.m.HandleMessage(func(s *Session, msg []byte) { + ws.m.Broadcast(msg) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessageBinary(func(s *Session, msg []byte) { + ws.m.BroadcastBinary(msg) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessage(func(s *Session, msg []byte) { + ws.m.BroadcastFilter(msg, func(s *Session) bool { + return true + }) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessageBinary(func(s *Session, msg []byte) { + ws.m.BroadcastBinaryFilter(msg, func(s *Session) bool { + return true + }) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessage(func(s *Session, msg []byte) { + ws.m.BroadcastOthers(msg, s) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessageBinary(func(s *Session, msg []byte) { + ws.m.BroadcastBinaryOthers(msg, s) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) + }) + + test(func(ws *TestServer) { + ws.m.HandleMessage(func(s *Session, msg []byte) { + ss, _ := ws.m.Sessions() + ws.m.BroadcastMultiple(msg, ss) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + }) +} + +func TestClose(t *testing.T) { + done := make(chan bool) + ws := NewTestServer() + + server := httptest.NewServer(ws) + defer server.Close() + + n := 10 + + conns := make([]*websocket.Conn, n) + for i := range conns { + conn := MustNewDialer(server.URL) + conns[i] = conn + defer conns[i].Close() + } + + q := n + ws.m.HandleDisconnect(func(s *Session) { + q-- + + if q == 0 { + close(done) + } + }) + + ws.m.Close() + + for _, conn := range conns { + conn.ReadMessage() + } + + assert.Zero(t, ws.m.Len()) + + <-done } func TestLen(t *testing.T) { @@ -124,6 +321,7 @@ func TestLen(t *testing.T) { connect := int(rand.Int31n(100)) disconnect := rand.Float32() conns := make([]*websocket.Conn, connect) + defer func() { for _, conn := range conns { if conn != nil { @@ -132,17 +330,14 @@ func TestLen(t *testing.T) { } }() - echo := NewTestServerHandler(func(session *Session, msg []byte) {}) - server := httptest.NewServer(echo) + ws := NewTestServer() + + server := httptest.NewServer(ws) defer server.Close() disconnected := 0 for i := 0; i < connect; i++ { - conn, err := NewDialer(server.URL) - - if err != nil { - t.Error(err) - } + conn := MustNewDialer(server.URL) if rand.Float32() < disconnect { conns[i] = nil @@ -158,12 +353,10 @@ func TestLen(t *testing.T) { connected := connect - disconnected - if echo.m.Len() != connected { - t.Errorf("melody len %d should equal %d", echo.m.Len(), connected) - } + assert.Equal(t, ws.m.Len(), connected) } -func TestGetSessions(t *testing.T) { +func TestSessions(t *testing.T) { rand.Seed(time.Now().UnixNano()) connect := int(rand.Int31n(100)) @@ -177,8 +370,8 @@ func TestGetSessions(t *testing.T) { } }() - echo := NewTestServerHandler(func(session *Session, msg []byte) {}) - server := httptest.NewServer(echo) + ws := NewTestServer() + server := httptest.NewServer(ws) defer server.Close() disconnected := 0 @@ -203,573 +396,307 @@ func TestGetSessions(t *testing.T) { connected := connect - disconnected - allsess, err := echo.m.Sessions() - if err != nil { - t.Fatalf("error retrieving sessions: %v", err.Error()) - } + ss, err := ws.m.Sessions() - if len(allsess) != connected { - t.Errorf("melody sessions %d should equal %d", len(allsess), connected) - } -} + assert.Nil(t, err) -func TestEchoBinary(t *testing.T) { - echo := NewTestServer() - echo.m.HandleMessageBinary(func(session *Session, msg []byte) { - session.WriteBinary(msg) - }) - server := httptest.NewServer(echo) - defer server.Close() - - fn := func(msg string) bool { - conn, err := NewDialer(server.URL) - defer conn.Close() - - if err != nil { - t.Error(err) - return false - } - - conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) - - _, ret, err := conn.ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if msg != string(ret) { - t.Errorf("%s should equal %s", msg, string(ret)) - return false - } - - return true - } - - if err := quick.Check(fn, nil); err != nil { - t.Error(err) - } -} - -func TestHandlers(t *testing.T) { - echo := NewTestServer() - echo.m.HandleMessage(func(session *Session, msg []byte) { - session.Write(msg) - }) - server := httptest.NewServer(echo) - defer server.Close() - - var q *Session - - echo.m.HandleConnect(func(session *Session) { - q = session - session.Close() - }) - - echo.m.HandleDisconnect(func(session *Session) { - if q != session { - t.Error("disconnecting session should be the same as connecting") - } - }) - - NewDialer(server.URL) -} - -func TestMetadata(t *testing.T) { - echo := NewTestServer() - echo.m.HandleConnect(func(session *Session) { - session.Set("stamp", time.Now().UnixNano()) - }) - echo.m.HandleMessage(func(session *Session, msg []byte) { - stamp := session.MustGet("stamp").(int64) - session.Write([]byte(strconv.Itoa(int(stamp)))) - }) - server := httptest.NewServer(echo) - defer server.Close() - - fn := func(msg string) bool { - conn, err := NewDialer(server.URL) - defer conn.Close() - - if err != nil { - t.Error(err) - return false - } - - conn.WriteMessage(websocket.TextMessage, []byte(msg)) - - _, ret, err := conn.ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - stamp, err := strconv.Atoi(string(ret)) - - if err != nil { - t.Error(err) - return false - } - - diff := int(time.Now().UnixNano()) - stamp - - if diff <= 0 { - t.Errorf("diff should be above 0 %d", diff) - return false - } - - return true - } - - if err := quick.Check(fn, nil); err != nil { - t.Error(err) - } -} - -func TestUpgrader(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessage(func(session *Session, msg []byte) { - session.Write(msg) - }) - server := httptest.NewServer(broadcast) - defer server.Close() - - broadcast.m.Upgrader = &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return false }, - } - - broadcast.m.HandleError(func(session *Session, err error) { - if err == nil || err.Error() != "websocket: origin not allowed" { - t.Error("there should be a origin error") - } - }) - - _, err := NewDialer(server.URL) - - if err == nil || err.Error() != "websocket: bad handshake" { - t.Error("there should be a badhandshake error") - } -} - -func TestBroadcast(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessage(func(session *Session, msg []byte) { - broadcast.m.Broadcast(msg) - }) - server := httptest.NewServer(broadcast) - defer server.Close() - - n := 10 - - fn := func(msg string) bool { - conn, _ := NewDialer(server.URL) - defer conn.Close() - - listeners := make([]*websocket.Conn, n) - for i := 0; i < n; i++ { - listener, _ := NewDialer(server.URL) - listeners[i] = listener - defer listeners[i].Close() - } - - conn.WriteMessage(websocket.TextMessage, []byte(msg)) - - for i := 0; i < n; i++ { - _, ret, err := listeners[i].ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if msg != string(ret) { - t.Errorf("%s should equal %s", msg, string(ret)) - return false - } - } - - return true - } - - if !fn("test") { - t.Errorf("should not be false") - } -} - -func TestBroadcastBinary(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) { - broadcast.m.BroadcastBinary(msg) - }) - server := httptest.NewServer(broadcast) - defer server.Close() - - n := 10 - - fn := func(msg []byte) bool { - conn, _ := NewDialer(server.URL) - defer conn.Close() - - listeners := make([]*websocket.Conn, n) - for i := 0; i < n; i++ { - listener, _ := NewDialer(server.URL) - listeners[i] = listener - defer listeners[i].Close() - } - - conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) - - for i := 0; i < n; i++ { - messageType, ret, err := listeners[i].ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if messageType != websocket.BinaryMessage { - t.Errorf("message type should be BinaryMessage") - return false - } - - if !bytes.Equal(msg, ret) { - t.Errorf("%v should equal %v", msg, ret) - return false - } - } - - return true - } - - if !fn([]byte{2, 3, 5, 7, 11}) { - t.Errorf("should not be false") - } -} - -func TestBroadcastOthers(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessage(func(session *Session, msg []byte) { - broadcast.m.BroadcastOthers(msg, session) - }) - broadcast.m.Config.PongWait = time.Second - broadcast.m.Config.PingPeriod = time.Second * 9 / 10 - server := httptest.NewServer(broadcast) - defer server.Close() - - n := 10 - - fn := func(msg string) bool { - conn, _ := NewDialer(server.URL) - defer conn.Close() - - listeners := make([]*websocket.Conn, n) - for i := 0; i < n; i++ { - listener, _ := NewDialer(server.URL) - listeners[i] = listener - defer listeners[i].Close() - } - - conn.WriteMessage(websocket.TextMessage, []byte(msg)) - - for i := 0; i < n; i++ { - _, ret, err := listeners[i].ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if msg != string(ret) { - t.Errorf("%s should equal %s", msg, string(ret)) - return false - } - } - - return true - } - - if !fn("test") { - t.Errorf("should not be false") - } -} - -func TestBroadcastBinaryOthers(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) { - broadcast.m.BroadcastBinaryOthers(msg, session) - }) - broadcast.m.Config.PongWait = time.Second - broadcast.m.Config.PingPeriod = time.Second * 9 / 10 - server := httptest.NewServer(broadcast) - defer server.Close() - - n := 10 - - fn := func(msg []byte) bool { - conn, _ := NewDialer(server.URL) - defer conn.Close() - - listeners := make([]*websocket.Conn, n) - for i := 0; i < n; i++ { - listener, _ := NewDialer(server.URL) - listeners[i] = listener - defer listeners[i].Close() - } - - conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) - - for i := 0; i < n; i++ { - messageType, ret, err := listeners[i].ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if messageType != websocket.BinaryMessage { - t.Errorf("message type should be BinaryMessage") - return false - } - - if !bytes.Equal(msg, ret) { - t.Errorf("%v should equal %v", msg, ret) - return false - } - } - - return true - } - - if !fn([]byte{2, 3, 5, 7, 11}) { - t.Errorf("should not be false") - } + assert.Equal(t, len(ss), connected) } func TestPingPong(t *testing.T) { - noecho := NewTestServer() - noecho.m.Config.PongWait = time.Second - noecho.m.Config.PingPeriod = time.Second * 9 / 10 - server := httptest.NewServer(noecho) + done := make(chan bool) + + ws := NewTestServer() + ws.m.Config.PingPeriod = time.Millisecond + + ws.m.HandlePong(func(s *Session) { + close(done) + }) + + server := httptest.NewServer(ws) defer server.Close() - conn, err := NewDialer(server.URL) - conn.SetPingHandler(func(string) error { - return nil - }) + conn := MustNewDialer(server.URL) defer conn.Close() - if err != nil { - t.Error(err) - } + go conn.NextReader() - conn.WriteMessage(websocket.TextMessage, []byte("test")) - - _, _, err = conn.ReadMessage() - - if err == nil { - t.Error("there should be an error") - } + <-done } -func TestBroadcastFilter(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessage(func(session *Session, msg []byte) { - broadcast.m.BroadcastFilter(msg, func(q *Session) bool { - return session == q - }) +func TestHandleClose(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + ws.m.Config.PingPeriod = time.Millisecond + + ws.m.HandleClose(func(s *Session, code int, text string) error { + close(done) + return nil }) - server := httptest.NewServer(broadcast) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + + conn.WriteMessage(websocket.CloseMessage, nil) + + <-done +} + +func TestHandleError(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + + ws.m.HandleError(func(s *Session, err error) { + var closeError *websocket.CloseError + assert.ErrorAs(t, err, &closeError) + close(done) + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + + conn.Close() + + <-done +} + +func TestHandleErrorWrite(t *testing.T) { + done := make(chan bool) + + ws := NewTestServer() + ws.m.Config.WriteWait = 0 + + ws.m.HandleConnect(func(s *Session) { + err := s.Write(TestMsg) + assert.Nil(t, err) + }) + + ws.m.HandleError(func(s *Session, err error) { + assert.NotNil(t, err) + + var closeError *websocket.CloseError + if !errors.As(err, &closeError) { + close(done) + } + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + go conn.NextReader() + + <-done +} + +func TestErrClosed(t *testing.T) { + res := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + ws.m.CloseWithMsg(TestMsg) + }) + + ws.m.HandleDisconnect(func(s *Session) { + res <- s + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + go conn.ReadMessage() + + s := <-res + + assert.True(t, s.IsClosed()) + assert.True(t, ws.m.IsClosed()) + _, err := ws.m.Sessions() + assert.ErrorIs(t, err, ErrClosed) + assert.ErrorIs(t, ws.m.Close(), ErrClosed) + assert.ErrorIs(t, ws.m.CloseWithMsg(TestMsg), ErrClosed) + + assert.ErrorIs(t, ws.m.Broadcast(TestMsg), ErrClosed) + assert.ErrorIs(t, ws.m.BroadcastBinary(TestMsg), ErrClosed) + assert.ErrorIs(t, ws.m.BroadcastFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed) + assert.ErrorIs(t, ws.m.BroadcastBinaryFilter(TestMsg, func(s *Session) bool { return true }), ErrClosed) + assert.ErrorIs(t, ws.m.HandleRequest(nil, nil), ErrClosed) +} + +func TestErrSessionClosed(t *testing.T) { + res := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + s.CloseWithMsg(TestMsg) + }) + + ws.m.HandleDisconnect(func(s *Session) { + res <- s + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + go conn.ReadMessage() + + s := <-res + + assert.True(t, s.IsClosed()) + assert.ErrorIs(t, s.Write(TestMsg), ErrSessionClosed) + assert.ErrorIs(t, s.WriteBinary(TestMsg), ErrSessionClosed) + assert.ErrorIs(t, s.CloseWithMsg(TestMsg), ErrSessionClosed) + assert.ErrorIs(t, s.Close(), ErrSessionClosed) + assert.ErrorIs(t, ws.m.BroadcastMultiple(TestMsg, []*Session{s}), ErrSessionClosed) + + assert.ErrorIs(t, s.writeRaw(nil), ErrWriteClosed) + s.writeMessage(nil) +} + +func TestErrMessageBufferFull(t *testing.T) { + done := make(chan bool) + + ws := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + session.Write(msg) + }) + ws.m.Config.MessageBufferSize = 0 + ws.m.HandleError(func(s *Session, err error) { + if errors.Is(err, ErrMessageBufferFull) { + close(done) + } + }) + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + conn.WriteMessage(websocket.TextMessage, TestMsg) + + <-done +} + +func TestSessionKeys(t *testing.T) { + ws := NewTestServer() + + ws.m.HandleConnect(func(session *Session) { + session.Set("stamp", time.Now().UnixNano()) + }) + ws.m.HandleMessage(func(session *Session, msg []byte) { + stamp := session.MustGet("stamp").(int64) + session.Write([]byte(strconv.Itoa(int(stamp)))) + }) + server := httptest.NewServer(ws) defer server.Close() fn := func(msg string) bool { - conn, err := NewDialer(server.URL) + conn := MustNewDialer(server.URL) defer conn.Close() - if err != nil { - t.Error(err) - return false - } - conn.WriteMessage(websocket.TextMessage, []byte(msg)) _, ret, err := conn.ReadMessage() - if err != nil { - t.Error(err) - return false - } + assert.Nil(t, err) - if msg != string(ret) { - t.Errorf("%s should equal %s", msg, string(ret)) - return false - } + stamp, err := strconv.Atoi(string(ret)) + + assert.Nil(t, err) + + diff := int(time.Now().UnixNano()) - stamp + + assert.Greater(t, diff, 0) return true } - if !fn("test") { - t.Errorf("should not be false") - } + assert.Nil(t, quick.Check(fn, nil)) } -func TestBroadcastBinaryFilter(t *testing.T) { - broadcast := NewTestServer() - broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) { - broadcast.m.BroadcastBinaryFilter(msg, func(q *Session) bool { - return session == q - }) +func TestMisc(t *testing.T) { + res := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + res <- s }) - server := httptest.NewServer(broadcast) + + server := httptest.NewServer(ws) defer server.Close() - fn := func(msg []byte) bool { - conn, err := NewDialer(server.URL) + conn := MustNewDialer(server.URL) + defer conn.Close() + + go conn.ReadMessage() + + s := <-res + + assert.Contains(t, s.LocalAddr().String(), "127.0.0.1") + assert.Contains(t, s.RemoteAddr().String(), "127.0.0.1") + assert.Equal(t, FormatCloseMessage(websocket.CloseMessage, "test"), websocket.FormatCloseMessage(websocket.CloseMessage, "test")) + assert.Panics(t, func() { + s.MustGet("test") + }) +} + +func TestHandleSentMessage(t *testing.T) { + test := func(h func(*TestServer, chan bool), w func(*websocket.Conn)) { + done := make(chan bool) + + ws := NewTestServer() + server := httptest.NewServer(ws) + defer server.Close() + + h(ws, done) + + conn := MustNewDialer(server.URL) defer conn.Close() - if err != nil { - t.Error(err) - return false - } + w(conn) - conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) - - messageType, ret, err := conn.ReadMessage() - - if err != nil { - t.Error(err) - return false - } - - if messageType != websocket.BinaryMessage { - t.Errorf("message type should be BinaryMessage") - return false - } - - if !bytes.Equal(msg, ret) { - t.Errorf("%v should equal %v", msg, ret) - return false - } - - return true + <-done } - if !fn([]byte{2, 3, 5, 7, 11}) { - t.Errorf("should not be false") - } -} + test(func(ws *TestServer, done chan bool) { + ws.m.HandleMessage(func(s *Session, msg []byte) { + s.Write(msg) + }) -func TestStop(t *testing.T) { - noecho := NewTestServer() - server := httptest.NewServer(noecho) - defer server.Close() - - conn, err := NewDialer(server.URL) - defer conn.Close() - - if err != nil { - t.Error(err) - } - - noecho.m.Close() -} - -func TestSmallMessageBuffer(t *testing.T) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) - }) - echo.m.Config.MessageBufferSize = 0 - echo.m.HandleError(func(s *Session, err error) { - if err == nil { - t.Error("there should be a buffer full error here") - } - }) - server := httptest.NewServer(echo) - defer server.Close() - - conn, err := NewDialer(server.URL) - defer conn.Close() - - if err != nil { - t.Error(err) - } - - conn.WriteMessage(websocket.TextMessage, []byte("12345")) -} - -func TestPong(t *testing.T) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) - }) - echo.m.Config.PongWait = time.Second - echo.m.Config.PingPeriod = time.Second * 9 / 10 - server := httptest.NewServer(echo) - defer server.Close() - - conn, err := NewDialer(server.URL) - defer conn.Close() - - if err != nil { - t.Error(err) - } - - fired := false - echo.m.HandlePong(func(s *Session) { - fired = true + ws.m.HandleSentMessage(func(s *Session, msg []byte) { + assert.Equal(t, TestMsg, msg) + close(done) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.TextMessage, TestMsg) }) - conn.WriteMessage(websocket.PongMessage, nil) + test(func(ws *TestServer, done chan bool) { + ws.m.HandleMessageBinary(func(s *Session, msg []byte) { + s.WriteBinary(msg) + }) - time.Sleep(time.Millisecond) - - if !fired { - t.Error("should have fired pong handler") - } -} - -func BenchmarkSessionWrite(b *testing.B) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) + ws.m.HandleSentMessageBinary(func(s *Session, msg []byte) { + assert.Equal(t, TestMsg, msg) + close(done) + }) + }, func(conn *websocket.Conn) { + conn.WriteMessage(websocket.BinaryMessage, TestMsg) }) - server := httptest.NewServer(echo) - conn, _ := NewDialer(server.URL) - defer server.Close() - defer conn.Close() - - for n := 0; n < b.N; n++ { - conn.WriteMessage(websocket.TextMessage, []byte("test")) - conn.ReadMessage() - } -} - -func BenchmarkBroadcast(b *testing.B) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) - }) - server := httptest.NewServer(echo) - defer server.Close() - - conns := make([]*websocket.Conn, 0) - - num := 100 - - for i := 0; i < num; i++ { - conn, _ := NewDialer(server.URL) - conns = append(conns, conn) - } - - for n := 0; n < b.N; n++ { - echo.m.Broadcast([]byte("test")) - - for i := 0; i < num; i++ { - conns[i].ReadMessage() - } - } - - for i := 0; i < num; i++ { - conns[i].Close() - } }