refactor relay

This commit is contained in:
jeffthibault
2022-08-01 20:10:48 -04:00
parent c22917d076
commit de8d66c82d

View File

@@ -1,9 +1,9 @@
import json
import ssl
from typing import Union
from websocket import WebSocket, WebSocketConnectionClosedException, WebSocketTimeoutException
from websocket import WebSocketApp
from .event import Event
from .filter import Filters
from .message_pool import MessagePool
from .message_type import RelayMessageType
from .subscription import Subscription
@@ -23,23 +23,29 @@ class Relay:
self,
url: str,
policy: RelayPolicy,
ws: WebSocket=WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE}),
message_pool: MessagePool,
subscriptions: dict[str, Subscription]={}) -> None:
self.url = url
self.policy = policy
self.ws = ws
self.message_pool = message_pool
self.subscriptions = subscriptions
self.ws = WebSocketApp(
url,
on_open=self._on_open,
on_message=self._on_message,
on_error=self._on_error,
on_close=self._on_close)
def open_websocket_connection(self, timeout: int=None) -> None:
if timeout != None:
self.ws.connect(self.url, timeout=timeout)
else:
self.ws.connect(self.url)
def connect(self):
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def close_websocket_connection(self) -> None:
def close(self):
self.ws.close()
def add_subscription(self, id: str, filters: Filters) -> None:
def publish(self, message: str):
self.ws.send(message)
def add_subscription(self, id, filters: Filters):
self.subscriptions[id] = Subscription(id, filters)
def close_subscription(self, id: str) -> None:
@@ -49,27 +55,6 @@ class Relay:
subscription = self.subscriptions[id]
subscription.filters = filters
def publish_message(self, message: str) -> None:
self.ws.send(message)
def get_message(self) -> Union[None, str]:
while True:
try:
message = self.ws.recv()
if not self._is_valid_message(message):
continue
return message
except WebSocketConnectionClosedException:
print('received connection closed')
break
except WebSocketTimeoutException:
print('ws connection timed out')
break
return None
def to_json_object(self) -> dict:
return {
"url": self.url,
@@ -77,16 +62,27 @@ class Relay:
"subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()]
}
def _on_open(self, class_obj):
pass
def _on_close(self, class_obj, status_code, message):
pass
def _on_message(self, class_obj, message: str):
if self._is_valid_message(message):
self.message_pool.add_message(message, self.url)
def _on_error(self, class_obj, error):
pass
def _is_valid_message(self, message: str) -> bool:
if not message or message[0] != '[' or message[-1] != ']':
return False
message_json = json.loads(message)
message_type = message_json[0]
if message_type == RelayMessageType.NOTICE:
return True
if message_type == RelayMessageType.END_OF_STORED_EVENTS:
return True
if not RelayMessageType.is_valid(message_type):
return False
if message_type == RelayMessageType.EVENT:
if not len(message_json) == 3:
return False