Skip to content

Commit

Permalink
Refactor websocket closing signals
Browse files Browse the repository at this point in the history
- improved logic when stopping the server to gracefully shutdown every connection
- add ServerStopAllConnections test

Signed-off-by: Lorenzo Donini <[email protected]>
  • Loading branch information
lorenzodonini committed Jul 17, 2021
1 parent a4c7494 commit c6b416e
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 59 deletions.
160 changes: 103 additions & 57 deletions ws/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ type WebSocket struct {
connection *websocket.Conn
id string
outQueue chan []byte
closeSignal chan error // used by the readPump to notify the closed connection to the writePump
closeC chan websocket.CloseError // used to gracefully close a websocket connection.
forceCloseC chan error // used by the readPump to notify a forcefully closed connection to the writePump.
pingMessage chan []byte
tlsConnectionState *tls.ConnectionState
}
Expand Down Expand Up @@ -344,6 +345,7 @@ func (server *Server) Start(port int, listenPath string) {

defer ln.Close()

server.httpServer.RegisterOnShutdown(server.stopConnections)
if server.tlsCertificatePath != "" && server.tlsCertificateKey != "" {
err = server.httpServer.ServeTLS(ln, server.tlsCertificatePath, server.tlsCertificateKey)
} else {
Expand Down Expand Up @@ -375,11 +377,16 @@ func (server *Server) StopConnection(id string, closeError websocket.CloseError)
if !ok {
return fmt.Errorf("couldn't stop websocket connection. No connection with id %s is open", id)
}
return ws.connection.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeError.Code, closeError.Text),
time.Now().Add(server.timeoutConfig.WriteWait),
)
ws.closeC <- closeError
return nil
}

func (server *Server) stopConnections() {
server.connMutex.Lock()
defer server.connMutex.Unlock()
for _, conn := range server.connections {
conn.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}
}
}

func (server *Server) Write(webSocketId string, data []byte) error {
Expand Down Expand Up @@ -445,7 +452,8 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
connection: conn,
id: id,
outQueue: make(chan []byte, 1),
closeSignal: make(chan error, 1),
closeC: make(chan websocket.CloseError, 1),
forceCloseC: make(chan error, 1),
pingMessage: make(chan []byte, 1),
tlsConnectionState: r.TLS,
}
Expand Down Expand Up @@ -477,11 +485,8 @@ func (server *Server) readPump(ws *WebSocket) {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
server.error(fmt.Errorf("read failed for %s: %w", ws.ID(), err))
}
// Notify writePump of error. Disconnection will be handled there
ws.closeSignal <- err
if server.disconnectedHandler != nil {
server.disconnectedHandler(ws)
}
// Notify writePump of error. Force close will be handled there
ws.forceCloseC <- err
return
}

Expand All @@ -499,47 +504,70 @@ func (server *Server) readPump(ws *WebSocket) {

func (server *Server) writePump(ws *WebSocket) {
conn := ws.connection
defer func() {
_ = conn.Close()
server.connMutex.Lock()
defer server.connMutex.Unlock()
delete(server.connections, ws.id)
}()

for {
select {
case data, ok := <-ws.outQueue:
_ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait))
if !ok {
// Closing connection
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
server.error(fmt.Errorf("close failed: %w", err))
}
// Unexpected closed queue, should never happen
server.error(fmt.Errorf("output queue for socket %v was closed, forcefully closing", ws.id))
// Don't invoke cleanup
return
}

// Send data
err := conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err))
// Invoking cleanup, as socket was forcefully closed
server.cleanupConnection(ws)
return
}
case ping := <-ws.pingMessage:
_ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait))
err := conn.WriteMessage(websocket.PongMessage, ping)
if err != nil {
server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err))
// Invoking cleanup, as socket was forcefully closed
server.cleanupConnection(ws)
return
}
case closed, ok := <-ws.closeSignal:
case closeErr, _ := <-ws.closeC:
// Closing connection gracefully
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeErr.Code, closeErr.Text),
time.Now().Add(server.timeoutConfig.WriteWait),
); err != nil {
server.error(fmt.Errorf("failed to write close message for connection %s: %w", ws.id, err))
}
// Invoking cleanup
server.cleanupConnection(ws)
return
case closed, ok := <-ws.forceCloseC:
if !ok || closed != nil {
//TODO: handle signal
return
// Connection was forcefully closed, invoke cleanup
server.cleanupConnection(ws)
}
return
}
}
}

