From bd09cc4ae499573cf6ed2a6b25a28b4fa1333c2e Mon Sep 17 00:00:00 2001 From: Davide Casale Date: Wed, 8 Mar 2023 19:31:48 +0100 Subject: [PATCH] Remove skip parameter in _Serializer::parse, add flat argument in _Serializer::__init__. Add _Serializer::__flatten class method. Fix small bugs in bfxapi.rest.endpoints.rest_public_endpoints and bfxapi.rest.endpoints.rest_authenticated_endpoints. --- .pylintrc | 4 ++ bfxapi/labeler.py | 60 +++++++++++-------- bfxapi/notification.py | 6 +- .../endpoints/rest_authenticated_endpoints.py | 18 ++++-- .../rest/endpoints/rest_public_endpoints.py | 6 +- bfxapi/rest/serializers.py | 11 +++- bfxapi/rest/types.py | 8 +-- bfxapi/tests/test_labeler.py | 3 - 8 files changed, 70 insertions(+), 46 deletions(-) diff --git a/.pylintrc b/.pylintrc index 0127cfd..2b14e62 100644 --- a/.pylintrc +++ b/.pylintrc @@ -17,6 +17,10 @@ disable= dangerous-default-value, inconsistent-return-statements, +[SIMILARITIES] + +min-similarity-lines=6 + [VARIABLES] allowed-redefined-builtins=type,dir,id,all,format,len diff --git a/bfxapi/labeler.py b/bfxapi/labeler.py index 88b6d9f..dfb4881 100644 --- a/bfxapi/labeler.py +++ b/bfxapi/labeler.py @@ -1,4 +1,4 @@ -from typing import Type, Generic, TypeVar, Iterable, Optional, Dict, List, Tuple, Any, cast +from typing import Type, Generic, TypeVar, Iterable, Dict, List, Tuple, Any, cast from .exceptions import LabelerSerializerException @@ -35,53 +35,63 @@ class _Type: class _Serializer(Generic[T]): def __init__(self, name: str, klass: Type[_Type], labels: List[str], - *, ignore: List[str] = [ "_PLACEHOLDER" ]): - self.name, self.klass, self.__labels, self.__ignore = name, klass, labels, ignore + *, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]): + self.name, self.klass, self.__labels, self.__flat, self.__ignore = name, klass, labels, flat, ignore - def _serialize(self, *args: Any, skip: Optional[List[str]] = None) -> Iterable[Tuple[str, Any]]: - labels, skips = [], [] + def _serialize(self, *args: Any) -> Iterable[Tuple[str, Any]]: + if self.__flat: + args = tuple(_Serializer.__flatten(list(args))) - for label in self.__labels: - (labels, skips)[label in (skip or [])].append(label) - - if len(labels) > len(args): + if len(self.__labels) > len(args): raise LabelerSerializerException(f"{self.name} -> and <*args> " \ "arguments should contain the same amount of elements.") - for index, label in enumerate(labels): + for index, label in enumerate(self.__labels): if label not in self.__ignore: yield label, args[index] - for skip in skips: - yield skip, None - - def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T: - return cast(T, self.klass(**dict(self._serialize(*values, skip=skip)))) + def parse(self, *values: Any) -> T: + return cast(T, self.klass(**dict(self._serialize(*values)))) def get_labels(self) -> List[str]: return [ label for label in self.__labels if label not in self.__ignore ] + @classmethod + def __flatten(cls, array: List[Any]) -> List[Any]: + if len(array) == 0: + return array + + if isinstance(array[0], list): + return cls.__flatten(array[0]) + cls.__flatten(array[1:]) + + return array[:1] + cls.__flatten(array[1:]) + class _RecursiveSerializer(_Serializer, Generic[T]): def __init__(self, name: str, klass: Type[_Type], labels: List[str], - *, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = ["_PLACEHOLDER"]): - super().__init__(name, klass, labels, ignore = ignore) + *, serializers: Dict[str, _Serializer[Any]], + flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]): + super().__init__(name, klass, labels, flat=flat, ignore=ignore) self.serializers = serializers - def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T: - serialization = dict(self._serialize(*values, skip=skip)) + def parse(self, *values: Any) -> T: + serialization = dict(self._serialize(*values)) for key in serialization: if key in self.serializers.keys(): - serialization[key] = self.serializers[key].parse(*serialization[key], skip=skip) + serialization[key] = self.serializers[key].parse(*serialization[key]) return cast(T, self.klass(**serialization)) def generate_labeler_serializer(name: str, klass: Type[T], labels: List[str], - *, ignore: List[str] = [ "_PLACEHOLDER" ]) -> _Serializer[T]: - return _Serializer[T](name, klass, labels, ignore=ignore) + *, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ] + ) -> _Serializer[T]: + return _Serializer[T](name, klass, labels, \ + flat=flat, ignore=ignore) def generate_recursive_serializer(name: str, klass: Type[T], labels: List[str], - *, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = [ "_PLACEHOLDER" ] - ) -> _RecursiveSerializer[T]: - return _RecursiveSerializer[T](name, klass, labels, serializers=serializers, ignore=ignore) + *, serializers: Dict[str, _Serializer[Any]], + flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ] + ) -> _RecursiveSerializer[T]: + return _RecursiveSerializer[T](name, klass, labels, \ + serializers=serializers, flat=flat, ignore=ignore) diff --git a/bfxapi/notification.py b/bfxapi/notification.py index 601b3f8..ae02259 100644 --- a/bfxapi/notification.py +++ b/bfxapi/notification.py @@ -22,7 +22,7 @@ class _Notification(_Serializer, Generic[T]): self.serializer, self.is_iterable = serializer, is_iterable - def parse(self, *values: Any, skip: Optional[List[str]] = None) -> Notification[T]: + def parse(self, *values: Any) -> Notification[T]: notification = cast(Notification[T], Notification(**dict(self._serialize(*values)))) if isinstance(self.serializer, _Serializer): @@ -32,7 +32,7 @@ class _Notification(_Serializer, Generic[T]): if len(data) == 1 and isinstance(data[0], list): data = data[0] - notification.data = self.serializer.parse(*data, skip=skip) - else: notification.data = cast(T, [ self.serializer.parse(*sub_data, skip=skip) for sub_data in data ]) + notification.data = self.serializer.parse(*data) + else: notification.data = cast(T, [ self.serializer.parse(*sub_data) for sub_data in data ]) return notification diff --git a/bfxapi/rest/endpoints/rest_authenticated_endpoints.py b/bfxapi/rest/endpoints/rest_authenticated_endpoints.py index a78858e..8a2e61e 100644 --- a/bfxapi/rest/endpoints/rest_authenticated_endpoints.py +++ b/bfxapi/rest/endpoints/rest_authenticated_endpoints.py @@ -145,7 +145,7 @@ class RestAuthenticatedEndpoints(Middleware): endpoint = "auth/r/orders/hist" else: endpoint = f"auth/r/orders/{symbol}/hist" - body = { + body = { "id": ids, "start": start, "end": end, "limit": limit } @@ -185,7 +185,7 @@ class RestAuthenticatedEndpoints(Middleware): start: Optional[str] = None, end: Optional[str] = None, limit: Optional[int] = None) -> List[Ledger]: - body = { + body = { "category": category, "start": start, "end": end, "limit": limit } @@ -199,11 +199,10 @@ class RestAuthenticatedEndpoints(Middleware): def get_symbol_margin_info(self, symbol: str) -> SymbolMarginInfo: return serializers.SymbolMarginInfo \ - .parse(*(self._post(f"auth/r/info/margin/{symbol}")[2]), \ - skip=["symbol"]) + .parse(*self._post(f"auth/r/info/margin/{symbol}")) def get_all_symbols_margin_info(self) -> List[SymbolMarginInfo]: - return [ serializers.SymbolMarginInfo.parse(*([sub_data[1]] + sub_data[2])) \ + return [ serializers.SymbolMarginInfo.parse(*sub_data) \ for sub_data in self._post("auth/r/info/margin/sym_all") ] def get_positions(self) -> List[Position]: @@ -225,6 +224,13 @@ class RestAuthenticatedEndpoints(Middleware): .parse(*self._post("auth/w/position/increase", \ body={ "symbol": symbol, "amount": amount })) + def get_increase_position_info(self, + symbol: str, + amount: Union[Decimal, float, str]) -> PositionIncreaseInfo: + return serializers.PositionIncreaseInfo \ + .parse(*self._post("auth/r/position/increase/info", \ + body={ "symbol": symbol, "amount": amount })) + def get_positions_history(self, *, start: Optional[str] = None, @@ -249,7 +255,7 @@ class RestAuthenticatedEndpoints(Middleware): start: Optional[str] = None, end: Optional[str] = None, limit: Optional[int] = None) -> List[PositionAudit]: - body = { + body = { "ids": ids, "start": start, "end": end, "limit": limit } diff --git a/bfxapi/rest/endpoints/rest_public_endpoints.py b/bfxapi/rest/endpoints/rest_public_endpoints.py index 1b76da3..ac494b6 100644 --- a/bfxapi/rest/endpoints/rest_public_endpoints.py +++ b/bfxapi/rest/endpoints/rest_public_endpoints.py @@ -49,10 +49,10 @@ class RestPublicEndpoints(Middleware): return cast(List[FundingCurrencyTicker], data) def get_t_ticker(self, pair: str) -> TradingPairTicker: - return serializers.TradingPairTicker.parse(*self._get(f"ticker/{pair}"), skip=["symbol"]) + return serializers.TradingPairTicker.parse(*([pair] + self._get(f"ticker/{pair}"))) def get_f_ticker(self, currency: str) -> FundingCurrencyTicker: - return serializers.FundingCurrencyTicker.parse(*self._get(f"ticker/{currency}"), skip=["symbol"]) + return serializers.FundingCurrencyTicker.parse(*([currency] + self._get(f"ticker/{currency}"))) def get_tickers_history(self, symbols: List[str], @@ -183,7 +183,7 @@ class RestPublicEndpoints(Middleware): limit: Optional[int] = None) -> List[DerivativesStatus]: params = { "sort": sort, "start": start, "end": end, "limit": limit } data = self._get(f"status/{type}/{symbol}/hist", params=params) - return [ serializers.DerivativesStatus.parse(*sub_data, skip=[ "key" ]) for sub_data in data ] + return [ serializers.DerivativesStatus.parse(*([symbol] + sub_data)) for sub_data in data ] def get_liquidations(self, *, diff --git a/bfxapi/rest/serializers.py b/bfxapi/rest/serializers.py index cee664d..1fd4684 100644 --- a/bfxapi/rest/serializers.py +++ b/bfxapi/rest/serializers.py @@ -786,12 +786,15 @@ SymbolMarginInfo = generate_labeler_serializer( name="SymbolMarginInfo", klass=types.SymbolMarginInfo, labels=[ + "_PLACEHOLDER", "symbol", "tradable_balance", "gross_balance", "buy", "sell" - ] + ], + + flat=True ) BaseMarginInfo = generate_labeler_serializer( @@ -849,11 +852,15 @@ PositionIncreaseInfo = generate_labeler_serializer( "_PLACEHOLDER", "_PLACEHOLDER", "funding_avail", + "_PLACEHOLDER", + "_PLACEHOLDER", "funding_value", "funding_required", "funding_value_currency", "funding_required_currency" - ] + ], + + flat=True ) PositionIncrease = generate_labeler_serializer( diff --git a/bfxapi/rest/types.py b/bfxapi/rest/types.py index f8988ea..34fdb46 100644 --- a/bfxapi/rest/types.py +++ b/bfxapi/rest/types.py @@ -20,7 +20,7 @@ class PlatformStatus(_Type): @dataclass class TradingPairTicker(_Type): - symbol: Optional[str] + symbol: str bid: float bid_size: float ask: float @@ -34,7 +34,7 @@ class TradingPairTicker(_Type): @dataclass class FundingCurrencyTicker(_Type): - symbol: Optional[str] + symbol: str frr: float bid: float bid_period: int @@ -114,7 +114,7 @@ class Candle(_Type): @dataclass class DerivativesStatus(_Type): - key: Optional[str] + key: str mts: int deriv_price: float spot_price: float @@ -466,7 +466,7 @@ class Movement(_Type): @dataclass class SymbolMarginInfo(_Type): - symbol: Optional[str] + symbol: str tradable_balance: float gross_balance: float buy: float diff --git a/bfxapi/tests/test_labeler.py b/bfxapi/tests/test_labeler.py index cb88528..c375798 100644 --- a/bfxapi/tests/test_labeler.py +++ b/bfxapi/tests/test_labeler.py @@ -21,9 +21,6 @@ class TestLabeler(unittest.TestCase): self.assertEqual(serializer.parse(5, None, 65.0, None, "X"), Test(5, 65.0, "X"), msg="_Serializer should produce the right result.") - self.assertEqual(serializer.parse(None, 65.0, None, "X", skip=[ "A" ]), Test(None, 65.0, "X"), - msg="_Serializer should produce the right result when skip parameter is given.") - self.assertListEqual(serializer.get_labels(), [ "A", "B", "C" ], msg="_Serializer::get_labels() should return the right list of labels.")