This commit is contained in:
callebtc
2022-12-26 15:23:14 +01:00
parent b5e99fc708
commit 77af490acf
8 changed files with 587 additions and 31 deletions

View File

@@ -7,24 +7,24 @@ from .message_pool import MessagePool
from .message_type import RelayMessageType
from .subscription import Subscription
class RelayPolicy:
def __init__(self, should_read: bool=True, should_write: bool=True) -> None:
def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
self.should_read = should_read
self.should_write = should_write
def to_json_object(self) -> dict[str, bool]:
return {
"read": self.should_read,
"write": self.should_write
}
return {"read": self.should_read, "write": self.should_write}
class Relay:
def __init__(
self,
url: str,
policy: RelayPolicy,
message_pool: MessagePool,
subscriptions: dict[str, Subscription]={}) -> None:
self,
url: str,
policy: RelayPolicy,
message_pool: MessagePool,
subscriptions: dict[str, Subscription] = {},
) -> None:
self.url = url
self.policy = policy
self.message_pool = message_pool
@@ -35,9 +35,10 @@ class Relay:
on_open=self._on_open,
on_message=self._on_message,
on_error=self._on_error,
on_close=self._on_close)
on_close=self._on_close,
)
def connect(self, ssl_options: dict=None):
def connect(self, ssl_options: dict = None):
self.ws.run_forever(sslopt=ssl_options)
def close(self):
@@ -63,7 +64,10 @@ class Relay:
return {
"url": self.url,
"policy": self.policy.to_json_object(),
"subscriptions": [subscription.to_json_object() for subscription in self.subscriptions.values()]
"subscriptions": [
subscription.to_json_object()
for subscription in self.subscriptions.values()
],
}
def _on_open(self, class_obj):
@@ -75,12 +79,13 @@ class Relay:
def _on_message(self, class_obj, message: str):
if self._is_valid_message(message):
self.message_pool.add_message(message, self.url)
def _on_error(self, class_obj, error):
pass
def _is_valid_message(self, message: str) -> bool:
if not message or message[0] != '[' or message[-1] != ']':
message = message.strip("\n")
if not message or message[0] != "[" or message[-1] != "]":
return False
message_json = json.loads(message)
@@ -90,14 +95,22 @@ class Relay:
if message_type == RelayMessageType.EVENT:
if not len(message_json) == 3:
return False
subscription_id = message_json[1]
with self.lock:
if subscription_id not in self.subscriptions:
return False
e = message_json[2]
event = Event(e['pubkey'], e['content'], e['created_at'], e['kind'], e['tags'], e['id'], e['sig'])
event = Event(
e["pubkey"],
e["content"],
e["created_at"],
e["kind"],
e["tags"],
e["id"],
e["sig"],
)
if not event.verify():
return False