From 06362a47a99b79dcb024e01303e773837f1cfd73 Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Wed, 25 Jan 2023 00:35:48 +0100 Subject: [PATCH] error checking and reconnect --- nostr/relay.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/nostr/relay.py b/nostr/relay.py index ad01ff6..44c05c6 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -1,4 +1,5 @@ import json +import time from threading import Lock from websocket import WebSocketApp from .event import Event @@ -29,6 +30,11 @@ class Relay: self.policy = policy self.message_pool = message_pool self.subscriptions = subscriptions + self.connected: bool = False + self.reconnect: bool = True + self.error_counter: int = 0 + self.error_threshold: int = 0 + self.ssl_options: dict = {} self.lock = Lock() self.ws = WebSocketApp( url, @@ -38,14 +44,26 @@ class Relay: on_close=self._on_close, ) - def connect(self, ssl_options: dict = None): - self.ws.run_forever(sslopt=ssl_options) + def connect(self, ssl_options: dict = {}): + self.ssl_options = ssl_options + self.ws.run_forever(sslopt=self.ssl_options) def close(self): self.ws.close() + def check_reconnect(self): + try: + self.close() + except: + pass + self.connected = False + if self.reconnect: + time.sleep(1) + self.connect(self.ssl_options) + def publish(self, message: str): - self.ws.send(message) + if self.connected: + self.ws.send(message) def add_subscription(self, id, filters: Filters): with self.lock: @@ -71,9 +89,12 @@ class Relay: } def _on_open(self, class_obj): + self.connected = True pass def _on_close(self, class_obj, status_code, message): + self.connected = False + self.check_reconnect() pass def _on_message(self, class_obj, message: str): @@ -81,7 +102,12 @@ class Relay: self.message_pool.add_message(message, self.url) def _on_error(self, class_obj, error): - pass + self.connected = False + self.error_counter += 1 + if self.error_threshold and self.error_counter > self.error_threshold: + pass + else: + self.check_reconnect() def _is_valid_message(self, message: str) -> bool: message = message.strip("\n") @@ -117,7 +143,7 @@ class Relay: with self.lock: subscription = self.subscriptions[subscription_id] - if not subscription.filters.match(event): + if subscription.filters and not subscription.filters.match(event): return False return True