diff --git a/.gitignore b/.gitignore index 8f5fd87..c610222 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .history +/.idea \ No newline at end of file diff --git a/README.md b/README.md index 5fc85ea..a146202 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,57 @@ # recws + +[![Go Report Card](https://goreportcard.com/badge/github.com/loeffel-io/recws)](https://goreportcard.com/report/github.com/loeffel-io/recws) +[![GitHub license](https://img.shields.io/github/license/Naereen/StrapDown.js.svg)](https://github.com/Naereen/StrapDown.js/blob/master/LICENSE) + Reconnecting WebSocket is a websocket client based on [gorilla/websocket](https://github.com/gorilla/websocket) that will automatically reconnect if the connection is dropped. + +## Basic example + +```go +package main + +import ( + "context" + "github.com/loeffel-io/recws" + "log" + "time" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + ws := recws.RecConn{} + ws.Dial("wss://echo.websocket.org", nil) + + go func() { + time.Sleep(2 * time.Second) + cancel() + }() + + for { + select { + case <-ctx.Done(): + go ws.Close() + log.Printf("Websocket closed %s", ws.GetURL()) + return + default: + if !ws.IsConnected() { + log.Printf("Websocket disconnected %s", ws.GetURL()) + continue + } + + if err := ws.WriteMessage(1, []byte("Incoming")); err != nil { + log.Printf("Error: WriteMessage %s", ws.GetURL()) + return + } + + _, message, err := ws.ReadMessage() + if err != nil { + log.Printf("Error: ReadMessage %s", ws.GetURL()) + return + } + + log.Printf("Success: %s", message) + } + } +} +``` \ No newline at end of file diff --git a/examples/basic.go b/examples/basic.go new file mode 100644 index 0000000..6784f4b --- /dev/null +++ b/examples/basic.go @@ -0,0 +1,46 @@ +package main + +import ( + "context" + "github.com/loeffel-io/recws" + "log" + "time" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + ws := recws.RecConn{} + ws.Dial("wss://echo.websocket.org", nil) + + go func() { + time.Sleep(2 * time.Second) + cancel() + }() + + for { + select { + case <-ctx.Done(): + go ws.Close() + log.Printf("Websocket closed %s", ws.GetURL()) + return + default: + if !ws.IsConnected() { + log.Printf("Websocket disconnected %s", ws.GetURL()) + continue + } + + if err := ws.WriteMessage(1, []byte("Incoming")); err != nil { + log.Printf("Error: WriteMessage %s", ws.GetURL()) + return + } + + _, message, err := ws.ReadMessage() + if err != nil { + log.Printf("Error: ReadMessage %s", ws.GetURL()) + return + } + + log.Printf("Success: %s", message) + } + } +} diff --git a/recws.go b/recws.go index d1f0f30..4fee538 100644 --- a/recws.go +++ b/recws.go @@ -35,8 +35,10 @@ type RecConn struct { HandshakeTimeout time.Duration // NonVerbose suppress connecting/reconnecting messages. NonVerbose bool + // SubscribeHandler fires after the connection successfully establish. + SubscribeHandler func() error - mu sync.Mutex + mu sync.RWMutex url string reqHeader http.Header httpResp *http.Response @@ -47,24 +49,30 @@ type RecConn struct { *websocket.Conn } -// CloseAndRecconect will try to reconnect. -func (rc *RecConn) closeAndRecconect() { +// CloseAndReconnect will try to reconnect. +func (rc *RecConn) closeAndReconnect() { rc.Close() - go func() { - rc.connect() - }() + go rc.connect() +} +// setIsConnected sets state for isConnected +func (rc *RecConn) setIsConnected(state bool) { + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.isConnected = state } // Close closes the underlying network connection without // sending or waiting for a close frame. func (rc *RecConn) Close() { - rc.mu.Lock() + rc.mu.RLock() + defer rc.mu.RUnlock() if rc.Conn != nil { rc.Conn.Close() } - rc.isConnected = false - rc.mu.Unlock() + + rc.setIsConnected(false) } // ReadMessage is a helper method for getting a reader @@ -76,7 +84,7 @@ func (rc *RecConn) ReadMessage() (messageType int, message []byte, err error) { if rc.IsConnected() { messageType, message, err = rc.Conn.ReadMessage() if err != nil { - rc.closeAndRecconect() + rc.closeAndReconnect() } } @@ -92,7 +100,7 @@ func (rc *RecConn) WriteMessage(messageType int, data []byte) error { if rc.IsConnected() { err = rc.Conn.WriteMessage(messageType, data) if err != nil { - rc.closeAndRecconect() + rc.closeAndReconnect() } } @@ -110,79 +118,182 @@ func (rc *RecConn) WriteJSON(v interface{}) error { if rc.IsConnected() { err = rc.Conn.WriteJSON(v) if err != nil { - rc.closeAndRecconect() + rc.closeAndReconnect() } } return err } +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +// +// If the connection is closed ErrNotConnected is returned +func (rc *RecConn) ReadJSON(v interface{}) error { + err := ErrNotConnected + if rc.IsConnected() { + err = rc.Conn.ReadJSON(v) + if err != nil { + rc.closeAndReconnect() + } + } + + return err +} + +func (rc *RecConn) setURL(url string) { + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.url = url +} + +// parseURL parses current url +func (rc *RecConn) parseURL(urlStr string) (string, error) { + if urlStr == "" { + return "", errors.New("dial: url cannot be empty") + } + + u, err := url.Parse(urlStr) + + if err != nil { + return "", errors.New("url: " + err.Error()) + } + + if u.Scheme != "ws" && u.Scheme != "wss" { + return "", errors.New("url: websocket uris must start with ws or wss scheme") + } + + if u.User != nil { + return "", errors.New("url: user name and password are not allowed in websocket URIs") + } + + return urlStr, nil +} + +func (rc *RecConn) setDefaultRecIntvlMin() { + rc.mu.Lock() + defer rc.mu.Unlock() + + if rc.RecIntvlMin == 0 { + rc.RecIntvlMin = 2 * time.Second + } +} + +func (rc *RecConn) setDefaultRecIntvlMax() { + rc.mu.Lock() + defer rc.mu.Unlock() + + if rc.RecIntvlMax == 0 { + rc.RecIntvlMax = 30 * time.Second + } +} + +func (rc *RecConn) setDefaultRecIntvlFactor() { + rc.mu.Lock() + defer rc.mu.Unlock() + + if rc.RecIntvlFactor == 0 { + rc.RecIntvlFactor = 1.5 + } +} + +func (rc *RecConn) setDefaultHandshakeTimeout() { + rc.mu.Lock() + defer rc.mu.Unlock() + + if rc.HandshakeTimeout == 0 { + rc.HandshakeTimeout = 2 * time.Second + } +} + + +func (rc *RecConn) setDefaultDialer(handshakeTimeout time.Duration) { + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.dialer = &websocket.Dialer{ + HandshakeTimeout: handshakeTimeout, + } +} + +func (rc *RecConn) getHandshakeTimeout() time.Duration { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return rc.HandshakeTimeout +} + // Dial creates a new client connection. // The URL url specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies // (Cookie). Use GetHTTPResponse() method for the response.Header to get // the selected subprotocol (Sec-WebSocket-Protocol) and cookies (Set-Cookie). func (rc *RecConn) Dial(urlStr string, reqHeader http.Header) { - if urlStr == "" { - log.Fatal("Dial: Url cannot be empty") - } - u, err := url.Parse(urlStr) + urlStr, err := rc.parseURL(urlStr) if err != nil { - log.Fatal("Url:", err) + log.Fatalf("Dial: %v", err) } - if u.Scheme != "ws" && u.Scheme != "wss" { - log.Fatal("Url: websocket URIs must start with ws or wss scheme") - } + // Config + rc.setURL(urlStr) + rc.setDefaultRecIntvlMin() + rc.setDefaultRecIntvlMax() + rc.setDefaultRecIntvlFactor() + rc.setDefaultHandshakeTimeout() + rc.setDefaultDialer(rc.getHandshakeTimeout()) - if u.User != nil { - log.Fatal("Url: user name and password are not allowed in websocket URIs") - } - - rc.url = urlStr - - if rc.RecIntvlMin == 0 { - rc.RecIntvlMin = 2 * time.Second - } - - if rc.RecIntvlMax == 0 { - rc.RecIntvlMax = 30 * time.Second - } - - if rc.RecIntvlFactor == 0 { - rc.RecIntvlFactor = 1.5 - } - - if rc.HandshakeTimeout == 0 { - rc.HandshakeTimeout = 2 * time.Second - } - - rc.dialer = websocket.DefaultDialer - rc.dialer.HandshakeTimeout = rc.HandshakeTimeout - rc.reqHeader = reqHeader - - go func() { - rc.connect() - }() + // Connect + go rc.connect() // wait on first attempt - time.Sleep(rc.HandshakeTimeout) + time.Sleep(rc.getHandshakeTimeout()) } -func (rc *RecConn) connect() { - b := &backoff.Backoff{ +// GetURL returns current connection url +func (rc *RecConn) GetURL() string { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return rc.url +} + +func (rc *RecConn) getNonVerbose() bool { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return rc.NonVerbose +} + +func (rc *RecConn) getBackoff() *backoff.Backoff { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return &backoff.Backoff{ Min: rc.RecIntvlMin, Max: rc.RecIntvlMax, Factor: rc.RecIntvlFactor, Jitter: true, } +} +func (rc *RecConn) hasSubscribeHandler() bool { + rc.mu.RLock() + defer rc.mu.RUnlock() + + return rc.SubscribeHandler != nil +} + +func (rc *RecConn) connect() { + b := rc.getBackoff() rand.Seed(time.Now().UTC().UnixNano()) for { nextItvl := b.Duration() - wsConn, httpResp, err := rc.dialer.Dial(rc.url, rc.reqHeader) rc.mu.Lock() @@ -193,15 +304,26 @@ func (rc *RecConn) connect() { rc.mu.Unlock() if err == nil { - if !rc.NonVerbose { + if !rc.getNonVerbose() { log.Printf("Dial: connection was successfully established with %s\n", rc.url) + + if !rc.hasSubscribeHandler() { + return + } + + if err := rc.SubscribeHandler(); err != nil { + log.Fatalf("Dial: connect handler failed with %s", err.Error()) + } + + log.Printf("Dial: connect handler was successfully established with %s\n", rc.url) } - break - } else { - if !rc.NonVerbose { - log.Println(err) - log.Println("Dial: will try again in", nextItvl, "seconds.") - } + + return + } + + if !rc.getNonVerbose() { + log.Println(err) + log.Println("Dial: will try again in", nextItvl, "seconds.") } time.Sleep(nextItvl) @@ -212,8 +334,8 @@ func (rc *RecConn) connect() { // Useful when WebSocket handshake fails, // so that callers can handle redirects, authentication, etc. func (rc *RecConn) GetHTTPResponse() *http.Response { - rc.mu.Lock() - defer rc.mu.Unlock() + rc.mu.RLock() + defer rc.mu.RUnlock() return rc.httpResp } @@ -221,16 +343,16 @@ func (rc *RecConn) GetHTTPResponse() *http.Response { // GetDialError returns the last dialer error. // nil on successful connection. func (rc *RecConn) GetDialError() error { - rc.mu.Lock() - defer rc.mu.Unlock() + rc.mu.RLock() + defer rc.mu.RUnlock() return rc.dialErr } // IsConnected returns the WebSocket connection state func (rc *RecConn) IsConnected() bool { - rc.mu.Lock() - defer rc.mu.Unlock() + rc.mu.RLock() + defer rc.mu.RUnlock() return rc.isConnected }