// Frees internal resources after a websocket connection was signaled to be closed.
// From this moment onwards, no new messages may be sent.
func (server *Server) cleanupConnection(ws *WebSocket) {
_ = ws.connection.Close()
server.connMutex.Lock()
close(ws.outQueue)
close(ws.closeC)
delete(server.connections, ws.id)
server.connMutex.Unlock()
if server.disconnectedHandler != nil {
server.disconnectedHandler(ws)
}
}

// ---------------------- CLIENT ----------------------

// A Websocket client, needed to connect to a websocket server.
Expand Down Expand Up @@ -638,7 +666,7 @@ type Client struct {
onReconnected func()
mutex sync.Mutex
errC chan error
stopped chan struct{}
reconnectC chan struct{} // used for signaling, that a reconnection attempt should be interrupted
}

// Creates a new simple websocket client (the channel is not secured).
Expand Down Expand Up @@ -708,31 +736,21 @@ func (client *Client) SetHeaderValue(key string, value string) {
func (client *Client) writePump() {
ticker := time.NewTicker(client.timeoutConfig.PingPeriod)
conn := client.webSocket.connection
// Closure function shuts down the current connection
// Closure function correctly closes the current connection
closure := func(err error) {
ticker.Stop()
_ = conn.Close()
client.setConnected(false)
client.cleanup()
// Invoke callback
if client.onDisconnected != nil && err != nil {
client.onDisconnected(err)
}
}

for {
select {
case data, ok := <-client.webSocket.outQueue:
case data, _ := <-client.webSocket.outQueue:
// Send data
_ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait))
if !ok {
// Closing connection normally
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
client.error(fmt.Errorf("close failed: %w", err))
}
// Disconnected by user command. Not calling auto-reconnect.
// Passing nil will also not call onDisconnected
closure(nil)
return
}
err := conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
client.error(fmt.Errorf("write failed: %w", err))
Expand All @@ -744,13 +762,26 @@ func (client *Client) writePump() {
// Send periodic ping
_ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait))
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
client.error(fmt.Errorf("write failed: %w", err))
client.error(fmt.Errorf("failed to send ping message: %w", err))
closure(err)
client.handleReconnection()
return
}
case closed, ok := <-client.webSocket.closeSignal:
// Read pump sent a closeSignal (i.e. a message couldn't be read in that moment)
case closeErr, _ := <-client.webSocket.closeC:
// Closing connection gracefully
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeErr.Code, closeErr.Text),
time.Now().Add(client.timeoutConfig.WriteWait),
); err != nil {
client.error(fmt.Errorf("failed to write close message: %w", err))
}
// Disconnected by user command. Not calling auto-reconnect.
// Passing nil will also not call onDisconnected.
closure(nil)
return
case closed, ok := <-client.webSocket.forceCloseC:
// Read pump sent a forceClose signal (reading failed -> aborting the connection)
if !ok || closed != nil {
closure(closed)
client.handleReconnection()
Expand All @@ -772,29 +803,40 @@ func (client *Client) readPump() {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
client.error(fmt.Errorf("read failed: %w", err))
}
// Notify writePump of error. Disconnection will be handled there
client.webSocket.closeSignal <- err
// Notify writePump of error. Forced close will be handled there
client.webSocket.forceCloseC <- err
return
}

if client.messageHandler != nil {
err = client.messageHandler(message)
if err != nil {
// TODO: Handle?
client.error(fmt.Errorf("handle failed: %w", err))
continue
}
}
}
}

// Frees internal resources after a websocket connection was signaled to be closed.
// From this moment onwards, no new messages may be sent.
func (client *Client) cleanup() {
client.setConnected(false)
ws := client.webSocket
_ = ws.connection.Close()
client.mutex.Lock()
defer client.mutex.Unlock()
close(ws.outQueue)
close(ws.closeC)
}

