diff --git a/hub.go b/hub.go index 4a8cffa..a3b2977 100644 --- a/hub.go +++ b/hub.go @@ -2,90 +2,124 @@ package melody import ( "sync" + "sync/atomic" ) +type sessionSet struct { + mu sync.RWMutex + members map[*Session]struct{} +} + +func (ss *sessionSet) add(s *Session) { + ss.mu.Lock() + defer ss.mu.Unlock() + + ss.members[s] = struct{}{} +} + +func (ss *sessionSet) del(s *Session) { + ss.mu.Lock() + defer ss.mu.Unlock() + + delete(ss.members, s) +} + +func (ss *sessionSet) clear() { + ss.mu.Lock() + defer ss.mu.Unlock() + + ss.members = make(map[*Session]struct{}) +} + +func (ss *sessionSet) each(cb func(*Session)) { + ss.mu.RLock() + defer ss.mu.RUnlock() + + for s := range ss.members { + cb(s) + } +} + +func (ss *sessionSet) len() int { + ss.mu.RLock() + defer ss.mu.RUnlock() + + return len(ss.members) +} + +func (ss *sessionSet) all() []*Session { + ss.mu.RLock() + defer ss.mu.RUnlock() + + s := make([]*Session, 0, len(ss.members)) + for k := range ss.members { + s = append(s, k) + } + + return s +} + type hub struct { - sessions map[*Session]bool + sessions sessionSet broadcast chan *envelope register chan *Session unregister chan *Session exit chan *envelope - open bool - rwmutex *sync.RWMutex + open atomic.Bool } func newHub() *hub { return &hub{ - sessions: make(map[*Session]bool), + sessions: sessionSet{ + members: make(map[*Session]struct{}), + }, broadcast: make(chan *envelope), register: make(chan *Session), unregister: make(chan *Session), exit: make(chan *envelope), - open: true, - rwmutex: &sync.RWMutex{}, } } func (h *hub) run() { + h.open.Store(true) + loop: for { select { case s := <-h.register: - h.rwmutex.Lock() - h.sessions[s] = true - h.rwmutex.Unlock() + h.sessions.add(s) case s := <-h.unregister: - if _, ok := h.sessions[s]; ok { - h.rwmutex.Lock() - delete(h.sessions, s) - h.rwmutex.Unlock() - } + h.sessions.del(s) case m := <-h.broadcast: - h.rwmutex.RLock() - for s := range h.sessions { - if m.filter != nil { - if m.filter(s) { - s.writeMessage(m) - } - } else { + h.sessions.each(func(s *Session) { + if m.filter == nil { + s.writeMessage(m) + } else if m.filter(s) { s.writeMessage(m) } - } - h.rwmutex.RUnlock() + }) case m := <-h.exit: - h.rwmutex.Lock() - for s := range h.sessions { + h.sessions.each(func(s *Session) { s.writeMessage(m) - delete(h.sessions, s) s.Close() - } - h.open = false - h.rwmutex.Unlock() + }) + + h.sessions.clear() + h.open.Store(false) + break loop } } } func (h *hub) closed() bool { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - return !h.open + return !h.open.Load() } func (h *hub) len() int { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - - return len(h.sessions) + return h.sessions.len() } func (h *hub) all() []*Session { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - - s := make([]*Session, 0, len(h.sessions)) - for k := range h.sessions { - s = append(s, k) - } - return s + return h.sessions.all() }