Remove skip parameter in _Serializer::parse, add flat argument in _Serializer::__init__. Add _Serializer::__flatten class method. Fix small bugs in bfxapi.rest.endpoints.rest_public_endpoints and bfxapi.rest.endpoints.rest_authenticated_endpoints.

This commit is contained in:
Davide Casale
2023-03-08 19:31:48 +01:00
parent 87ea765281
commit bd09cc4ae4
8 changed files with 70 additions and 46 deletions

View File

@@ -17,6 +17,10 @@ disable=
dangerous-default-value, dangerous-default-value,
inconsistent-return-statements, inconsistent-return-statements,
[SIMILARITIES]
min-similarity-lines=6
[VARIABLES] [VARIABLES]
allowed-redefined-builtins=type,dir,id,all,format,len allowed-redefined-builtins=type,dir,id,all,format,len

View File

@@ -1,4 +1,4 @@
from typing import Type, Generic, TypeVar, Iterable, Optional, Dict, List, Tuple, Any, cast from typing import Type, Generic, TypeVar, Iterable, Dict, List, Tuple, Any, cast
from .exceptions import LabelerSerializerException from .exceptions import LabelerSerializerException
@@ -35,53 +35,63 @@ class _Type:
class _Serializer(Generic[T]): class _Serializer(Generic[T]):
def __init__(self, name: str, klass: Type[_Type], labels: List[str], def __init__(self, name: str, klass: Type[_Type], labels: List[str],
*, ignore: List[str] = [ "_PLACEHOLDER" ]): *, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]):
self.name, self.klass, self.__labels, self.__ignore = name, klass, labels, ignore self.name, self.klass, self.__labels, self.__flat, self.__ignore = name, klass, labels, flat, ignore
def _serialize(self, *args: Any, skip: Optional[List[str]] = None) -> Iterable[Tuple[str, Any]]: def _serialize(self, *args: Any) -> Iterable[Tuple[str, Any]]:
labels, skips = [], [] if self.__flat:
args = tuple(_Serializer.__flatten(list(args)))
for label in self.__labels: if len(self.__labels) > len(args):
(labels, skips)[label in (skip or [])].append(label)
if len(labels) > len(args):
raise LabelerSerializerException(f"{self.name} -> <labels> and <*args> " \ raise LabelerSerializerException(f"{self.name} -> <labels> and <*args> " \
"arguments should contain the same amount of elements.") "arguments should contain the same amount of elements.")
for index, label in enumerate(labels): for index, label in enumerate(self.__labels):
if label not in self.__ignore: if label not in self.__ignore:
yield label, args[index] yield label, args[index]
for skip in skips: def parse(self, *values: Any) -> T:
yield skip, None return cast(T, self.klass(**dict(self._serialize(*values))))
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T:
return cast(T, self.klass(**dict(self._serialize(*values, skip=skip))))
def get_labels(self) -> List[str]: def get_labels(self) -> List[str]:
return [ label for label in self.__labels if label not in self.__ignore ] return [ label for label in self.__labels if label not in self.__ignore ]
@classmethod
def __flatten(cls, array: List[Any]) -> List[Any]:
if len(array) == 0:
return array
if isinstance(array[0], list):
return cls.__flatten(array[0]) + cls.__flatten(array[1:])
return array[:1] + cls.__flatten(array[1:])
class _RecursiveSerializer(_Serializer, Generic[T]): class _RecursiveSerializer(_Serializer, Generic[T]):
def __init__(self, name: str, klass: Type[_Type], labels: List[str], def __init__(self, name: str, klass: Type[_Type], labels: List[str],
*, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = ["_PLACEHOLDER"]): *, serializers: Dict[str, _Serializer[Any]],
super().__init__(name, klass, labels, ignore = ignore) flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]):
super().__init__(name, klass, labels, flat=flat, ignore=ignore)
self.serializers = serializers self.serializers = serializers
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T: def parse(self, *values: Any) -> T:
serialization = dict(self._serialize(*values, skip=skip)) serialization = dict(self._serialize(*values))
for key in serialization: for key in serialization:
if key in self.serializers.keys(): if key in self.serializers.keys():
serialization[key] = self.serializers[key].parse(*serialization[key], skip=skip) serialization[key] = self.serializers[key].parse(*serialization[key])
return cast(T, self.klass(**serialization)) return cast(T, self.klass(**serialization))
def generate_labeler_serializer(name: str, klass: Type[T], labels: List[str], def generate_labeler_serializer(name: str, klass: Type[T], labels: List[str],
*, ignore: List[str] = [ "_PLACEHOLDER" ]) -> _Serializer[T]: *, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]
return _Serializer[T](name, klass, labels, ignore=ignore) ) -> _Serializer[T]:
return _Serializer[T](name, klass, labels, \
flat=flat, ignore=ignore)
def generate_recursive_serializer(name: str, klass: Type[T], labels: List[str], def generate_recursive_serializer(name: str, klass: Type[T], labels: List[str],
*, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = [ "_PLACEHOLDER" ] *, serializers: Dict[str, _Serializer[Any]],
flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]
) -> _RecursiveSerializer[T]: ) -> _RecursiveSerializer[T]:
return _RecursiveSerializer[T](name, klass, labels, serializers=serializers, ignore=ignore) return _RecursiveSerializer[T](name, klass, labels, \
serializers=serializers, flat=flat, ignore=ignore)

