From c6b416ead4431e57be2eb9c913bfbf72e4ba3cf0 Mon Sep 17 00:00:00 2001 From: Lorenzo Donini Date: Mon, 12 Jul 2021 00:45:01 +0200 Subject: [PATCH] Refactor websocket closing signals - improved logic when stopping the server to gracefully shutdown every connection - add ServerStopAllConnections test Signed-off-by: Lorenzo Donini --- ws/websocket.go | 160 ++++++++++++++++++++++++++++--------------- ws/websocket_test.go | 57 ++++++++++++++- 2 files changed, 158 insertions(+), 59 deletions(-) diff --git a/ws/websocket.go b/ws/websocket.go index d61ffcc2..eabff002 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -101,7 +101,8 @@ type WebSocket struct { connection *websocket.Conn id string outQueue chan []byte - closeSignal chan error // used by the readPump to notify the closed connection to the writePump + closeC chan websocket.CloseError // used to gracefully close a websocket connection. + forceCloseC chan error // used by the readPump to notify a forcefully closed connection to the writePump. pingMessage chan []byte tlsConnectionState *tls.ConnectionState } @@ -344,6 +345,7 @@ func (server *Server) Start(port int, listenPath string) { defer ln.Close() + server.httpServer.RegisterOnShutdown(server.stopConnections) if server.tlsCertificatePath != "" && server.tlsCertificateKey != "" { err = server.httpServer.ServeTLS(ln, server.tlsCertificatePath, server.tlsCertificateKey) } else { @@ -375,11 +377,16 @@ func (server *Server) StopConnection(id string, closeError websocket.CloseError) if !ok { return fmt.Errorf("couldn't stop websocket connection. No connection with id %s is open", id) } - return ws.connection.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(closeError.Code, closeError.Text), - time.Now().Add(server.timeoutConfig.WriteWait), - ) + ws.closeC <- closeError + return nil +} + +func (server *Server) stopConnections() { + server.connMutex.Lock() + defer server.connMutex.Unlock() + for _, conn := range server.connections { + conn.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""} + } } func (server *Server) Write(webSocketId string, data []byte) error { @@ -445,7 +452,8 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) { connection: conn, id: id, outQueue: make(chan []byte, 1), - closeSignal: make(chan error, 1), + closeC: make(chan websocket.CloseError, 1), + forceCloseC: make(chan error, 1), pingMessage: make(chan []byte, 1), tlsConnectionState: r.TLS, } @@ -477,11 +485,8 @@ func (server *Server) readPump(ws *WebSocket) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { server.error(fmt.Errorf("read failed for %s: %w", ws.ID(), err)) } - // Notify writePump of error. Disconnection will be handled there - ws.closeSignal <- err - if server.disconnectedHandler != nil { - server.disconnectedHandler(ws) - } + // Notify writePump of error. Force close will be handled there + ws.forceCloseC <- err return } @@ -499,29 +504,23 @@ func (server *Server) readPump(ws *WebSocket) { func (server *Server) writePump(ws *WebSocket) { conn := ws.connection - defer func() { - _ = conn.Close() - server.connMutex.Lock() - defer server.connMutex.Unlock() - delete(server.connections, ws.id) - }() for { select { case data, ok := <-ws.outQueue: _ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait)) if !ok { - // Closing connection - err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - server.error(fmt.Errorf("close failed: %w", err)) - } + // Unexpected closed queue, should never happen + server.error(fmt.Errorf("output queue for socket %v was closed, forcefully closing", ws.id)) + // Don't invoke cleanup return } - + // Send data err := conn.WriteMessage(websocket.TextMessage, data) if err != nil { server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err)) + // Invoking cleanup, as socket was forcefully closed + server.cleanupConnection(ws) return } case ping := <-ws.pingMessage: @@ -529,17 +528,46 @@ func (server *Server) writePump(ws *WebSocket) { err := conn.WriteMessage(websocket.PongMessage, ping) if err != nil { server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err)) + // Invoking cleanup, as socket was forcefully closed + server.cleanupConnection(ws) return } - case closed, ok := <-ws.closeSignal: + case closeErr, _ := <-ws.closeC: + // Closing connection gracefully + if err := conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(closeErr.Code, closeErr.Text), + time.Now().Add(server.timeoutConfig.WriteWait), + ); err != nil { + server.error(fmt.Errorf("failed to write close message for connection %s: %w", ws.id, err)) + } + // Invoking cleanup + server.cleanupConnection(ws) + return + case closed, ok := <-ws.forceCloseC: if !ok || closed != nil { - //TODO: handle signal - return + // Connection was forcefully closed, invoke cleanup + server.cleanupConnection(ws) } + return } } } +// Frees internal resources after a websocket connection was signaled to be closed. +// From this moment onwards, no new messages may be sent. +func (server *Server) cleanupConnection(ws *WebSocket) { + _ = ws.connection.Close() + server.connMutex.Lock() + close(ws.outQueue) + close(ws.closeC) + delete(server.connections, ws.id) + server.connMutex.Unlock() + if server.disconnectedHandler != nil { + server.disconnectedHandler(ws) + } +} + // ---------------------- CLIENT ---------------------- // A Websocket client, needed to connect to a websocket server. @@ -638,7 +666,7 @@ type Client struct { onReconnected func() mutex sync.Mutex errC chan error - stopped chan struct{} + reconnectC chan struct{} // used for signaling, that a reconnection attempt should be interrupted } // Creates a new simple websocket client (the channel is not secured). @@ -708,11 +736,11 @@ func (client *Client) SetHeaderValue(key string, value string) { func (client *Client) writePump() { ticker := time.NewTicker(client.timeoutConfig.PingPeriod) conn := client.webSocket.connection - // Closure function shuts down the current connection + // Closure function correctly closes the current connection closure := func(err error) { ticker.Stop() - _ = conn.Close() - client.setConnected(false) + client.cleanup() + // Invoke callback if client.onDisconnected != nil && err != nil { client.onDisconnected(err) } @@ -720,19 +748,9 @@ func (client *Client) writePump() { for { select { - case data, ok := <-client.webSocket.outQueue: + case data, _ := <-client.webSocket.outQueue: + // Send data _ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait)) - if !ok { - // Closing connection normally - err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - client.error(fmt.Errorf("close failed: %w", err)) - } - // Disconnected by user command. Not calling auto-reconnect. - // Passing nil will also not call onDisconnected - closure(nil) - return - } err := conn.WriteMessage(websocket.TextMessage, data) if err != nil { client.error(fmt.Errorf("write failed: %w", err)) @@ -744,13 +762,26 @@ func (client *Client) writePump() { // Send periodic ping _ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait)) if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { - client.error(fmt.Errorf("write failed: %w", err)) + client.error(fmt.Errorf("failed to send ping message: %w", err)) closure(err) client.handleReconnection() return } - case closed, ok := <-client.webSocket.closeSignal: - // Read pump sent a closeSignal (i.e. a message couldn't be read in that moment) + case closeErr, _ := <-client.webSocket.closeC: + // Closing connection gracefully + if err := conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(closeErr.Code, closeErr.Text), + time.Now().Add(client.timeoutConfig.WriteWait), + ); err != nil { + client.error(fmt.Errorf("failed to write close message: %w", err)) + } + // Disconnected by user command. Not calling auto-reconnect. + // Passing nil will also not call onDisconnected. + closure(nil) + return + case closed, ok := <-client.webSocket.forceCloseC: + // Read pump sent a forceClose signal (reading failed -> aborting the connection) if !ok || closed != nil { closure(closed) client.handleReconnection() @@ -772,15 +803,14 @@ func (client *Client) readPump() { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { client.error(fmt.Errorf("read failed: %w", err)) } - // Notify writePump of error. Disconnection will be handled there - client.webSocket.closeSignal <- err + // Notify writePump of error. Forced close will be handled there + client.webSocket.forceCloseC <- err return } if client.messageHandler != nil { err = client.messageHandler(message) if err != nil { - // TODO: Handle? client.error(fmt.Errorf("handle failed: %w", err)) continue } @@ -788,13 +818,25 @@ func (client *Client) readPump() { } } +// Frees internal resources after a websocket connection was signaled to be closed. +// From this moment onwards, no new messages may be sent. +func (client *Client) cleanup() { + client.setConnected(false) + ws := client.webSocket + _ = ws.connection.Close() + client.mutex.Lock() + defer client.mutex.Unlock() + close(ws.outQueue) + close(ws.closeC) +} + func (client *Client) handleReconnection() { delay := client.timeoutConfig.ReconnectBackoff for { // Wait before reconnecting select { case <-time.After(delay): - case <-client.stopped: + case <-client.reconnectC: return } err := client.Start(client.url.String()) @@ -870,11 +912,12 @@ func (client *Client) Start(urlStr string) error { client.webSocket = WebSocket{ connection: ws, id: id, - outQueue: make(chan []byte), - closeSignal: make(chan error, 1), + outQueue: make(chan []byte, 1), + closeC: make(chan websocket.CloseError, 1), + forceCloseC: make(chan error, 1), tlsConnectionState: resp.TLS, } - client.stopped = make(chan struct{}) + client.reconnectC = make(chan struct{}) client.setConnected(true) //Start reader and write routine go client.writePump() @@ -883,10 +926,13 @@ func (client *Client) Start(urlStr string) error { } func (client *Client) Stop() { - client.setConnected(false) - close(client.webSocket.outQueue) - - close(client.stopped) + if client.IsConnected() { + client.setConnected(false) + // Send signal for gracefully shutting down the connection + client.webSocket.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""} + } + // Notify reconnection goroutine to stop (if any) + close(client.reconnectC) if client.errC != nil { close(client.errC) client.errC = nil diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 5ca16108..90b1d7b7 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -252,8 +252,6 @@ func TestServerStopConnection(t *testing.T) { disconnectedServerC <- struct{}{} }) wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - strMsg := string(data) - fmt.Println(strMsg) return nil, nil }) wsClient.SetDisconnectedHandler(func(err error) { @@ -290,6 +288,61 @@ func TestServerStopConnection(t *testing.T) { wsServer.Stop() } +func TestWebsocketServerStopAllConnections(t *testing.T) { + triggerC := make(chan struct{}, 1) + numClients := 5 + disconnectedClientC := make(chan struct{}, numClients) + disconnectedServerC := make(chan struct{}, 1) + wsServer := newWebsocketServer(t, nil) + wsServer.SetNewClientHandler(func(ws Channel) { + triggerC <- struct{}{} + }) + wsServer.SetDisconnectedClientHandler(func(ws Channel) { + disconnectedServerC <- struct{}{} + }) + // Start server + go wsServer.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + // Connect clients + clients := []WsClient{} + host := fmt.Sprintf("localhost:%v", serverPort) + for i := 0; i < numClients; i++ { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { + return nil, nil + }) + wsClient.SetDisconnectedHandler(func(err error) { + require.IsType(t, &websocket.CloseError{}, err) + closeErr, _ := err.(*websocket.CloseError) + assert.Equal(t, websocket.CloseNormalClosure, closeErr.Code) + assert.Equal(t, "", closeErr.Text) + disconnectedClientC <- struct{}{} + }) + u := url.URL{Scheme: "ws", Host: host, Path: fmt.Sprintf("%v-%v", testPath, i)} + err := wsClient.Start(u.String()) + require.NoError(t, err) + clients = append(clients, wsClient) + // Wait for client to connect + _, ok := <-triggerC + require.True(t, ok) + } + // Stop server and wait for clients to disconnect + wsServer.Stop() + for disconnects := 0; disconnects < numClients; disconnects++ { + _, ok := <-disconnectedClientC + require.True(t, ok) + _, ok = <-disconnectedServerC + require.True(t, ok) + } + // Check disconnection status + for _, c := range clients { + assert.False(t, c.IsConnected()) + // Client will attempt to reconnect under the hood, but test finishes before this can happen + c.Stop() + } + time.Sleep(100 * time.Millisecond) + assert.Empty(t, wsServer.connections) +} + func TestWebsocketClientConnectionBreak(t *testing.T) { newClient := make(chan bool) disconnected := make(chan bool)