diff --git a/keepalive.go b/keepalive.go index bb14889..55d71fa 100644 --- a/keepalive.go +++ b/keepalive.go @@ -6,7 +6,8 @@ import ( ) type keepAliveResponse struct { - lastResponse time.Time + allowDataResponse bool + lastResponse time.Time sync.RWMutex } @@ -17,6 +18,22 @@ func (k *keepAliveResponse) setLastResponse() { k.lastResponse = time.Now() } +func (k *keepAliveResponse) getAllowDataResponse() bool { + k.RLock() + allow := k.allowDataResponse + k.RUnlock() + return allow +} + +func (k *keepAliveResponse) setLastDataResponse() { + allow := k.getAllowDataResponse() + if allow { + k.Lock() + k.lastResponse = time.Now() + k.Unlock() + } +} + func (k *keepAliveResponse) getLastResponse() time.Time { k.RLock() defer k.RUnlock() diff --git a/recws.go b/recws.go index 29ca31f..157c0c8 100644 --- a/recws.go +++ b/recws.go @@ -49,14 +49,17 @@ type RecConn struct { LogHandler func(v LogValues) // NonVerbose suppress connecting/reconnecting messages. NonVerbose bool + // AllowKeepAliveDataResponse allows recognize data response like keepalive response + AllowKeepAliveDataResponse bool - isConnected bool - mu sync.RWMutex - url string - reqHeader http.Header - httpResp *http.Response - dialErr error - dialer *websocket.Dialer + isConnected bool + mu sync.RWMutex + url string + reqHeader http.Header + httpResp *http.Response + dialErr error + dialer *websocket.Dialer + keepAliveResponse *keepAliveResponse *websocket.Conn } @@ -133,6 +136,9 @@ func (rc *RecConn) ReadMessage() (messageType int, message []byte, err error) { rc.CloseAndReconnect() } } + if err == nil { + rc.getKeepAliveResponse().setLastDataResponse() + } return } @@ -206,6 +212,13 @@ func (rc *RecConn) ReadJSON(v interface{}) error { return err } +func (rc *RecConn) getKeepAliveResponse() *keepAliveResponse { + rc.mu.RLock() + ka := rc.keepAliveResponse + rc.mu.RUnlock() + return ka +} + func (rc *RecConn) setURL(url string) { rc.mu.Lock() defer rc.mu.Unlock() @@ -409,13 +422,12 @@ func (rc *RecConn) writeControlPingMessage() error { func (rc *RecConn) keepAlive() { var ( - keepAliveResponse = new(keepAliveResponse) - ticker = time.NewTicker(rc.getKeepAliveTimeout()) + ticker = time.NewTicker(rc.getKeepAliveTimeout()) ) rc.mu.Lock() rc.Conn.SetPongHandler(func(msg string) error { - keepAliveResponse.setLastResponse() + rc.getKeepAliveResponse().setLastResponse() return nil }) rc.mu.Unlock() @@ -434,7 +446,7 @@ func (rc *RecConn) keepAlive() { <-ticker.C timeoutOffset := time.Millisecond * 500 - if time.Since(keepAliveResponse.getLastResponse()) > rc.getKeepAliveTimeout()+timeoutOffset { + if time.Since(rc.getKeepAliveResponse().getLastResponse()) > rc.getKeepAliveTimeout()+timeoutOffset { rc.log(LogValues{Err: errors.New("keepalive timeout"), Msg: "Reconnect", Url: rc.url}) rc.CloseAndReconnect() return @@ -456,6 +468,8 @@ func (rc *RecConn) connect() { rc.dialErr = err rc.isConnected = err == nil rc.httpResp = httpResp + rc.keepAliveResponse = new(keepAliveResponse) + rc.keepAliveResponse.allowDataResponse = rc.AllowKeepAliveDataResponse rc.mu.Unlock() if err == nil {