View File

@@ -22,7 +22,7 @@ class _Notification(_Serializer, Generic[T]):
self.serializer, self.is_iterable = serializer, is_iterable self.serializer, self.is_iterable = serializer, is_iterable
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> Notification[T]: def parse(self, *values: Any) -> Notification[T]:
notification = cast(Notification[T], Notification(**dict(self._serialize(*values)))) notification = cast(Notification[T], Notification(**dict(self._serialize(*values))))
if isinstance(self.serializer, _Serializer): if isinstance(self.serializer, _Serializer):
@@ -32,7 +32,7 @@ class _Notification(_Serializer, Generic[T]):
if len(data) == 1 and isinstance(data[0], list): if len(data) == 1 and isinstance(data[0], list):
data = data[0] data = data[0]
notification.data = self.serializer.parse(*data, skip=skip) notification.data = self.serializer.parse(*data)
else: notification.data = cast(T, [ self.serializer.parse(*sub_data, skip=skip) for sub_data in data ]) else: notification.data = cast(T, [ self.serializer.parse(*sub_data) for sub_data in data ])
return notification return notification

View File

@@ -199,11 +199,10 @@ class RestAuthenticatedEndpoints(Middleware):
def get_symbol_margin_info(self, symbol: str) -> SymbolMarginInfo: def get_symbol_margin_info(self, symbol: str) -> SymbolMarginInfo:
return serializers.SymbolMarginInfo \ return serializers.SymbolMarginInfo \
.parse(*(self._post(f"auth/r/info/margin/{symbol}")[2]), \ .parse(*self._post(f"auth/r/info/margin/{symbol}"))
skip=["symbol"])
def get_all_symbols_margin_info(self) -> List[SymbolMarginInfo]: def get_all_symbols_margin_info(self) -> List[SymbolMarginInfo]:
return [ serializers.SymbolMarginInfo.parse(*([sub_data[1]] + sub_data[2])) \ return [ serializers.SymbolMarginInfo.parse(*sub_data) \
for sub_data in self._post("auth/r/info/margin/sym_all") ] for sub_data in self._post("auth/r/info/margin/sym_all") ]
def get_positions(self) -> List[Position]: def get_positions(self) -> List[Position]:
@@ -225,6 +224,13 @@ class RestAuthenticatedEndpoints(Middleware):
.parse(*self._post("auth/w/position/increase", \ .parse(*self._post("auth/w/position/increase", \
body={ "symbol": symbol, "amount": amount })) body={ "symbol": symbol, "amount": amount }))
def get_increase_position_info(self,
symbol: str,
amount: Union[Decimal, float, str]) -> PositionIncreaseInfo:
return serializers.PositionIncreaseInfo \
.parse(*self._post("auth/r/position/increase/info", \
body={ "symbol": symbol, "amount": amount }))
def get_positions_history(self, def get_positions_history(self,
*, *,
start: Optional[str] = None, start: Optional[str] = None,

View File

@@ -49,10 +49,10 @@ class RestPublicEndpoints(Middleware):
return cast(List[FundingCurrencyTicker], data) return cast(List[FundingCurrencyTicker], data)
def get_t_ticker(self, pair: str) -> TradingPairTicker: def get_t_ticker(self, pair: str) -> TradingPairTicker:
return serializers.TradingPairTicker.parse(*self._get(f"ticker/{pair}"), skip=["symbol"]) return serializers.TradingPairTicker.parse(*([pair] + self._get(f"ticker/{pair}")))
def get_f_ticker(self, currency: str) -> FundingCurrencyTicker: def get_f_ticker(self, currency: str) -> FundingCurrencyTicker:
return serializers.FundingCurrencyTicker.parse(*self._get(f"ticker/{currency}"), skip=["symbol"]) return serializers.FundingCurrencyTicker.parse(*([currency] + self._get(f"ticker/{currency}")))
def get_tickers_history(self, def get_tickers_history(self,
symbols: List[str], symbols: List[str],
@@ -183,7 +183,7 @@ class RestPublicEndpoints(Middleware):
limit: Optional[int] = None) -> List[DerivativesStatus]: limit: Optional[int] = None) -> List[DerivativesStatus]:
params = { "sort": sort, "start": start, "end": end, "limit": limit } params = { "sort": sort, "start": start, "end": end, "limit": limit }
data = self._get(f"status/{type}/{symbol}/hist", params=params) data = self._get(f"status/{type}/{symbol}/hist", params=params)
return [ serializers.DerivativesStatus.parse(*sub_data, skip=[ "key" ]) for sub_data in data ] return [ serializers.DerivativesStatus.parse(*([symbol] + sub_data)) for sub_data in data ]
def get_liquidations(self, def get_liquidations(self,
*, *,

View File

@@ -786,12 +786,15 @@ SymbolMarginInfo = generate_labeler_serializer(
name="SymbolMarginInfo", name="SymbolMarginInfo",
klass=types.SymbolMarginInfo, klass=types.SymbolMarginInfo,
labels=[ labels=[
"_PLACEHOLDER",
"symbol", "symbol",
"tradable_balance", "tradable_balance",
"gross_balance", "gross_balance",
"buy", "buy",
"sell" "sell"
] ],
flat=True
) )
BaseMarginInfo = generate_labeler_serializer( BaseMarginInfo = generate_labeler_serializer(
@@ -849,11 +852,15 @@ PositionIncreaseInfo = generate_labeler_serializer(
"_PLACEHOLDER", "_PLACEHOLDER",
"_PLACEHOLDER", "_PLACEHOLDER",
"funding_avail", "funding_avail",
"_PLACEHOLDER",
"_PLACEHOLDER",
"funding_value", "funding_value",
"funding_required", "funding_required",
"funding_value_currency", "funding_value_currency",
"funding_required_currency" "funding_required_currency"
] ],
flat=True
) )
PositionIncrease = generate_labeler_serializer( PositionIncrease = generate_labeler_serializer(

View File

@@ -20,7 +20,7 @@ class PlatformStatus(_Type):
@dataclass @dataclass
class TradingPairTicker(_Type): class TradingPairTicker(_Type):
symbol: Optional[str] symbol: str
bid: float bid: float
bid_size: float bid_size: float
ask: float ask: float
@@ -34,7 +34,7 @@ class TradingPairTicker(_Type):
@dataclass @dataclass
class FundingCurrencyTicker(_Type): class FundingCurrencyTicker(_Type):
symbol: Optional[str] symbol: str
frr: float frr: float
bid: float bid: float
bid_period: int bid_period: int
@@ -114,7 +114,7 @@ class Candle(_Type):
@dataclass @dataclass
class DerivativesStatus(_Type): class DerivativesStatus(_Type):
key: Optional[str] key: str
mts: int mts: int
deriv_price: float deriv_price: float
spot_price: float spot_price: float
@@ -466,7 +466,7 @@ class Movement(_Type):
@dataclass @dataclass
class SymbolMarginInfo(_Type): class SymbolMarginInfo(_Type):
symbol: Optional[str] symbol: str
tradable_balance: float tradable_balance: float
gross_balance: float gross_balance: float
buy: float buy: float

View File

@@ -21,9 +21,6 @@ class TestLabeler(unittest.TestCase):
self.assertEqual(serializer.parse(5, None, 65.0, None, "X"), Test(5, 65.0, "X"), self.assertEqual(serializer.parse(5, None, 65.0, None, "X"), Test(5, 65.0, "X"),
msg="_Serializer should produce the right result.") msg="_Serializer should produce the right result.")
self.assertEqual(serializer.parse(None, 65.0, None, "X", skip=[ "A" ]), Test(None, 65.0, "X"),
msg="_Serializer should produce the right result when skip parameter is given.")
self.assertListEqual(serializer.get_labels(), [ "A", "B", "C" ], self.assertListEqual(serializer.get_labels(), [ "A", "B", "C" ],
msg="_Serializer::get_labels() should return the right list of labels.") msg="_Serializer::get_labels() should return the right list of labels.")