diff --git a/recws.go b/recws.go index 1d4b27a..4faa7af 100644 --- a/recws.go +++ b/recws.go @@ -91,6 +91,18 @@ func (rc *RecConn) Close() { rc.setIsConnected(false) } +// Shutdown gracefully closes the connection by sending the websocket.CloseMessage. +// The writeWait param defines the duration before the deadline of the write operation is hit. +func (rc *RecConn) Shutdown(writeWait time.Duration) { + msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := rc.WriteControl(websocket.CloseMessage, msg, time.Now().Add(writeWait)) + if err != nil && err != websocket.ErrCloseSent { + // If close message could not be sent, then close without the handshake. + log.Printf("Shutdown: %v", err) + rc.Close() + } +} + // ReadMessage is a helper method for getting a reader // using NextReader and reading from that reader to a buffer. // @@ -99,6 +111,10 @@ func (rc *RecConn) ReadMessage() (messageType int, message []byte, err error) { err = ErrNotConnected if rc.IsConnected() { messageType, message, err = rc.Conn.ReadMessage() + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + rc.Close() + return messageType, message, nil + } if err != nil { rc.CloseAndReconnect() } @@ -117,6 +133,10 @@ func (rc *RecConn) WriteMessage(messageType int, data []byte) error { rc.mu.Lock() err = rc.Conn.WriteMessage(messageType, data) rc.mu.Unlock() + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + rc.Close() + return nil + } if err != nil { rc.CloseAndReconnect() } @@ -137,6 +157,10 @@ func (rc *RecConn) WriteJSON(v interface{}) error { rc.mu.Lock() err = rc.Conn.WriteJSON(v) rc.mu.Unlock() + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + rc.Close() + return nil + } if err != nil { rc.CloseAndReconnect() } @@ -156,6 +180,10 @@ func (rc *RecConn) ReadJSON(v interface{}) error { err := ErrNotConnected if rc.IsConnected() { err = rc.Conn.ReadJSON(v) + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + rc.Close() + return nil + } if err != nil { rc.CloseAndReconnect() } @@ -421,7 +449,7 @@ func (rc *RecConn) connect() { if rc.getKeepAliveTimeout() != 0 { rc.keepAlive() } - + return }