func (client *Client) handleReconnection() {
delay := client.timeoutConfig.ReconnectBackoff
for {
// Wait before reconnecting
select {
case <-time.After(delay):
case <-client.stopped:
case <-client.reconnectC:
return
}
err := client.Start(client.url.String())
Expand Down Expand Up @@ -870,11 +912,12 @@ func (client *Client) Start(urlStr string) error {
client.webSocket = WebSocket{
connection: ws,
id: id,
outQueue: make(chan []byte),
closeSignal: make(chan error, 1),
outQueue: make(chan []byte, 1),
closeC: make(chan websocket.CloseError, 1),
forceCloseC: make(chan error, 1),
tlsConnectionState: resp.TLS,
}
client.stopped = make(chan struct{})
client.reconnectC = make(chan struct{})
client.setConnected(true)
//Start reader and write routine
go client.writePump()
Expand All @@ -883,10 +926,13 @@ func (client *Client) Start(urlStr string) error {
}

func (client *Client) Stop() {
client.setConnected(false)
close(client.webSocket.outQueue)

close(client.stopped)
if client.IsConnected() {
client.setConnected(false)
// Send signal for gracefully shutting down the connection
client.webSocket.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}
}
// Notify reconnection goroutine to stop (if any)
close(client.reconnectC)
if client.errC != nil {
close(client.errC)
client.errC = nil
Expand Down
57 changes: 55 additions & 2 deletions ws/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,6 @@ func TestServerStopConnection(t *testing.T) {
disconnectedServerC <- struct{}{}
})
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
strMsg := string(data)
fmt.Println(strMsg)
return nil, nil
})
wsClient.SetDisconnectedHandler(func(err error) {
Expand Down Expand Up @@ -290,6 +288,61 @@ func TestServerStopConnection(t *testing.T) {
wsServer.Stop()
}

func TestWebsocketServerStopAllConnections(t *testing.T) {
triggerC := make(chan struct{}, 1)
numClients := 5
disconnectedClientC := make(chan struct{}, numClients)
disconnectedServerC := make(chan struct{}, 1)
wsServer := newWebsocketServer(t, nil)
wsServer.SetNewClientHandler(func(ws Channel) {
triggerC <- struct{}{}
})
wsServer.SetDisconnectedClientHandler(func(ws Channel) {
disconnectedServerC <- struct{}{}
})
// Start server
go wsServer.Start(serverPort, serverPath)
time.Sleep(100 * time.Millisecond)
// Connect clients
clients := []WsClient{}
host := fmt.Sprintf("localhost:%v", serverPort)
for i := 0; i < numClients; i++ {
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
return nil, nil
})
wsClient.SetDisconnectedHandler(func(err error) {
require.IsType(t, &websocket.CloseError{}, err)
closeErr, _ := err.(*websocket.CloseError)
assert.Equal(t, websocket.CloseNormalClosure, closeErr.Code)
assert.Equal(t, "", closeErr.Text)
disconnectedClientC <- struct{}{}
})
u := url.URL{Scheme: "ws", Host: host, Path: fmt.Sprintf("%v-%v", testPath, i)}
err := wsClient.Start(u.String())
require.NoError(t, err)
clients = append(clients, wsClient)
// Wait for client to connect
_, ok := <-triggerC
require.True(t, ok)
}
// Stop server and wait for clients to disconnect
wsServer.Stop()
for disconnects := 0; disconnects < numClients; disconnects++ {
_, ok := <-disconnectedClientC
require.True(t, ok)
_, ok = <-disconnectedServerC
require.True(t, ok)
}
// Check disconnection status
for _, c := range clients {
assert.False(t, c.IsConnected())
// Client will attempt to reconnect under the hood, but test finishes before this can happen
c.Stop()
}
time.Sleep(100 * time.Millisecond)
assert.Empty(t, wsServer.connections)
}

func TestWebsocketClientConnectionBreak(t *testing.T) {
newClient := make(chan bool)
disconnected := make(chan bool)
Expand Down

0 comments on commit c6b416e

Please sign in to comment.