error checking and reconnect

This commit is contained in:
callebtc
2023-01-25 00:35:48 +01:00
parent d7fb45f6a1
commit 06362a47a9

View File

@@ -1,4 +1,5 @@
import json import json
import time
from threading import Lock from threading import Lock
from websocket import WebSocketApp from websocket import WebSocketApp
from .event import Event from .event import Event
@@ -29,6 +30,11 @@ class Relay:
self.policy = policy self.policy = policy
self.message_pool = message_pool self.message_pool = message_pool
self.subscriptions = subscriptions self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
self.error_counter: int = 0
self.error_threshold: int = 0
self.ssl_options: dict = {}
self.lock = Lock() self.lock = Lock()
self.ws = WebSocketApp( self.ws = WebSocketApp(
url, url,
@@ -38,13 +44,25 @@ class Relay:
on_close=self._on_close, on_close=self._on_close,
) )
def connect(self, ssl_options: dict = None): def connect(self, ssl_options: dict = {}):
self.ws.run_forever(sslopt=ssl_options) self.ssl_options = ssl_options
self.ws.run_forever(sslopt=self.ssl_options)
def close(self): def close(self):
self.ws.close() self.ws.close()
def check_reconnect(self):
try:
self.close()
except:
pass
self.connected = False
if self.reconnect:
time.sleep(1)
self.connect(self.ssl_options)
def publish(self, message: str): def publish(self, message: str):
if self.connected:
self.ws.send(message) self.ws.send(message)
def add_subscription(self, id, filters: Filters): def add_subscription(self, id, filters: Filters):
@@ -71,9 +89,12 @@ class Relay:
} }
def _on_open(self, class_obj): def _on_open(self, class_obj):
self.connected = True
pass pass
def _on_close(self, class_obj, status_code, message): def _on_close(self, class_obj, status_code, message):
self.connected = False
self.check_reconnect()
pass pass
def _on_message(self, class_obj, message: str): def _on_message(self, class_obj, message: str):
@@ -81,7 +102,12 @@ class Relay:
self.message_pool.add_message(message, self.url) self.message_pool.add_message(message, self.url)
def _on_error(self, class_obj, error): def _on_error(self, class_obj, error):
self.connected = False
self.error_counter += 1
if self.error_threshold and self.error_counter > self.error_threshold:
pass pass
else:
self.check_reconnect()
def _is_valid_message(self, message: str) -> bool: def _is_valid_message(self, message: str) -> bool:
message = message.strip("\n") message = message.strip("\n")
@@ -117,7 +143,7 @@ class Relay:
with self.lock: with self.lock:
subscription = self.subscriptions[subscription_id] subscription = self.subscriptions[subscription_id]
if not subscription.filters.match(event): if subscription.filters and not subscription.filters.match(event):
return False return False
return True return True