Initial AUTH support

This commit is contained in:
vic
2023-10-25 20:15:18 -04:00
parent 8086e4d8ba
commit d42448f300
4 changed files with 31 additions and 5 deletions

View File

@@ -24,11 +24,18 @@ class EndOfStoredEventsMessage:
self.url = url self.url = url
class AuthMessage:
def __init__(self, challenge: str, url: str) -> None:
self.challenge = challenge
self.url = url
class MessagePool: class MessagePool:
def __init__(self) -> None: def __init__(self) -> None:
self.events: Queue[EventMessage] = Queue() self.events: Queue[EventMessage] = Queue()
self.notices: Queue[NoticeMessage] = Queue() self.notices: Queue[NoticeMessage] = Queue()
self.eose_notices: Queue[EndOfStoredEventsMessage] = Queue() self.eose_notices: Queue[EndOfStoredEventsMessage] = Queue()
self.auths: Queue[AuthMessage] = Queue()
self._unique_events: set = set() self._unique_events: set = set()
self.lock: Lock = Lock() self.lock: Lock = Lock()
@@ -44,6 +51,9 @@ class MessagePool:
def get_eose_notice(self): def get_eose_notice(self):
return self.eose_notices.get() return self.eose_notices.get()
def get_auth(self):
return self.auth.get()
def has_events(self): def has_events(self):
return self.events.qsize() > 0 return self.events.qsize() > 0
@@ -53,6 +63,9 @@ class MessagePool:
def has_eose_notices(self): def has_eose_notices(self):
return self.eose_notices.qsize() > 0 return self.eose_notices.qsize() > 0
def has_auths(self):
return self.auths.qsize() > 0
def _process_message(self, message: str, url: str): def _process_message(self, message: str, url: str):
message_json = json.loads(message) message_json = json.loads(message)
message_type = message_json[0] message_type = message_json[0]
@@ -75,3 +88,5 @@ class MessagePool:
self.notices.put(NoticeMessage(message_json[1], url)) self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS: elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url)) self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url))
elif message_type == RelayMessageType.AUTH:
self.auths.put(AuthMessage(message_json[1], url))

View File

@@ -2,14 +2,16 @@ class ClientMessageType:
EVENT = "EVENT" EVENT = "EVENT"
REQUEST = "REQ" REQUEST = "REQ"
CLOSE = "CLOSE" CLOSE = "CLOSE"
AUTH = "AUTH"
class RelayMessageType: class RelayMessageType:
EVENT = "EVENT" EVENT = "EVENT"
NOTICE = "NOTICE" NOTICE = "NOTICE"
END_OF_STORED_EVENTS = "EOSE" END_OF_STORED_EVENTS = "EOSE"
AUTH = "AUTH"
@staticmethod @staticmethod
def is_valid(type: str) -> bool: def is_valid(type: str) -> bool:
if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS: if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS:
return True return True
return False return False

View File

@@ -100,7 +100,7 @@ class Relay:
def close_subscription(self, id: str) -> None: def close_subscription(self, id: str) -> None:
with self.lock: with self.lock:
self.subscriptions.pop(id) self.subscriptions.pop(id, None)
def update_subscription(self, id: str, filters: Filters) -> None: def update_subscription(self, id: str, filters: Filters) -> None:
with self.lock: with self.lock:

View File

@@ -52,12 +52,13 @@ class RelayManager:
for relay in self.relays.values(): for relay in self.relays.values():
relay.close() relay.close()
def publish_message(self, message: str): def publish_message(self, message: str, url:str=None):
for relay in self.relays.values(): for relay in self.relays.values():
if relay.policy.should_write: if relay.policy.should_write:
relay.publish(message) if url is None or relay.url == url:
relay.publish(message)
def publish_event(self, event: Event): def verify_event(self, event: Event):
"""Verifies that the Event is publishable before submitting it to relays""" """Verifies that the Event is publishable before submitting it to relays"""
if event.signature is None: if event.signature is None:
raise RelayException(f"Could not publish {event.id}: must be signed") raise RelayException(f"Could not publish {event.id}: must be signed")
@@ -66,4 +67,12 @@ class RelayManager:
raise RelayException( raise RelayException(
f"Could not publish {event.id}: failed to verify signature {event.signature}" f"Could not publish {event.id}: failed to verify signature {event.signature}"
) )
def publish_event(self, event: Event):
self.verify_event(event)
self.publish_message(event.to_message()) self.publish_message(event.to_message())
def publish_auth(self, event: Event, url: str):
self.verify_event(event)
self.publish_message(event.to_message(), url)