From f1e678e0438d25137a941323f684a7ee5ac9bf67 Mon Sep 17 00:00:00 2001 From: Davide Casale Date: Mon, 19 Jun 2023 04:57:33 +0200 Subject: [PATCH] Add type hinting support to bfxapi.websocket.client.bfx_websocket_client. --- bfxapi/client.py | 25 +- bfxapi/rest/endpoints/bfx_rest_interface.py | 4 +- .../websocket/_client/bfx_websocket_bucket.py | 24 +- .../websocket/_client/bfx_websocket_client.py | 441 ++++++++++-------- bfxapi/websocket/_connection.py | 99 ++-- 5 files changed, 341 insertions(+), 252 deletions(-) diff --git a/bfxapi/client.py b/bfxapi/client.py index dac3649..5e5292d 100644 --- a/bfxapi/client.py +++ b/bfxapi/client.py @@ -9,29 +9,16 @@ class Client: self, api_key: Optional[str] = None, api_secret: Optional[str] = None, - filters: Optional[List[str]] = None, *, rest_host: str = REST_HOST, wss_host: str = WSS_HOST, + filters: Optional[List[str]] = None, wss_timeout: Optional[float] = 60 * 15, log_filename: Optional[str] = None, log_level: Literal["ERROR", "WARNING", "INFO", "DEBUG"] = "INFO" - ): - credentials = None + ) -> None: + self.rest = BfxRestInterface(rest_host, api_key, api_secret) - if api_key and api_secret: - credentials = { "api_key": api_key, "api_secret": api_secret, "filters": filters } - - self.rest = BfxRestInterface( - host=rest_host, - credentials=credentials - ) - - self.wss = BfxWebSocketClient( - host=wss_host, - credentials=credentials, - wss_timeout=wss_timeout, - log_filename=log_filename, - log_level=log_level - ) - \ No newline at end of file + self.wss = BfxWebSocketClient(wss_host, api_key, api_secret, + filters=filters, wss_timeout=wss_timeout, log_filename=log_filename, + log_level=log_level) diff --git a/bfxapi/rest/endpoints/bfx_rest_interface.py b/bfxapi/rest/endpoints/bfx_rest_interface.py index 73ec603..1d17d26 100644 --- a/bfxapi/rest/endpoints/bfx_rest_interface.py +++ b/bfxapi/rest/endpoints/bfx_rest_interface.py @@ -5,9 +5,7 @@ from .rest_merchant_endpoints import RestMerchantEndpoints class BfxRestInterface: VERSION = 2 - def __init__(self, host, credentials = None): - api_key, api_secret = (credentials['api_key'], credentials['api_secret']) if credentials else (None, None) - + def __init__(self, host, api_key = None, api_secret = None): self.public = RestPublicEndpoints(host=host) self.auth = RestAuthEndpoints(host=host, api_key=api_key, api_secret=api_secret) self.merchant = RestMerchantEndpoints(host=host, api_key=api_key, api_secret=api_secret) diff --git a/bfxapi/websocket/_client/bfx_websocket_bucket.py b/bfxapi/websocket/_client/bfx_websocket_bucket.py index e8eddc7..c8dcf97 100644 --- a/bfxapi/websocket/_client/bfx_websocket_bucket.py +++ b/bfxapi/websocket/_client/bfx_websocket_bucket.py @@ -1,7 +1,9 @@ from typing import \ TYPE_CHECKING, Optional, Dict, List, Any, cast -import asyncio, json, uuid, websockets +import asyncio, json, uuid + +from websockets.legacy.client import connect as _websockets__connect from bfxapi.websocket._connection import Connection from bfxapi.websocket._handlers import PublicChannelsHandler @@ -38,7 +40,7 @@ class BfxWebSocketBucket(Connection): return self.__subscriptions async def connect(self) -> None: - async with websockets.client.connect(self._host) as websocket: + async with _websockets__connect(self._host) as websocket: self._websocket = websocket await self.__recover_state() @@ -102,20 +104,26 @@ class BfxWebSocketBucket(Connection): json.dumps(subscription)) @Connection.require_websocket_connection - async def unsubscribe(self, chan_id: int) -> None: - await self._websocket.send(json.dumps( \ - { "event": "unsubscribe", "chanId": chan_id })) + async def unsubscribe(self, sub_id: str) -> None: + for subscription in self.__subscriptions.values(): + if subscription["subId"] == sub_id: + data = { "event": "unsubscribe", \ + "chanId": subscription["subId"] } + + message = json.dumps(data) + + await self._websocket.send(message) @Connection.require_websocket_connection async def close(self, code: int = 1000, reason: str = str()) -> None: await self._websocket.close(code=code, reason=reason) - def get_chan_id(self, sub_id: str) -> Optional[int]: + def has(self, sub_id: str) -> bool: for subscription in self.__subscriptions.values(): if subscription["subId"] == sub_id: - return subscription["chanId"] + return True - return None + return False async def wait(self) -> None: async with self.__condition: diff --git a/bfxapi/websocket/_client/bfx_websocket_client.py b/bfxapi/websocket/_client/bfx_websocket_client.py index 28fc9ea..38a2c37 100644 --- a/bfxapi/websocket/_client/bfx_websocket_client.py +++ b/bfxapi/websocket/_client/bfx_websocket_client.py @@ -1,55 +1,54 @@ -from collections import namedtuple +from typing import \ + TYPE_CHECKING, TypeVar, TypedDict,\ + Callable, Optional, Literal,\ + Tuple, List, Dict, \ + Any from datetime import datetime -import traceback, json, asyncio, hmac, hashlib, time, socket, random, websockets +from socket import gaierror -from .bfx_websocket_bucket import require_websocket_connection, BfxWebSocketBucket +import \ + traceback, json, asyncio, \ + hmac, hashlib, random, \ + websockets + +from websockets.legacy.client import connect as _websockets__connect + +from bfxapi.utils.json_encoder import JSONEncoder + +from bfxapi.utils.logger import \ + ColorLogger, FileLogger + +from bfxapi.websocket._handlers import \ + PublicChannelsHandler, AuthEventsHandler + +from bfxapi.websocket._connection import Connection + +from bfxapi.websocket._event_emitter import BfxEventEmitter + +from bfxapi.websocket.exceptions import \ + InvalidAuthenticationCredentials, EventNotSupported, ZeroConnectionsError, \ + ReconnectionTimeoutError, OutdatedClientVersion + +from .bfx_websocket_bucket import BfxWebSocketBucket from .bfx_websocket_inputs import BfxWebSocketInputs -from .._handlers import PublicChannelsHandler, AuthEventsHandler -from ..exceptions import ActionRequiresAuthentication, InvalidAuthenticationCredentials, EventNotSupported, \ - ZeroConnectionsError, ReconnectionTimeoutError, OutdatedClientVersion -from .._event_emitter import BfxEventEmitter +if TYPE_CHECKING: + from logging import Logger -from ...utils.json_encoder import JSONEncoder + from asyncio import Task -from ...utils.logger import ColorLogger, FileLogger + _T = TypeVar("_T", bound=Callable[..., None]) -def require_websocket_authentication(function): - async def wrapper(self, *args, **kwargs): - if hasattr(self, "authentication") and not self.authentication: - raise ActionRequiresAuthentication("To perform this action you need to " \ - "authenticate using your API_KEY and API_SECRET.") + _Credentials = TypedDict("_Credentials", \ + { "api_key": str, "api_secret": str, "filters": Optional[List[str]] }) - await require_websocket_connection(function) \ - (self, *args, **kwargs) + _Reconnection = TypedDict("_Reconnection", + { "attempts": int, "reason": str, "timestamp": datetime }) - return wrapper - -class _Delay: - BACKOFF_MIN, BACKOFF_MAX = 1.92, 60.0 - - BACKOFF_INITIAL = 5.0 - - def __init__(self, backoff_factor): - self.__backoff_factor = backoff_factor - self.__backoff_delay = _Delay.BACKOFF_MIN - self.__initial_delay = random.random() * _Delay.BACKOFF_INITIAL - - def next(self): - backoff_delay = self.peek() - __backoff_delay = self.__backoff_delay * self.__backoff_factor - self.__backoff_delay = min(__backoff_delay, _Delay.BACKOFF_MAX) - - return backoff_delay - - def peek(self): - return (self.__backoff_delay == _Delay.BACKOFF_MIN) \ - and self.__initial_delay or self.__backoff_delay - -class BfxWebSocketClient: +class BfxWebSocketClient(Connection, Connection.Authenticable): VERSION = BfxWebSocketBucket.VERSION MAXIMUM_CONNECTIONS_AMOUNT = 20 @@ -66,223 +65,299 @@ class BfxWebSocketClient: *AuthEventsHandler.ON_EVENTS ] - def __init__(self, host, credentials, *, wss_timeout = 60 * 15, log_filename = None, log_level = "INFO"): - self.websocket, self.authentication, self.buckets = None, False, [] + def __init__(self, + host: str, + api_key: Optional[str] = None, + api_secret: Optional[str] = None, + *, + filters: Optional[List[str]] = None, + wss_timeout: Optional[float] = 60 * 15, + log_filename: Optional[str] = None, + log_level: Literal["ERROR", "WARNING", "INFO", "DEBUG"] = "INFO") -> None: + super().__init__(host) - self.host, self.credentials, self.wss_timeout = host, credentials, wss_timeout + self.__credentials: Optional["_Credentials"] = None - self.event_emitter = BfxEventEmitter(targets= \ + if api_key and api_secret: + self.__credentials = \ + { "api_key": api_key, "api_secret": api_secret, "filters": filters } + + self.__wss_timeout = wss_timeout + + self.__event_emitter = BfxEventEmitter(targets = \ PublicChannelsHandler.ONCE_PER_SUBSCRIPTION + \ ["subscribed"]) - self.handler = AuthEventsHandler(event_emitter=self.event_emitter) + self.__handler = AuthEventsHandler(\ + event_emitter=self.__event_emitter) - self.inputs = BfxWebSocketInputs(handle_websocket_input=self.__handle_websocket_input) + self.__inputs = BfxWebSocketInputs(\ + handle_websocket_input=self.__handle_websocket_input) + + self.__buckets: Dict[BfxWebSocketBucket, Optional["Task"]] = { } + + self.__reconnection: Optional[_Reconnection] = None + + self.__logger: "Logger" if log_filename is None: - self.logger = ColorLogger("BfxWebSocketClient", level=log_level) - else: self.logger = FileLogger("BfxWebSocketClient", level=log_level, filename=log_filename) + self.__logger = ColorLogger("BfxWebSocketClient", level=log_level) + else: self.__logger = FileLogger("BfxWebSocketClient", level=log_level, filename=log_filename) - self.event_emitter.add_listener("error", - lambda exception: self.logger.error(f"{type(exception).__name__}: {str(exception)}" + "\n" + + self.__event_emitter.add_listener("error", + lambda exception: self.__logger.error(f"{type(exception).__name__}: {str(exception)}" + "\n" + str().join(traceback.format_exception(type(exception), exception, exception.__traceback__))[:-1]) ) - def run(self, connections = 5): + @property + def inputs(self) -> BfxWebSocketInputs: + return self.__inputs + + def run(self, connections: int = 5) -> None: return asyncio.run(self.start(connections)) - async def start(self, connections = 5): + async def start(self, connections: int = 5) -> None: if connections == 0: - self.logger.info("With connections set to 0 it will not be possible to subscribe to any public channel. " \ - "Attempting a subscription will cause a ZeroConnectionsError to be thrown.") + self.__logger.info("With connections set to 0 it will not be possible to subscribe to any " \ + "public channel. Attempting a subscription will cause a ZeroConnectionsError to be thrown.") if connections > BfxWebSocketClient.MAXIMUM_CONNECTIONS_AMOUNT: - self.logger.warning(f"It is not safe to use more than {BfxWebSocketClient.MAXIMUM_CONNECTIONS_AMOUNT} " \ - f"buckets from the same connection ({connections} in use), the server could momentarily " \ - "block the client with <429 Too Many Requests>.") + self.__logger.warning(f"It is not safe to use more than {BfxWebSocketClient.MAXIMUM_CONNECTIONS_AMOUNT} " \ + f"buckets from the same connection ({connections} in use), the server could momentarily " \ + "block the client with <429 Too Many Requests>.") for _ in range(connections): - self.buckets += [BfxWebSocketBucket(self.host, self.event_emitter)] + _bucket = BfxWebSocketBucket( \ + self._host, self.__event_emitter) + + self.__buckets.update( { _bucket: None }) await self.__connect() - #pylint: disable-next=too-many-statements,too-many-branches - async def __connect(self): - Reconnection = namedtuple("Reconnection", ["status", "attempts", "timestamp"]) - reconnection = Reconnection(status=False, attempts=0, timestamp=None) - timer, tasks, on_timeout_event = None, [], asyncio.locks.Event() + #pylint: disable-next=too-many-branches + async def __connect(self) -> None: + class _Delay: + BACKOFF_MIN, BACKOFF_MAX = 1.92, 60.0 - delay = None + BACKOFF_INITIAL = 5.0 - def _on_wss_timeout(): - on_timeout_event.set() + def __init__(self, backoff_factor: float) -> None: + self.__backoff_factor = backoff_factor + self.__backoff_delay = _Delay.BACKOFF_MIN + self.__initial_delay = random.random() * _Delay.BACKOFF_INITIAL - #pylint: disable-next=too-many-branches - async def _connection(): - nonlocal reconnection, timer, tasks + def next(self) -> float: + _backoff_delay = self.peek() + __backoff_delay = self.__backoff_delay * self.__backoff_factor + self.__backoff_delay = min(__backoff_delay, _Delay.BACKOFF_MAX) - async with websockets.connect(self.host, ping_interval=None) as websocket: - if reconnection.status: - self.logger.info(f"Reconnection attempt successful (no.{reconnection.attempts}): The " \ - f"client has been offline for a total of {datetime.now() - reconnection.timestamp} " \ - f"(connection lost on: {reconnection.timestamp:%d-%m-%Y at %H:%M:%S}).") + return _backoff_delay - reconnection = Reconnection(status=False, attempts=0, timestamp=None) + def peek(self) -> float: + return (self.__backoff_delay == _Delay.BACKOFF_MIN) \ + and self.__initial_delay or self.__backoff_delay - if isinstance(timer, asyncio.events.TimerHandle): - timer.cancel() + def reset(self) -> None: + self.__backoff_delay = _Delay.BACKOFF_MIN - self.websocket = websocket + _delay = _Delay(backoff_factor=1.618) - coroutines = [ BfxWebSocketBucket.connect(bucket) for bucket in self.buckets ] + _on_wss_timeout = asyncio.locks.Event() - tasks = [ asyncio.create_task(coroutine) for coroutine in coroutines ] - - if len(self.buckets) == 0 or \ - (await asyncio.gather(*[bucket.wait() for bucket in self.buckets])): - self.event_emitter.emit("open") - - if self.credentials: - await self.__authenticate(**self.credentials) - - async for message in websocket: - message = json.loads(message) - - if isinstance(message, dict): - if message["event"] == "info" and "version" in message: - if BfxWebSocketClient.VERSION != message["version"]: - raise OutdatedClientVersion("Mismatch between the client version and the server " \ - "version. Update the library to the latest version to continue (client version: " \ - f"{BfxWebSocketClient.VERSION}, server version: {message['version']}).") - elif message["event"] == "info" and message["code"] == 20051: - rcvd = websockets.frames.Close(code=1012, - reason="Stop/Restart WebSocket Server (please reconnect).") - - raise websockets.exceptions.ConnectionClosedError(rcvd=rcvd, sent=None) - elif message["event"] == "auth": - if message["status"] != "OK": - raise InvalidAuthenticationCredentials( - "Cannot authenticate with given API-KEY and API-SECRET.") - - self.event_emitter.emit("authenticated", message) - - self.authentication = True - elif message["event"] == "error": - self.event_emitter.emit("wss-error", message["code"], message["msg"]) - - if isinstance(message, list): - if message[0] == 0 and message[1] != "hb": - self.handler.handle(message[1], message[2]) + def on_wss_timeout(): + if not self.open: + _on_wss_timeout.set() while True: - if reconnection.status: - await asyncio.sleep(delay.next()) + if self.__reconnection: + await asyncio.sleep(_delay.next()) - if on_timeout_event.is_set(): + if _on_wss_timeout.is_set(): raise ReconnectionTimeoutError("Connection has been offline for too long " \ - f"without being able to reconnect (wss_timeout: {self.wss_timeout}s).") + f"without being able to reconnect (wss_timeout: {self.__wss_timeout}s).") try: - await _connection() - except (websockets.exceptions.ConnectionClosedError, socket.gaierror) as error: - for task in tasks: - task.cancel() + await self.__connection() + except (websockets.exceptions.ConnectionClosedError, gaierror) as error: + for bucket in self.__buckets: + if (_task := self.__buckets[bucket]): + _task.cancel() if isinstance(error, websockets.exceptions.ConnectionClosedError) and error.code in (1006, 1012): if error.code == 1006: - self.logger.error("Connection lost: no close frame received " \ - "or sent (1006). Trying to reconnect...") + self.__logger.error("Connection lost: no close frame " \ + "received or sent (1006). Trying to reconnect...") if error.code == 1012: - self.logger.info("WSS server is about to restart, clients need " \ + self.__logger.info("WSS server is about to restart, clients need " \ "to reconnect (server sent 20051). Reconnection attempt in progress...") - reconnection = Reconnection(status=True, attempts=1, timestamp=datetime.now()) + if self.__wss_timeout is not None: + asyncio.get_event_loop().call_later( + self.__wss_timeout, on_wss_timeout) - if self.wss_timeout is not None: - timer = asyncio.get_event_loop().call_later(self.wss_timeout, _on_wss_timeout) + self.__reconnection = \ + { "attempts": 1, "reason": error.reason, "timestamp": datetime.now() } - delay = _Delay(backoff_factor=1.618) + self._authentication = False - self.authentication = False - elif isinstance(error, socket.gaierror) and reconnection.status: - self.logger.warning(f"Reconnection attempt was unsuccessful (no.{reconnection.attempts}). " \ - f"Next reconnection attempt in {delay.peek():.2f} seconds. (at the moment " \ - f"the client has been offline for {datetime.now() - reconnection.timestamp})") + _delay.reset() + elif isinstance(error, gaierror) and self.__reconnection: + self.__logger.warning( + f"_Reconnection attempt was unsuccessful (no.{self.__reconnection['attempts']}). " \ + f"Next reconnection attempt in {_delay.peek():.2f} seconds. (at the moment " \ + f"the client has been offline for {datetime.now() - self.__reconnection['timestamp']})") - reconnection = reconnection._replace(attempts=reconnection.attempts + 1) + self.__reconnection["attempts"] += 1 else: raise error - if not reconnection.status: - self.event_emitter.emit("disconnection", - self.websocket.close_code, self.websocket.close_reason) + if not self.__reconnection: + self.__event_emitter.emit("disconnection", + self._websocket.close_code, self._websocket.close_reason) break - async def __authenticate(self, api_key, api_secret, filters=None): - data = { "event": "auth", "filter": filters, "apiKey": api_key } + async def __connection(self) -> None: + async with _websockets__connect(self._host) as websocket: + if self.__reconnection: + self.__logger.info(f"_Reconnection attempt successful (no.{self.__reconnection['attempts']}): The " \ + f"client has been offline for a total of {datetime.now() - self.__reconnection['timestamp']} " \ + f"(connection lost on: {self.__reconnection['timestamp']:%d-%m-%Y at %H:%M:%S}).") - data["authNonce"] = str(round(time.time() * 1_000_000)) + self.__reconnection = None - data["authPayload"] = "AUTH" + data["authNonce"] + self._websocket = websocket - data["authSig"] = hmac.new( - api_secret.encode("utf8"), - data["authPayload"].encode("utf8"), - hashlib.sha384 - ).hexdigest() + self.__buckets = { + bucket: asyncio.create_task(_c) + for bucket in self.__buckets + if (_c := bucket.connect()) + } - await self.websocket.send(json.dumps(data)) + if len(self.__buckets) == 0 or \ + (await asyncio.gather(*[bucket.wait() for bucket in self.__buckets])): + self.__event_emitter.emit("open") - async def subscribe(self, channel, **kwargs): - if len(self.buckets) == 0: - raise ZeroConnectionsError("Unable to subscribe: the number of connections must be greater than 0.") + if self.__credentials: + authentication = BfxWebSocketClient. \ + __build_authentication_message(**self.__credentials) - counters = [ len(bucket.pendings) + len(bucket.subscriptions) for bucket in self.buckets ] + await self._websocket.send(authentication) + + async for message in self._websocket: + message = json.loads(message) + + if isinstance(message, dict): + if message["event"] == "info" and "version" in message: + if BfxWebSocketClient.VERSION != message["version"]: + raise OutdatedClientVersion("Mismatch between the client version and the server " \ + "version. Update the library to the latest version to continue (client version: " \ + f"{BfxWebSocketClient.VERSION}, server version: {message['version']}).") + elif message["event"] == "info" and message["code"] == 20051: + rcvd = websockets.frames.Close(code=1012, + reason="Stop/Restart WebSocket Server (please reconnect).") + + raise websockets.exceptions.ConnectionClosedError(rcvd=rcvd, sent=None) + elif message["event"] == "auth": + if message["status"] != "OK": + raise InvalidAuthenticationCredentials( + "Cannot authenticate with given API-KEY and API-SECRET.") + + self.__event_emitter.emit("authenticated", message) + + self._authentication = True + elif message["event"] == "error": + self.__event_emitter.emit("wss-error", message["code"], message["msg"]) + + if isinstance(message, list) and \ + message[0] == 0 and message[1] != Connection.HEARTBEAT: + self.__handler.handle(message[1], message[2]) + + @Connection.require_websocket_connection + async def subscribe(self, + channel: str, + sub_id: Optional[str] = None, + **kwargs: Any) -> None: + if len(self.__buckets) == 0: + raise ZeroConnectionsError("Unable to subscribe: " \ + "the number of connections must be greater than 0.") + + _buckets = list(self.__buckets.keys()) + + counters = [ len(bucket.pendings) + len(bucket.subscriptions) + for bucket in _buckets ] index = counters.index(min(counters)) - await self.buckets[index].subscribe(channel, **kwargs) + await _buckets[index] \ + .subscribe(channel, sub_id, **kwargs) - async def unsubscribe(self, sub_id): - for bucket in self.buckets: - if (chan_id := bucket.get_chan_id(sub_id)): - await bucket.unsubscribe(chan_id=chan_id) + @Connection.require_websocket_connection + async def unsubscribe(self, sub_id: str) -> None: + for bucket in self.__buckets: + if bucket.has(sub_id=sub_id): + await bucket.unsubscribe(sub_id=sub_id) - async def close(self, code=1000, reason=str()): - for bucket in self.buckets: + @Connection.require_websocket_connection + async def close(self, code: int = 1000, reason: str = str()) -> None: + for bucket in self.__buckets: await bucket.close(code=code, reason=reason) - if self.websocket is not None and self.websocket.open: - await self.websocket.close(code=code, reason=reason) + if self._websocket.open: + await self._websocket.close( \ + code=code, reason=reason) - @require_websocket_authentication - async def notify(self, info, message_id=None, **kwargs): - await self.websocket.send(json.dumps([ 0, "n", message_id, { "type": "ucm-test", "info": info, **kwargs } ])) + @Connection.Authenticable.require_websocket_authentication + async def notify(self, + info: Any, + message_id: Optional[int] = None, + **kwargs: Any) -> None: + await self._websocket.send( + json.dumps([ 0, "n", message_id, + { "type": "ucm-test", "info": info, **kwargs } ])) - @require_websocket_authentication - async def __handle_websocket_input(self, event, data): - await self.websocket.send(json.dumps([ 0, event, None, data], cls=JSONEncoder)) + @Connection.Authenticable.require_websocket_authentication + async def __handle_websocket_input(self, event: str, data: Any) -> None: + await self._websocket.send(json.dumps(\ + [ 0, event, None, data], cls=JSONEncoder)) - def on(self, *events, callback = None): + def on(self, *events: str, callback: Optional["_T"] = None) -> Callable[["_T"], None]: for event in events: if event not in BfxWebSocketClient.EVENTS: raise EventNotSupported(f"Event <{event}> is not supported. To get a list " \ - "of available events see BfxWebSocketClient.EVENTS.") + "of available events see BfxWebSocketClient.EVENTS.") - def _register_event(event, function): - if event in BfxWebSocketClient.__ONCE_EVENTS: - self.event_emitter.once(event, function) - else: self.event_emitter.on(event, function) - - if callback is not None: + def _register_events(function: "_T", events: Tuple[str, ...]) -> None: for event in events: - _register_event(event, callback) + if event in BfxWebSocketClient.__ONCE_EVENTS: + self.__event_emitter.once(event, function) + else: + self.__event_emitter.on(event, function) - if callback is None: - def handler(function): - for event in events: - _register_event(event, function) + if callback: + _register_events(callback, events) - return handler + def _handler(function: "_T") -> None: + _register_events(function, events) + + return _handler + + @staticmethod + def __build_authentication_message(api_key: str, + api_secret: str, + filters: Optional[List[str]] = None) -> str: + message: Dict[str, Any] = \ + { "event": "auth", "filter": filters, "apiKey": api_key } + + message["authNonce"] = round(datetime.now().timestamp() * 1_000_000) + + message["authPayload"] = f"AUTH{message['authNonce']}" + + message["authSig"] = hmac.new( + key=api_secret.encode("utf8"), + msg=message["authPayload"].encode("utf8"), + digestmod=hashlib.sha384 + ).hexdigest() + + return json.dumps(message) diff --git a/bfxapi/websocket/_connection.py b/bfxapi/websocket/_connection.py index 7908ca7..1562b43 100644 --- a/bfxapi/websocket/_connection.py +++ b/bfxapi/websocket/_connection.py @@ -1,39 +1,60 @@ -from typing import \ - TYPE_CHECKING, Optional, cast - -from bfxapi.websocket.exceptions import \ - ConnectionNotOpen - -if TYPE_CHECKING: - from websockets.client import WebSocketClientProtocol - -class Connection: - HEARTBEAT = "hb" - - def __init__(self, host: str) -> None: - self._host = host - - self.__protocol: Optional["WebSocketClientProtocol"] = None - - @property - def open(self) -> bool: - return self.__protocol is not None and \ - self.__protocol.open - - @property - def _websocket(self) -> "WebSocketClientProtocol": - return cast("WebSocketClientProtocol", self.__protocol) - - @_websocket.setter - def _websocket(self, protocol: "WebSocketClientProtocol") -> None: - self.__protocol = protocol - - @staticmethod - def require_websocket_connection(function): - async def wrapper(self, *args, **kwargs): - if self.open: - return await function(self, *args, **kwargs) - - raise ConnectionNotOpen("No open connection with the server.") - - return wrapper +from typing import \ + TYPE_CHECKING, Optional, cast + +from bfxapi.websocket.exceptions import \ + ConnectionNotOpen, ActionRequiresAuthentication + +if TYPE_CHECKING: + from websockets.client import WebSocketClientProtocol + +class Connection: + HEARTBEAT = "hb" + + class Authenticable: + def __init__(self) -> None: + self._authentication: bool = False + + @property + def authentication(self) -> bool: + return self._authentication + + @staticmethod + def require_websocket_authentication(function): + async def wrapper(self, *args, **kwargs): + if not self.authentication: + raise ActionRequiresAuthentication("To perform this action you need to " \ + "authenticate using your API_KEY and API_SECRET.") + + internal = Connection.require_websocket_connection(function) + + return await internal(self, *args, **kwargs) + + return wrapper + + def __init__(self, host: str) -> None: + self._host = host + + self.__protocol: Optional["WebSocketClientProtocol"] = None + + @property + def open(self) -> bool: + return self.__protocol is not None and \ + self.__protocol.open + + @property + def _websocket(self) -> "WebSocketClientProtocol": + return cast("WebSocketClientProtocol", self.__protocol) + + @_websocket.setter + def _websocket(self, protocol: "WebSocketClientProtocol") -> None: + self.__protocol = protocol + + @staticmethod + def require_websocket_connection(function): + async def wrapper(self, *args, **kwargs): + if self.open: + return await function(self, *args, **kwargs) + + raise ConnectionNotOpen("No open connection with the server.") + + return wrapper