diff --git a/bfxapi/websocket/_client/bfx_websocket_bucket.py b/bfxapi/websocket/_client/bfx_websocket_bucket.py index 2cc421b..e8eddc7 100644 --- a/bfxapi/websocket/_client/bfx_websocket_bucket.py +++ b/bfxapi/websocket/_client/bfx_websocket_bucket.py @@ -1,106 +1,123 @@ +from typing import \ + TYPE_CHECKING, Optional, Dict, List, Any, cast + import asyncio, json, uuid, websockets -from .._handlers import PublicChannelsHandler +from bfxapi.websocket._connection import Connection +from bfxapi.websocket._handlers import PublicChannelsHandler +from bfxapi.websocket.exceptions import TooManySubscriptions -from ..exceptions import ConnectionNotOpen, TooManySubscriptions +if TYPE_CHECKING: + from bfxapi.websocket.subscriptions import Subscription + from websockets.client import WebSocketClientProtocol + from pyee import EventEmitter -def require_websocket_connection(function): - async def wrapper(self, *args, **kwargs): - if self.websocket is None or not self.websocket.open: - raise ConnectionNotOpen("No open connection with the server.") - - await function(self, *args, **kwargs) - - return wrapper - -class BfxWebSocketBucket: +class BfxWebSocketBucket(Connection): VERSION = 2 MAXIMUM_SUBSCRIPTIONS_AMOUNT = 25 - def __init__(self, host, event_emitter): - self.host, self.websocket, self.event_emitter = \ - host, None, event_emitter + def __init__(self, host: str, event_emitter: "EventEmitter") -> None: + super().__init__(host) - self.condition, self.subscriptions, self.pendings = \ - asyncio.locks.Condition(), {}, [] + self.__event_emitter = event_emitter + self.__pendings: List[Dict[str, Any]] = [ ] + self.__subscriptions: Dict[int, "Subscription"] = { } - self.handler = PublicChannelsHandler(event_emitter=self.event_emitter) + self.__condition = asyncio.locks.Condition() - async def connect(self): - async with websockets.connect(self.host) as websocket: - self.websocket = websocket + self.__handler = PublicChannelsHandler( \ + event_emitter=self.__event_emitter) + + @property + def pendings(self) -> List[Dict[str, Any]]: + return self.__pendings + + @property + def subscriptions(self) -> Dict[int, "Subscription"]: + return self.__subscriptions + + async def connect(self) -> None: + async with websockets.client.connect(self._host) as websocket: + self._websocket = websocket await self.__recover_state() - async with self.condition: - self.condition.notify() + async with self.__condition: + self.__condition.notify(1) - async for message in websocket: + async for message in self._websocket: message = json.loads(message) if isinstance(message, dict): if message["event"] == "subscribed" and (chan_id := message["chanId"]): - self.pendings = [ pending \ - for pending in self.pendings if pending["subId"] != message["subId"] ] + self.__pendings = [ pending \ + for pending in self.__pendings \ + if pending["subId"] != message["subId"] ] - self.subscriptions[chan_id] = message + self.__subscriptions[chan_id] = cast("Subscription", message) - self.event_emitter.emit("subscribed", message) + self.__event_emitter.emit("subscribed", message) elif message["event"] == "unsubscribed" and (chan_id := message["chanId"]): if message["status"] == "OK": - del self.subscriptions[chan_id] + del self.__subscriptions[chan_id] elif message["event"] == "error": - self.event_emitter.emit("wss-error", message["code"], message["msg"]) + self.__event_emitter.emit( \ + "wss-error", message["code"], message["msg"]) if isinstance(message, list): - if (chan_id := message[0]) and message[1] != "hb": - self.handler.handle(self.subscriptions[chan_id], message[1:]) + if (chan_id := message[0]) and message[1] != Connection.HEARTBEAT: + self.__handler.handle(self.__subscriptions[chan_id], message[1:]) - async def __recover_state(self): - for pending in self.pendings: - await self.websocket.send(json.dumps(pending)) + async def __recover_state(self) -> None: + for pending in self.__pendings: + await self._websocket.send( \ + json.dumps(pending)) - for _, subscription in self.subscriptions.items(): - await self.subscribe(sub_id=subscription.pop("subId"), **subscription) + for _, subscription in self.__subscriptions.items(): + _subscription = cast(Dict[str, Any], subscription) - self.subscriptions.clear() + await self.subscribe( \ + sub_id=_subscription.pop("subId"), **_subscription) - @require_websocket_connection - async def subscribe(self, channel, sub_id=None, **kwargs): - if len(self.subscriptions) + len(self.pendings) == BfxWebSocketBucket.MAXIMUM_SUBSCRIPTIONS_AMOUNT: + self.__subscriptions.clear() + + @Connection.require_websocket_connection + async def subscribe(self, + channel: str, + sub_id: Optional[str] = None, + **kwargs: Any) -> None: + if len(self.__subscriptions) + len(self.__pendings) \ + == BfxWebSocketBucket.MAXIMUM_SUBSCRIPTIONS_AMOUNT: raise TooManySubscriptions("The client has reached the maximum number of subscriptions.") - subscription = { - **kwargs, + subscription = \ + { **kwargs, "event": "subscribe", "channel": channel } - "event": "subscribe", - "channel": channel, - "subId": sub_id or str(uuid.uuid4()), - } + subscription["subId"] = sub_id or str(uuid.uuid4()) - self.pendings.append(subscription) + self.__pendings.append(subscription) - await self.websocket.send(json.dumps(subscription)) + await self._websocket.send( \ + json.dumps(subscription)) - @require_websocket_connection - async def unsubscribe(self, chan_id): - await self.websocket.send(json.dumps({ - "event": "unsubscribe", - "chanId": chan_id - })) + @Connection.require_websocket_connection + async def unsubscribe(self, chan_id: int) -> None: + await self._websocket.send(json.dumps( \ + { "event": "unsubscribe", "chanId": chan_id })) - @require_websocket_connection - async def close(self, code=1000, reason=str()): - await self.websocket.close(code=code, reason=reason) + @Connection.require_websocket_connection + async def close(self, code: int = 1000, reason: str = str()) -> None: + await self._websocket.close(code=code, reason=reason) - def get_chan_id(self, sub_id): - for subscription in self.subscriptions.values(): + def get_chan_id(self, sub_id: str) -> Optional[int]: + for subscription in self.__subscriptions.values(): if subscription["subId"] == sub_id: return subscription["chanId"] - async def wait(self): - async with self.condition: - await self.condition.wait_for( - lambda: self.websocket is not None and \ - self.websocket.open) + return None + + async def wait(self) -> None: + async with self.__condition: + await self.__condition.wait_for( + lambda: self.open) diff --git a/bfxapi/websocket/_connection.py b/bfxapi/websocket/_connection.py new file mode 100644 index 0000000..7908ca7 --- /dev/null +++ b/bfxapi/websocket/_connection.py @@ -0,0 +1,39 @@ +from typing import \ + TYPE_CHECKING, Optional, cast + +from bfxapi.websocket.exceptions import \ + ConnectionNotOpen + +if TYPE_CHECKING: + from websockets.client import WebSocketClientProtocol + +class Connection: + HEARTBEAT = "hb" + + def __init__(self, host: str) -> None: + self._host = host + + self.__protocol: Optional["WebSocketClientProtocol"] = None + + @property + def open(self) -> bool: + return self.__protocol is not None and \ + self.__protocol.open + + @property + def _websocket(self) -> "WebSocketClientProtocol": + return cast("WebSocketClientProtocol", self.__protocol) + + @_websocket.setter + def _websocket(self, protocol: "WebSocketClientProtocol") -> None: + self.__protocol = protocol + + @staticmethod + def require_websocket_connection(function): + async def wrapper(self, *args, **kwargs): + if self.open: + return await function(self, *args, **kwargs) + + raise ConnectionNotOpen("No open connection with the server.") + + return wrapper