diff --git a/nostr/relay.py b/nostr/relay.py index 80fca7e..6571d85 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -1,9 +1,9 @@ import json import ssl -from typing import Union -from websocket import WebSocket, WebSocketConnectionClosedException, WebSocketTimeoutException +from websocket import WebSocketApp from .event import Event from .filter import Filters +from .message_pool import MessagePool from .message_type import RelayMessageType from .subscription import Subscription @@ -22,24 +22,30 @@ class Relay: def __init__( self, url: str, - policy: RelayPolicy, - ws: WebSocket=WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE}), + policy: RelayPolicy, + message_pool: MessagePool, subscriptions: dict[str, Subscription]={}) -> None: self.url = url self.policy = policy - self.ws = ws + self.message_pool = message_pool self.subscriptions = subscriptions + self.ws = WebSocketApp( + url, + on_open=self._on_open, + on_message=self._on_message, + on_error=self._on_error, + on_close=self._on_close) - def open_websocket_connection(self, timeout: int=None) -> None: - if timeout != None: - self.ws.connect(self.url, timeout=timeout) - else: - self.ws.connect(self.url) + def connect(self): + self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - def close_websocket_connection(self) -> None: + def close(self): self.ws.close() - def add_subscription(self, id: str, filters: Filters) -> None: + def publish(self, message: str): + self.ws.send(message) + + def add_subscription(self, id, filters: Filters): self.subscriptions[id] = Subscription(id, filters) def close_subscription(self, id: str) -> None: @@ -49,27 +55,6 @@ class Relay: subscription = self.subscriptions[id] subscription.filters = filters - def publish_message(self, message: str) -> None: - self.ws.send(message) - - def get_message(self) -> Union[None, str]: - while True: - try: - message = self.ws.recv() - if not self._is_valid_message(message): - continue - - return message - - except WebSocketConnectionClosedException: - print('received connection closed') - break - except WebSocketTimeoutException: - print('ws connection timed out') - break - - return None - def to_json_object(self) -> dict: return { "url": self.url, @@ -77,16 +62,27 @@ class Relay: "subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()] } + def _on_open(self, class_obj): + pass + + def _on_close(self, class_obj, status_code, message): + pass + + def _on_message(self, class_obj, message: str): + if self._is_valid_message(message): + self.message_pool.add_message(message, self.url) + + def _on_error(self, class_obj, error): + pass + def _is_valid_message(self, message: str) -> bool: if not message or message[0] != '[' or message[-1] != ']': return False message_json = json.loads(message) message_type = message_json[0] - if message_type == RelayMessageType.NOTICE: - return True - if message_type == RelayMessageType.END_OF_STORED_EVENTS: - return True + if not RelayMessageType.is_valid(message_type): + return False if message_type == RelayMessageType.EVENT: if not len(message_json) == 3: return False