Apply some refactoring to sub-package bfxapi.websocket.

This commit is contained in:
Davide Casale
2023-10-16 05:59:24 +02:00
parent 133db74a72
commit ddce83be0c
5 changed files with 102 additions and 96 deletions

View File

@@ -100,7 +100,7 @@ class BfxWebSocketBucket(Connection):
await self._websocket.send(json.dumps( \ await self._websocket.send(json.dumps( \
{ "event": "conf", "flags": sum(flags) })) { "event": "conf", "flags": sum(flags) }))
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def subscribe(self, async def subscribe(self,
channel: str, channel: str,
sub_id: Optional[str] = None, sub_id: Optional[str] = None,
@@ -115,7 +115,7 @@ class BfxWebSocketBucket(Connection):
await self._websocket.send(message = \ await self._websocket.send(message = \
json.dumps(subscription)) json.dumps(subscription))
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def unsubscribe(self, sub_id: str) -> None: async def unsubscribe(self, sub_id: str) -> None:
for chan_id, subscription in list(self.__subscriptions.items()): for chan_id, subscription in list(self.__subscriptions.items()):
if subscription["sub_id"] == sub_id: if subscription["sub_id"] == sub_id:
@@ -128,7 +128,7 @@ class BfxWebSocketBucket(Connection):
await self._websocket.send(message = \ await self._websocket.send(message = \
json.dumps(unsubscription)) json.dumps(unsubscription))
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def resubscribe(self, sub_id: str) -> None: async def resubscribe(self, sub_id: str) -> None:
for subscription in self.__subscriptions.values(): for subscription in self.__subscriptions.values():
if subscription["sub_id"] == sub_id: if subscription["sub_id"] == sub_id:
@@ -136,7 +136,7 @@ class BfxWebSocketBucket(Connection):
await self.subscribe(**subscription) await self.subscribe(**subscription)
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def close(self, code: int = 1000, reason: str = str()) -> None: async def close(self, code: int = 1000, reason: str = str()) -> None:
await self._websocket.close(code, reason) await self._websocket.close(code, reason)

View File

@@ -1,17 +1,16 @@
from typing import \ from typing import \
TypedDict, List, Dict, \ TypedDict, List, Dict, \
Optional, Any, no_type_check Optional, Any
from logging import Logger from logging import Logger
from datetime import datetime from datetime import datetime
from socket import gaierror from socket import gaierror
from asyncio import Task from asyncio import Task
import \ import \
traceback, json, asyncio, \ traceback, json, asyncio, \
hmac, hashlib, random, \ random, websockets
websockets
import websockets.client import websockets.client
@@ -214,8 +213,8 @@ class BfxWebSocketClient(Connection):
self.__event_emitter.emit("open") self.__event_emitter.emit("open")
if self.__credentials: if self.__credentials:
authentication = BfxWebSocketClient. \ authentication = Connection. \
__build_authentication_message(**self.__credentials) _get_authentication_message(**self.__credentials)
await self._websocket.send(authentication) await self._websocket.send(authentication)
@@ -235,7 +234,7 @@ class BfxWebSocketClient(Connection):
raise ConnectionClosedError(rcvd=rcvd, sent=None) raise ConnectionClosedError(rcvd=rcvd, sent=None)
elif message["event"] == "auth": elif message["event"] == "auth":
if message["status"] != "OK": if message["status"] != "OK":
raise InvalidCredentialError("Cannot authenticate " + \ raise InvalidCredentialError("Can't authenticate " + \
"with given API-KEY and API-SECRET.") "with given API-KEY and API-SECRET.")
self.__event_emitter.emit("authenticated", message) self.__event_emitter.emit("authenticated", message)
@@ -257,7 +256,7 @@ class BfxWebSocketClient(Connection):
return bucket return bucket
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def subscribe(self, async def subscribe(self,
channel: str, channel: str,
sub_id: Optional[str] = None, sub_id: Optional[str] = None,
@@ -281,7 +280,7 @@ class BfxWebSocketClient(Connection):
return await bucket.subscribe( \ return await bucket.subscribe( \
channel, sub_id, **kwargs) channel, sub_id, **kwargs)
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def unsubscribe(self, sub_id: str) -> None: async def unsubscribe(self, sub_id: str) -> None:
# pylint: disable-next=consider-using-dict-items # pylint: disable-next=consider-using-dict-items
for bucket in self.__buckets: for bucket in self.__buckets:
@@ -297,7 +296,7 @@ class BfxWebSocketClient(Connection):
raise UnknownSubscriptionError("Unable to find " + \ raise UnknownSubscriptionError("Unable to find " + \
f"a subscription with sub_id <{sub_id}>.") f"a subscription with sub_id <{sub_id}>.")
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def resubscribe(self, sub_id: str) -> None: async def resubscribe(self, sub_id: str) -> None:
for bucket in self.__buckets: for bucket in self.__buckets:
if bucket.has(sub_id): if bucket.has(sub_id):
@@ -306,7 +305,7 @@ class BfxWebSocketClient(Connection):
raise UnknownSubscriptionError("Unable to find " + \ raise UnknownSubscriptionError("Unable to find " + \
f"a subscription with sub_id <{sub_id}>.") f"a subscription with sub_id <{sub_id}>.")
@Connection.require_websocket_connection @Connection._require_websocket_connection
async def close(self, code: int = 1000, reason: str = str()) -> None: async def close(self, code: int = 1000, reason: str = str()) -> None:
for bucket in self.__buckets: for bucket in self.__buckets:
await bucket.close(code=code, reason=reason) await bucket.close(code=code, reason=reason)
@@ -315,7 +314,7 @@ class BfxWebSocketClient(Connection):
await self._websocket.close( \ await self._websocket.close( \
code=code, reason=reason) code=code, reason=reason)
@Connection.require_websocket_authentication @Connection._require_websocket_authentication
async def notify(self, async def notify(self,
info: Any, info: Any,
message_id: Optional[int] = None, message_id: Optional[int] = None,
@@ -324,30 +323,10 @@ class BfxWebSocketClient(Connection):
json.dumps([ 0, "n", message_id, json.dumps([ 0, "n", message_id,
{ "type": "ucm-test", "info": info, **kwargs } ])) { "type": "ucm-test", "info": info, **kwargs } ]))
@Connection.require_websocket_authentication @Connection._require_websocket_authentication
async def __handle_websocket_input(self, event: str, data: Any) -> None: async def __handle_websocket_input(self, event: str, data: Any) -> None:
await self._websocket.send(json.dumps( \ await self._websocket.send(json.dumps( \
[ 0, event, None, data], cls=JSONEncoder)) [ 0, event, None, data], cls=JSONEncoder))
@no_type_check
def on(self, event, f = None): def on(self, event, f = None):
return self.__event_emitter.on(event, f=f) return self.__event_emitter.on(event, f=f)
@staticmethod
def __build_authentication_message(api_key: str,
api_secret: str,
filters: Optional[List[str]] = None) -> str:
message: Dict[str, Any] = \
{ "event": "auth", "filter": filters, "apiKey": api_key }
message["authNonce"] = round(datetime.now().timestamp() * 1_000_000)
message["authPayload"] = f"AUTH{message['authNonce']}"
message["authSig"] = hmac.new(
key=api_secret.encode("utf8"),
msg=message["authPayload"].encode("utf8"),
digestmod=hashlib.sha384
).hexdigest()
return json.dumps(message)

View File

@@ -1,18 +1,24 @@
from typing import \ from typing import \
TYPE_CHECKING, TypeVar, Callable, \ TypeVar, Callable, Awaitable, \
Awaitable, Optional, Any, \ List, Dict, Optional, \
cast Any, cast
from abc import ABC, abstractmethod # pylint: disable-next=wrong-import-order
from typing_extensions import \
ParamSpec, Concatenate
from typing_extensions import ParamSpec, Concatenate from abc import \
ABC, abstractmethod
from datetime import datetime
import hmac, hashlib, json
from websockets.client import WebSocketClientProtocol
from bfxapi.websocket.exceptions import \ from bfxapi.websocket.exceptions import \
ConnectionNotOpen, ActionRequiresAuthentication ConnectionNotOpen, ActionRequiresAuthentication
if TYPE_CHECKING:
from websockets.client import WebSocketClientProtocol
_S = TypeVar("_S", bound="Connection") _S = TypeVar("_S", bound="Connection")
_R = TypeVar("_R") _R = TypeVar("_R")
@@ -27,7 +33,7 @@ class Connection(ABC):
self._authentication: bool = False self._authentication: bool = False
self.__protocol: Optional["WebSocketClientProtocol"] = None self.__protocol: Optional[WebSocketClientProtocol] = None
@property @property
def open(self) -> bool: def open(self) -> bool:
@@ -39,11 +45,11 @@ class Connection(ABC):
return self._authentication return self._authentication
@property @property
def _websocket(self) -> "WebSocketClientProtocol": def _websocket(self) -> WebSocketClientProtocol:
return cast("WebSocketClientProtocol", self.__protocol) return cast(WebSocketClientProtocol, self.__protocol)
@_websocket.setter @_websocket.setter
def _websocket(self, protocol: "WebSocketClientProtocol") -> None: def _websocket(self, protocol: WebSocketClientProtocol) -> None:
self.__protocol = protocol self.__protocol = protocol
@abstractmethod @abstractmethod
@@ -51,9 +57,9 @@ class Connection(ABC):
... ...
@staticmethod @staticmethod
def require_websocket_connection( def _require_websocket_connection(
function: Callable[Concatenate[_S, _P], Awaitable[_R]] function: Callable[Concatenate[_S, _P], Awaitable[_R]]
) -> Callable[Concatenate[_S, _P], Awaitable["_R"]]: ) -> Callable[Concatenate[_S, _P], Awaitable[_R]]:
async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R:
if self.open: if self.open:
return await function(self, *args, **kwargs) return await function(self, *args, **kwargs)
@@ -63,7 +69,7 @@ class Connection(ABC):
return wrapper return wrapper
@staticmethod @staticmethod
def require_websocket_authentication( def _require_websocket_authentication(
function: Callable[Concatenate[_S, _P], Awaitable[_R]] function: Callable[Concatenate[_S, _P], Awaitable[_R]]
) -> Callable[Concatenate[_S, _P], Awaitable[_R]]: ) -> Callable[Concatenate[_S, _P], Awaitable[_R]]:
async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R: async def wrapper(self: _S, *args: Any, **kwargs: Any) -> _R:
@@ -71,8 +77,31 @@ class Connection(ABC):
raise ActionRequiresAuthentication("To perform this action you need to " \ raise ActionRequiresAuthentication("To perform this action you need to " \
"authenticate using your API_KEY and API_SECRET.") "authenticate using your API_KEY and API_SECRET.")
internal = Connection.require_websocket_connection(function) internal = Connection._require_websocket_connection(function)
return await internal(self, *args, **kwargs) return await internal(self, *args, **kwargs)
return wrapper return wrapper
@staticmethod
def _get_authentication_message(
api_key: str,
api_secret: str,
filters: Optional[List[str]] = None
) -> str:
message: Dict[str, Any] = \
{ "event": "auth", "filter": filters, "apiKey": api_key }
message["authNonce"] = round(datetime.now().timestamp() * 1_000_000)
message["authPayload"] = f"AUTH{message['authNonce']}"
auth_sig = hmac.new(
key=api_secret.encode("utf8"),
msg=message["authPayload"].encode("utf8"),
digestmod=hashlib.sha384
)
message["authSig"] = auth_sig.hexdigest()
return json.dumps(message)

View File

@@ -1,15 +1,14 @@
from typing import TYPE_CHECKING, \ from typing import \
Dict, Tuple, Any Dict, Tuple, Any
from pyee.base import EventEmitter
from bfxapi.types import serializers from bfxapi.types import serializers
from bfxapi.types.serializers import _Notification from bfxapi.types.serializers import _Notification
if TYPE_CHECKING: from bfxapi.types.dataclasses import \
from bfxapi.types.dataclasses import \ Order, FundingOffer
Order, FundingOffer
from pyee.base import EventEmitter
class AuthEventsHandler: class AuthEventsHandler:
__ABBREVIATIONS = { __ABBREVIATIONS = {
@@ -23,24 +22,24 @@ class AuthEventsHandler:
"flc": "funding_loan_close", "ws": "wallet_snapshot", "wu": "wallet_update" "flc": "funding_loan_close", "ws": "wallet_snapshot", "wu": "wallet_update"
} }
def __init__(self, event_emitter: "EventEmitter") -> None: __SERIALIZERS: Dict[Tuple[str, ...], serializers._Serializer] = {
self.__event_emitter = event_emitter ("os", "on", "ou", "oc"): serializers.Order,
("ps", "pn", "pu", "pc"): serializers.Position,
("te", "tu"): serializers.Trade,
("fos", "fon", "fou", "foc"): serializers.FundingOffer,
("fcs", "fcn", "fcu", "fcc"): serializers.FundingCredit,
("fls", "fln", "flu", "flc"): serializers.FundingLoan,
("ws", "wu"): serializers.Wallet
}
self.__serializers: Dict[Tuple[str, ...], serializers._Serializer] = { def __init__(self, event_emitter: EventEmitter) -> None:
("os", "on", "ou", "oc",): serializers.Order, self.__event_emitter = event_emitter
("ps", "pn", "pu", "pc",): serializers.Position,
("te", "tu"): serializers.Trade,
("fos", "fon", "fou", "foc",): serializers.FundingOffer,
("fcs", "fcn", "fcu", "fcc",): serializers.FundingCredit,
("fls", "fln", "flu", "flc",): serializers.FundingLoan,
("ws", "wu",): serializers.Wallet
}
def handle(self, abbrevation: str, stream: Any) -> None: def handle(self, abbrevation: str, stream: Any) -> None:
if abbrevation == "n": if abbrevation == "n":
return self.__notification(stream) return self.__notification(stream)
for abbrevations, serializer in self.__serializers.items(): for abbrevations, serializer in AuthEventsHandler.__SERIALIZERS.items():
if abbrevation in abbrevations: if abbrevation in abbrevations:
event = AuthEventsHandler.__ABBREVIATIONS[abbrevation] event = AuthEventsHandler.__ABBREVIATIONS[abbrevation]
@@ -57,12 +56,12 @@ class AuthEventsHandler:
serializer: _Notification = _Notification[None](serializer=None) serializer: _Notification = _Notification[None](serializer=None)
if stream[1] == "on-req" or stream[1] == "ou-req" or stream[1] == "oc-req": if stream[1] in ("on-req", "ou-req", "oc-req"):
event, serializer = f"{stream[1]}-notification", \ event, serializer = f"{stream[1]}-notification", \
_Notification["Order"](serializer=serializers.Order) _Notification[Order](serializer=serializers.Order)
if stream[1] == "fon-req" or stream[1] == "foc-req": if stream[1] in ("fon-req", "foc-req"):
event, serializer = f"{stream[1]}-notification", \ event, serializer = f"{stream[1]}-notification", \
_Notification["FundingOffer"](serializer=serializers.FundingOffer) _Notification[FundingOffer](serializer=serializers.FundingOffer)
self.__event_emitter.emit(event, serializer.parse(*stream)) self.__event_emitter.emit(event, serializer.parse(*stream))

View File

@@ -1,28 +1,27 @@
from typing import \ from typing import \
TYPE_CHECKING, List, Any, \ List, Any, cast
cast
from pyee.base import EventEmitter
from bfxapi.types import serializers from bfxapi.types import serializers
if TYPE_CHECKING: from bfxapi.websocket.subscriptions import \
from bfxapi.websocket.subscriptions import Subscription, \ Subscription, Ticker, Trades, \
Ticker, Trades, Book, Candles, Status Book, Candles, Status
from pyee.base import EventEmitter
_CHECKSUM = "cs" _CHECKSUM = "cs"
class PublicChannelsHandler: class PublicChannelsHandler:
def __init__(self, event_emitter: "EventEmitter") -> None: def __init__(self, event_emitter: EventEmitter) -> None:
self.__event_emitter = event_emitter self.__event_emitter = event_emitter
def handle(self, subscription: "Subscription", stream: List[Any]) -> None: def handle(self, subscription: Subscription, stream: List[Any]) -> None:
if subscription["channel"] == "ticker": if subscription["channel"] == "ticker":
self.__ticker_channel_handler(cast("Ticker", subscription), stream) self.__ticker_channel_handler(cast(Ticker, subscription), stream)
elif subscription["channel"] == "trades": elif subscription["channel"] == "trades":
self.__trades_channel_handler(cast("Trades", subscription), stream) self.__trades_channel_handler(cast(Trades, subscription), stream)
elif subscription["channel"] == "book": elif subscription["channel"] == "book":
subscription = cast("Book", subscription) subscription = cast(Book, subscription)
if stream[0] == _CHECKSUM: if stream[0] == _CHECKSUM:
self.__checksum_handler(subscription, stream[1]) self.__checksum_handler(subscription, stream[1])
@@ -32,11 +31,11 @@ class PublicChannelsHandler:
else: else:
self.__raw_book_channel_handler(subscription, stream) self.__raw_book_channel_handler(subscription, stream)
elif subscription["channel"] == "candles": elif subscription["channel"] == "candles":
self.__candles_channel_handler(cast("Candles", subscription), stream) self.__candles_channel_handler(cast(Candles, subscription), stream)
elif subscription["channel"] == "status": elif subscription["channel"] == "status":
self.__status_channel_handler(cast("Status", subscription), stream) self.__status_channel_handler(cast(Status, subscription), stream)
def __ticker_channel_handler(self, subscription: "Ticker", stream: List[Any]): def __ticker_channel_handler(self, subscription: Ticker, stream: List[Any]):
if subscription["symbol"].startswith("t"): if subscription["symbol"].startswith("t"):
return self.__event_emitter.emit("t_ticker_update", subscription, \ return self.__event_emitter.emit("t_ticker_update", subscription, \
serializers.TradingPairTicker.parse(*stream[0])) serializers.TradingPairTicker.parse(*stream[0]))
@@ -45,7 +44,7 @@ class PublicChannelsHandler:
return self.__event_emitter.emit("f_ticker_update", subscription, \ return self.__event_emitter.emit("f_ticker_update", subscription, \
serializers.FundingCurrencyTicker.parse(*stream[0])) serializers.FundingCurrencyTicker.parse(*stream[0]))
def __trades_channel_handler(self, subscription: "Trades", stream: List[Any]): def __trades_channel_handler(self, subscription: Trades, stream: List[Any]):
if (event := stream[0]) and event in [ "te", "tu", "fte", "ftu" ]: if (event := stream[0]) and event in [ "te", "tu", "fte", "ftu" ]:
events = { "te": "t_trade_execution", "tu": "t_trade_execution_update", \ events = { "te": "t_trade_execution", "tu": "t_trade_execution_update", \
"fte": "f_trade_execution", "ftu": "f_trade_execution_update" } "fte": "f_trade_execution", "ftu": "f_trade_execution_update" }
@@ -68,7 +67,7 @@ class PublicChannelsHandler:
[ serializers.FundingCurrencyTrade.parse(*sub_stream) \ [ serializers.FundingCurrencyTrade.parse(*sub_stream) \
for sub_stream in stream[0] ]) for sub_stream in stream[0] ])
def __book_channel_handler(self, subscription: "Book", stream: List[Any]): def __book_channel_handler(self, subscription: Book, stream: List[Any]):
if subscription["symbol"].startswith("t"): if subscription["symbol"].startswith("t"):
if all(isinstance(sub_stream, list) for sub_stream in stream[0]): if all(isinstance(sub_stream, list) for sub_stream in stream[0]):
return self.__event_emitter.emit("t_book_snapshot", subscription, \ return self.__event_emitter.emit("t_book_snapshot", subscription, \
@@ -87,7 +86,7 @@ class PublicChannelsHandler:
return self.__event_emitter.emit("f_book_update", subscription, \ return self.__event_emitter.emit("f_book_update", subscription, \
serializers.FundingCurrencyBook.parse(*stream[0])) serializers.FundingCurrencyBook.parse(*stream[0]))
def __raw_book_channel_handler(self, subscription: "Book", stream: List[Any]): def __raw_book_channel_handler(self, subscription: Book, stream: List[Any]):
if subscription["symbol"].startswith("t"): if subscription["symbol"].startswith("t"):
if all(isinstance(sub_stream, list) for sub_stream in stream[0]): if all(isinstance(sub_stream, list) for sub_stream in stream[0]):
return self.__event_emitter.emit("t_raw_book_snapshot", subscription, \ return self.__event_emitter.emit("t_raw_book_snapshot", subscription, \
@@ -106,7 +105,7 @@ class PublicChannelsHandler:
return self.__event_emitter.emit("f_raw_book_update", subscription, \ return self.__event_emitter.emit("f_raw_book_update", subscription, \
serializers.FundingCurrencyRawBook.parse(*stream[0])) serializers.FundingCurrencyRawBook.parse(*stream[0]))
def __candles_channel_handler(self, subscription: "Candles", stream: List[Any]): def __candles_channel_handler(self, subscription: Candles, stream: List[Any]):
if all(isinstance(sub_stream, list) for sub_stream in stream[0]): if all(isinstance(sub_stream, list) for sub_stream in stream[0]):
return self.__event_emitter.emit("candles_snapshot", subscription, \ return self.__event_emitter.emit("candles_snapshot", subscription, \
[ serializers.Candle.parse(*sub_stream) \ [ serializers.Candle.parse(*sub_stream) \
@@ -115,7 +114,7 @@ class PublicChannelsHandler:
return self.__event_emitter.emit("candles_update", subscription, \ return self.__event_emitter.emit("candles_update", subscription, \
serializers.Candle.parse(*stream[0])) serializers.Candle.parse(*stream[0]))
def __status_channel_handler(self, subscription: "Status", stream: List[Any]): def __status_channel_handler(self, subscription: Status, stream: List[Any]):
if subscription["key"].startswith("deriv:"): if subscription["key"].startswith("deriv:"):
return self.__event_emitter.emit("derivatives_status_update", subscription, \ return self.__event_emitter.emit("derivatives_status_update", subscription, \
serializers.DerivativesStatus.parse(*stream[0])) serializers.DerivativesStatus.parse(*stream[0]))
@@ -124,6 +123,6 @@ class PublicChannelsHandler:
return self.__event_emitter.emit("liquidation_feed_update", subscription, \ return self.__event_emitter.emit("liquidation_feed_update", subscription, \
serializers.Liquidation.parse(*stream[0][0])) serializers.Liquidation.parse(*stream[0][0]))
def __checksum_handler(self, subscription: "Book", value: int): def __checksum_handler(self, subscription: Book, value: int):
return self.__event_emitter.emit( \ return self.__event_emitter.emit( \
"checksum", subscription, value & 0xFFFFFFFF) "checksum", subscription, value & 0xFFFFFFFF)