protect websocket write with mutex

* Gorilla Websocket only allows one concurrent writer. As there are
  multiple goroutines that could write concurrently to the websocket,
  the websocket write needs to be protected by a Mutext. This is not
  particular nice, but a complete rewrite of the connection handling
  would be quite a lot of work.

see #556
This commit is contained in:
Bernhard B
2024-07-09 21:19:49 +02:00
parent 911b686778
commit cd996e1814

View File

@@ -8,6 +8,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"sync"
"github.com/gabriel-vasile/mimetype" "github.com/gabriel-vasile/mimetype"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -197,6 +198,7 @@ type AddStickerPackRequest struct {
type Api struct { type Api struct {
signalClient *client.SignalClient signalClient *client.SignalClient
wsMutex sync.Mutex
} }
func NewApi(signalClient *client.SignalClient) *Api { func NewApi(signalClient *client.SignalClient) *Api {
@@ -471,6 +473,8 @@ func (a *Api) handleSignalReceive(ws *websocket.Conn, number string, stop chan s
} }
if response.Account == number { if response.Account == number {
a.wsMutex.Lock()
defer a.wsMutex.Unlock()
err = ws.WriteMessage(websocket.TextMessage, []byte(data)) err = ws.WriteMessage(websocket.TextMessage, []byte(data))
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
@@ -487,6 +491,8 @@ func (a *Api) handleSignalReceive(ws *websocket.Conn, number string, stop chan s
log.Error("Couldn't serialize error message: " + err.Error()) log.Error("Couldn't serialize error message: " + err.Error())
return return
} }
a.wsMutex.Lock()
defer a.wsMutex.Unlock()
err = ws.WriteMessage(websocket.TextMessage, errorMsgBytes) err = ws.WriteMessage(websocket.TextMessage, errorMsgBytes)
if err != nil { if err != nil {
log.Error("Couldn't write message: " + err.Error()) log.Error("Couldn't write message: " + err.Error())
@@ -513,7 +519,7 @@ func wsPong(ws *websocket.Conn, stop chan struct{}) {
} }
} }
func wsPing(ws *websocket.Conn, stop chan struct{}) { func (a *Api) wsPing(ws *websocket.Conn, stop chan struct{}) {
pingTicker := time.NewTicker(pingPeriod) pingTicker := time.NewTicker(pingPeriod)
for { for {
select { select {
@@ -521,6 +527,8 @@ func wsPing(ws *websocket.Conn, stop chan struct{}) {
ws.Close() ws.Close()
return return
case <-pingTicker.C: case <-pingTicker.C:
a.wsMutex.Lock()
defer a.wsMutex.Unlock()
if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return return
} }
@@ -561,7 +569,7 @@ func (a *Api) Receive(c *gin.Context) {
defer ws.Close() defer ws.Close()
var stop = make(chan struct{}) var stop = make(chan struct{})
go a.handleSignalReceive(ws, number, stop) go a.handleSignalReceive(ws, number, stop)
go wsPing(ws, stop) go a.wsPing(ws, stop)
wsPong(ws, stop) wsPong(ws, stop)
} else { } else {
timeout := c.DefaultQuery("timeout", "1") timeout := c.DefaultQuery("timeout", "1")