refactor relay

This commit is contained in:
jeffthibault
2022-08-01 20:10:48 -04:00
parent c22917d076
commit de8d66c82d

View File

@@ -1,9 +1,9 @@
import json import json
import ssl import ssl
from typing import Union from websocket import WebSocketApp
from websocket import WebSocket, WebSocketConnectionClosedException, WebSocketTimeoutException
from .event import Event from .event import Event
from .filter import Filters from .filter import Filters
from .message_pool import MessagePool
from .message_type import RelayMessageType from .message_type import RelayMessageType
from .subscription import Subscription from .subscription import Subscription
@@ -22,24 +22,30 @@ class Relay:
def __init__( def __init__(
self, self,
url: str, url: str,
policy: RelayPolicy, policy: RelayPolicy,
ws: WebSocket=WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE}), message_pool: MessagePool,
subscriptions: dict[str, Subscription]={}) -> None: subscriptions: dict[str, Subscription]={}) -> None:
self.url = url self.url = url
self.policy = policy self.policy = policy
self.ws = ws self.message_pool = message_pool
self.subscriptions = subscriptions self.subscriptions = subscriptions
self.ws = WebSocketApp(
url,
on_open=self._on_open,
on_message=self._on_message,
on_error=self._on_error,
on_close=self._on_close)
def open_websocket_connection(self, timeout: int=None) -> None: def connect(self):
if timeout != None: self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
self.ws.connect(self.url, timeout=timeout)
else:
self.ws.connect(self.url)
def close_websocket_connection(self) -> None: def close(self):
self.ws.close() self.ws.close()
def add_subscription(self, id: str, filters: Filters) -> None: def publish(self, message: str):
self.ws.send(message)
def add_subscription(self, id, filters: Filters):
self.subscriptions[id] = Subscription(id, filters) self.subscriptions[id] = Subscription(id, filters)
def close_subscription(self, id: str) -> None: def close_subscription(self, id: str) -> None:
@@ -49,27 +55,6 @@ class Relay:
subscription = self.subscriptions[id] subscription = self.subscriptions[id]
subscription.filters = filters subscription.filters = filters
def publish_message(self, message: str) -> None:
self.ws.send(message)
def get_message(self) -> Union[None, str]:
while True:
try:
message = self.ws.recv()
if not self._is_valid_message(message):
continue
return message
except WebSocketConnectionClosedException:
print('received connection closed')
break
except WebSocketTimeoutException:
print('ws connection timed out')
break
return None
def to_json_object(self) -> dict: def to_json_object(self) -> dict:
return { return {
"url": self.url, "url": self.url,
@@ -77,16 +62,27 @@ class Relay:
"subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()] "subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()]
} }
def _on_open(self, class_obj):
pass
def _on_close(self, class_obj, status_code, message):
pass
def _on_message(self, class_obj, message: str):
if self._is_valid_message(message):
self.message_pool.add_message(message, self.url)
def _on_error(self, class_obj, error):
pass
def _is_valid_message(self, message: str) -> bool: def _is_valid_message(self, message: str) -> bool:
if not message or message[0] != '[' or message[-1] != ']': if not message or message[0] != '[' or message[-1] != ']':
return False return False
message_json = json.loads(message) message_json = json.loads(message)
message_type = message_json[0] message_type = message_json[0]
if message_type == RelayMessageType.NOTICE: if not RelayMessageType.is_valid(message_type):
return True return False
if message_type == RelayMessageType.END_OF_STORED_EVENTS:
return True
if message_type == RelayMessageType.EVENT: if message_type == RelayMessageType.EVENT:
if not len(message_json) == 3: if not len(message_json) == 3:
return False return False