ensure order of messages and dispatch error handler when buffer is full
This commit is contained in:
@@ -8,7 +8,7 @@ type Config struct {
|
||||
PongWait time.Duration // Timeout for waiting on pong.
|
||||
PingPeriod time.Duration // Milliseconds between pings.
|
||||
MaxMessageSize int64 // Maximum size in bytes of a message.
|
||||
MessageBufferSize int // Size of each sessions message buffer.
|
||||
MessageBufferSize int // The max amount of messages that can be in a sessions buffer before it starts dropping them.
|
||||
}
|
||||
|
||||
func newConfig() *Config {
|
||||
|
||||
@@ -36,10 +36,10 @@ loop:
|
||||
for s := range h.sessions {
|
||||
if m.filter != nil {
|
||||
if m.filter(s) {
|
||||
go s.writeMessage(m)
|
||||
s.writeMessage(m)
|
||||
}
|
||||
} else {
|
||||
go s.writeMessage(m)
|
||||
s.writeMessage(m)
|
||||
}
|
||||
}
|
||||
case <-h.exit:
|
||||
|
||||
@@ -78,15 +78,20 @@ func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
session := newSession(m.Config, conn, r)
|
||||
session := &Session{
|
||||
Request: r,
|
||||
conn: conn,
|
||||
output: make(chan *envelope, m.Config.MessageBufferSize),
|
||||
melody: m,
|
||||
}
|
||||
|
||||
m.hub.register <- session
|
||||
|
||||
go m.connectHandler(session)
|
||||
|
||||
go session.writePump(m.errorHandler)
|
||||
go session.writePump()
|
||||
|
||||
session.readPump(m.messageHandler, m.messageHandlerBinary, m.errorHandler)
|
||||
session.readPump()
|
||||
|
||||
if m.hub.open {
|
||||
m.hub.unregister <- session
|
||||
|
||||
@@ -340,3 +340,26 @@ func TestStop(t *testing.T) {
|
||||
|
||||
noecho.m.Close()
|
||||
}
|
||||
|
||||
func TestSmallMessageBuffer(t *testing.T) {
|
||||
echo := NewTestServerHandler(func(session *Session, msg []byte) {
|
||||
session.Write(msg)
|
||||
})
|
||||
echo.m.Config.MessageBufferSize = 0
|
||||
echo.m.HandleError(func(s *Session, err error) {
|
||||
if err == nil {
|
||||
t.Error("there should be a buffer full error here")
|
||||
}
|
||||
})
|
||||
server := httptest.NewServer(echo)
|
||||
defer server.Close()
|
||||
|
||||
conn, err := NewDialer(server.URL)
|
||||
defer conn.Close()
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
conn.WriteMessage(websocket.TextMessage, []byte("12345"))
|
||||
}
|
||||
|
||||
+18
-22
@@ -1,6 +1,7 @@
|
||||
package melody
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -11,24 +12,19 @@ type Session struct {
|
||||
Request *http.Request
|
||||
conn *websocket.Conn
|
||||
output chan *envelope
|
||||
config *Config
|
||||
}
|
||||
|
||||
func newSession(config *Config, conn *websocket.Conn, req *http.Request) *Session {
|
||||
return &Session{
|
||||
Request: req,
|
||||
conn: conn,
|
||||
output: make(chan *envelope, config.MessageBufferSize),
|
||||
config: config,
|
||||
}
|
||||
melody *Melody
|
||||
}
|
||||
|
||||
func (s *Session) writeMessage(message *envelope) {
|
||||
s.output <- message
|
||||
if len(s.output) < s.melody.Config.MessageBufferSize {
|
||||
s.output <- message
|
||||
} else {
|
||||
s.melody.errorHandler(s, errors.New("Message buffer full"))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) writeRaw(message *envelope) error {
|
||||
s.conn.SetWriteDeadline(time.Now().Add(s.config.WriteWait))
|
||||
s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait))
|
||||
err := s.conn.WriteMessage(message.t, message.msg)
|
||||
|
||||
if err != nil {
|
||||
@@ -54,10 +50,10 @@ func (s *Session) ping() {
|
||||
s.writeMessage(&envelope{t: websocket.PingMessage, msg: []byte{}})
|
||||
}
|
||||
|
||||
func (s *Session) writePump(errorHandler handleErrorFunc) {
|
||||
func (s *Session) writePump() {
|
||||
defer s.conn.Close()
|
||||
|
||||
ticker := time.NewTicker(s.config.PingPeriod)
|
||||
ticker := time.NewTicker(s.melody.Config.PingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
loop:
|
||||
@@ -69,7 +65,7 @@ loop:
|
||||
break loop
|
||||
}
|
||||
if err := s.writeRaw(msg); err != nil {
|
||||
go errorHandler(s, err)
|
||||
s.melody.errorHandler(s, err)
|
||||
break loop
|
||||
}
|
||||
case <-ticker.C:
|
||||
@@ -78,14 +74,14 @@ loop:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) readPump(messageHandler handleMessageFunc, messageHandlerBinary handleMessageFunc, errorHandler handleErrorFunc) {
|
||||
func (s *Session) readPump() {
|
||||
defer s.conn.Close()
|
||||
|
||||
s.conn.SetReadLimit(s.config.MaxMessageSize)
|
||||
s.conn.SetReadDeadline(time.Now().Add(s.config.PongWait))
|
||||
s.conn.SetReadLimit(s.melody.Config.MaxMessageSize)
|
||||
s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait))
|
||||
|
||||
s.conn.SetPongHandler(func(string) error {
|
||||
s.conn.SetReadDeadline(time.Now().Add(s.config.PongWait))
|
||||
s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -93,16 +89,16 @@ func (s *Session) readPump(messageHandler handleMessageFunc, messageHandlerBinar
|
||||
t, message, err := s.conn.ReadMessage()
|
||||
|
||||
if err != nil {
|
||||
go errorHandler(s, err)
|
||||
s.melody.errorHandler(s, err)
|
||||
break
|
||||
}
|
||||
|
||||
if t == websocket.TextMessage {
|
||||
go messageHandler(s, message)
|
||||
s.melody.messageHandler(s, message)
|
||||
}
|
||||
|
||||
if t == websocket.BinaryMessage {
|
||||
go messageHandlerBinary(s, message)
|
||||
s.melody.messageHandlerBinary(s, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user