diff --git a/melody_test.go b/melody_test.go index 39ddabc..f5e0b0d 100644 --- a/melody_test.go +++ b/melody_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strconv" "strings" + "sync" "testing" "testing/quick" "time" @@ -280,7 +281,6 @@ func TestBroadcast(t *testing.T) { } func TestClose(t *testing.T) { - done := make(chan bool) ws := NewTestServer() server := httptest.NewServer(ws) @@ -295,13 +295,9 @@ func TestClose(t *testing.T) { defer conns[i].Close() } - q := n + q := make(chan bool) ws.m.HandleDisconnect(func(s *Session) { - q-- - - if q == 0 { - close(done) - } + q <- true }) ws.m.Close() @@ -312,7 +308,13 @@ func TestClose(t *testing.T) { assert.Zero(t, ws.m.Len()) - <-done + m := 0 + for range q { + m += 1 + if m == n { + break + } + } } func TestLen(t *testing.T) { @@ -629,6 +631,46 @@ func TestSessionKeys(t *testing.T) { assert.Nil(t, quick.Check(fn, nil)) } +func TestSessionKeysConcurrent(t *testing.T) { + ss := make(chan *Session) + + ws := NewTestServer() + + ws.m.HandleConnect(func(s *Session) { + ss <- s + }) + + server := httptest.NewServer(ws) + defer server.Close() + + conn := MustNewDialer(server.URL) + defer conn.Close() + + s := <-ss + + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + s.Set("test", TestMsg) + v1, exists := s.Get("test") + + assert.True(t, exists) + assert.Equal(t, v1, TestMsg) + + v2 := s.MustGet("test") + + assert.Equal(t, v1, v2) + + wg.Done() + }() + } + + wg.Wait() +} + func TestMisc(t *testing.T) { res := make(chan *Session) diff --git a/session.go b/session.go index 13b0884..d9f9b4f 100644 --- a/session.go +++ b/session.go @@ -189,6 +189,9 @@ func (s *Session) CloseWithMsg(msg []byte) error { // Set is used to store a new key/value pair exclusivelly for this session. // It also lazy initializes s.Keys if it was not used previously. func (s *Session) Set(key string, value interface{}) { + s.rwmutex.Lock() + defer s.rwmutex.Unlock() + if s.Keys == nil { s.Keys = make(map[string]interface{}) } @@ -199,6 +202,9 @@ func (s *Session) Set(key string, value interface{}) { // Get returns the value for the given key, ie: (value, true). // If the value does not exists it returns (nil, false) func (s *Session) Get(key string) (value interface{}, exists bool) { + s.rwmutex.RLock() + defer s.rwmutex.RUnlock() + if s.Keys != nil { value, exists = s.Keys[key] }