diff --git a/src/api/api.go b/src/api/api.go index b649e5f..fde1f89 100644 --- a/src/api/api.go +++ b/src/api/api.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "strconv" + "time" "github.com/gabriel-vasile/mimetype" "github.com/gin-gonic/gin" @@ -15,6 +16,17 @@ import ( utils "github.com/bbernhard/signal-cli-rest-api/utils" ) +const ( + // Time allowed to write the file to the client. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the client. + pongWait = 60 * time.Second + + // Send pings to client with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + type GroupPermissions struct { AddMembers string `json:"add_members" enums:"only-admins,every-member"` EditGroup string `json:"edit_group" enums:"only-admins,every-member"` @@ -267,6 +279,57 @@ func (a *Api) SendV2(c *gin.Context) { c.JSON(201, SendMessageResponse{Timestamp: strconv.FormatInt((*timestamps)[0].Timestamp, 10)}) } + +func (a *Api) handleSignalReceive(ws *websocket.Conn, number string) { + for { + data, err := a.signalClient.Receive(number, 0) + if err == nil { + err = ws.WriteMessage(websocket.TextMessage, []byte(data)) + if err != nil { + log.Error("Couldn't write message: " + err.Error()) + return + } + } else { + errorMsg := Error{Msg: err.Error()} + errorMsgBytes, err := json.Marshal(errorMsg) + if err != nil { + log.Error("Couldn't serialize error message: " + err.Error()) + return + } + err = ws.WriteMessage(websocket.TextMessage, errorMsgBytes) + if err != nil { + log.Error("Couldn't write message: " + err.Error()) + return + } + } + } +} + +func wsPong(ws *websocket.Conn) { + ws.SetReadLimit(512) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, _, err := ws.ReadMessage() + if err != nil { + break + } + } +} + +func wsPing(ws *websocket.Conn) { + pingTicker := time.NewTicker(pingPeriod) + for { + select { + case <-pingTicker.C: + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + // @Summary Receive Signal Messages. // @Tags Messages // @Description Receives Signal Messages from the Signal Network. @@ -287,28 +350,9 @@ func (a *Api) Receive(c *gin.Context) { return } defer ws.Close() - for { - data, err := a.signalClient.Receive(number, 0) - if err == nil { - err = ws.WriteMessage(websocket.TextMessage, []byte(data)) - if err != nil { - log.Error("Couldn't write message: " + err.Error()) - return - } - } else { - errorMsg := Error{Msg: err.Error()} - errorMsgBytes, err := json.Marshal(errorMsg) - if err != nil { - log.Error("Couldn't serialize error message: " + err.Error()) - return - } - err = ws.WriteMessage(websocket.TextMessage, errorMsgBytes) - if err != nil { - log.Error("Couldn't write message: " + err.Error()) - return - } - } - } + go a.handleSignalReceive(ws, number) + go wsPing(ws) + wsPong(ws) } else { timeout := c.DefaultQuery("timeout", "1") timeoutInt, err := strconv.ParseInt(timeout, 10, 32)