Remove skip parameter in _Serializer::parse, add flat argument in _Serializer::__init__. Add _Serializer::__flatten class method. Fix small bugs in bfxapi.rest.endpoints.rest_public_endpoints and bfxapi.rest.endpoints.rest_authenticated_endpoints.

This commit is contained in:
Davide Casale
2023-03-08 19:31:48 +01:00
parent 87ea765281
commit bd09cc4ae4
8 changed files with 70 additions and 46 deletions

View File

@@ -17,6 +17,10 @@ disable=
dangerous-default-value,
inconsistent-return-statements,
[SIMILARITIES]
min-similarity-lines=6
[VARIABLES]
allowed-redefined-builtins=type,dir,id,all,format,len

View File

@@ -1,4 +1,4 @@
from typing import Type, Generic, TypeVar, Iterable, Optional, Dict, List, Tuple, Any, cast
from typing import Type, Generic, TypeVar, Iterable, Dict, List, Tuple, Any, cast
from .exceptions import LabelerSerializerException
@@ -35,53 +35,63 @@ class _Type:
class _Serializer(Generic[T]):
def __init__(self, name: str, klass: Type[_Type], labels: List[str],
*, ignore: List[str] = [ "_PLACEHOLDER" ]):
self.name, self.klass, self.__labels, self.__ignore = name, klass, labels, ignore
*, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]):
self.name, self.klass, self.__labels, self.__flat, self.__ignore = name, klass, labels, flat, ignore
def _serialize(self, *args: Any, skip: Optional[List[str]] = None) -> Iterable[Tuple[str, Any]]:
labels, skips = [], []
def _serialize(self, *args: Any) -> Iterable[Tuple[str, Any]]:
if self.__flat:
args = tuple(_Serializer.__flatten(list(args)))
for label in self.__labels:
(labels, skips)[label in (skip or [])].append(label)
if len(labels) > len(args):
if len(self.__labels) > len(args):
raise LabelerSerializerException(f"{self.name} -> <labels> and <*args> " \
"arguments should contain the same amount of elements.")
for index, label in enumerate(labels):
for index, label in enumerate(self.__labels):
if label not in self.__ignore:
yield label, args[index]
for skip in skips:
yield skip, None
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T:
return cast(T, self.klass(**dict(self._serialize(*values, skip=skip))))
def parse(self, *values: Any) -> T:
return cast(T, self.klass(**dict(self._serialize(*values))))
def get_labels(self) -> List[str]:
return [ label for label in self.__labels if label not in self.__ignore ]
@classmethod
def __flatten(cls, array: List[Any]) -> List[Any]:
if len(array) == 0:
return array
if isinstance(array[0], list):
return cls.__flatten(array[0]) + cls.__flatten(array[1:])
return array[:1] + cls.__flatten(array[1:])
class _RecursiveSerializer(_Serializer, Generic[T]):
def __init__(self, name: str, klass: Type[_Type], labels: List[str],
*, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = ["_PLACEHOLDER"]):
super().__init__(name, klass, labels, ignore = ignore)
*, serializers: Dict[str, _Serializer[Any]],
flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]):
super().__init__(name, klass, labels, flat=flat, ignore=ignore)
self.serializers = serializers
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> T:
serialization = dict(self._serialize(*values, skip=skip))
def parse(self, *values: Any) -> T:
serialization = dict(self._serialize(*values))
for key in serialization:
if key in self.serializers.keys():
serialization[key] = self.serializers[key].parse(*serialization[key], skip=skip)
serialization[key] = self.serializers[key].parse(*serialization[key])
return cast(T, self.klass(**serialization))
def generate_labeler_serializer(name: str, klass: Type[T], labels: List[str],
*, ignore: List[str] = [ "_PLACEHOLDER" ]) -> _Serializer[T]:
return _Serializer[T](name, klass, labels, ignore=ignore)
*, flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]
) -> _Serializer[T]:
return _Serializer[T](name, klass, labels, \
flat=flat, ignore=ignore)
def generate_recursive_serializer(name: str, klass: Type[T], labels: List[str],
*, serializers: Dict[str, _Serializer[Any]], ignore: List[str] = [ "_PLACEHOLDER" ]
*, serializers: Dict[str, _Serializer[Any]],
flat: bool = False, ignore: List[str] = [ "_PLACEHOLDER" ]
) -> _RecursiveSerializer[T]:
return _RecursiveSerializer[T](name, klass, labels, serializers=serializers, ignore=ignore)
return _RecursiveSerializer[T](name, klass, labels, \
serializers=serializers, flat=flat, ignore=ignore)

View File

@@ -22,7 +22,7 @@ class _Notification(_Serializer, Generic[T]):
self.serializer, self.is_iterable = serializer, is_iterable
def parse(self, *values: Any, skip: Optional[List[str]] = None) -> Notification[T]:
def parse(self, *values: Any) -> Notification[T]:
notification = cast(Notification[T], Notification(**dict(self._serialize(*values))))
if isinstance(self.serializer, _Serializer):
@@ -32,7 +32,7 @@ class _Notification(_Serializer, Generic[T]):
if len(data) == 1 and isinstance(data[0], list):
data = data[0]
notification.data = self.serializer.parse(*data, skip=skip)
else: notification.data = cast(T, [ self.serializer.parse(*sub_data, skip=skip) for sub_data in data ])
notification.data = self.serializer.parse(*data)
else: notification.data = cast(T, [ self.serializer.parse(*sub_data) for sub_data in data ])
return notification

View File

@@ -199,11 +199,10 @@ class RestAuthenticatedEndpoints(Middleware):
def get_symbol_margin_info(self, symbol: str) -> SymbolMarginInfo:
return serializers.SymbolMarginInfo \
.parse(*(self._post(f"auth/r/info/margin/{symbol}")[2]), \
skip=["symbol"])
.parse(*self._post(f"auth/r/info/margin/{symbol}"))
def get_all_symbols_margin_info(self) -> List[SymbolMarginInfo]:
return [ serializers.SymbolMarginInfo.parse(*([sub_data[1]] + sub_data[2])) \
return [ serializers.SymbolMarginInfo.parse(*sub_data) \
for sub_data in self._post("auth/r/info/margin/sym_all") ]
def get_positions(self) -> List[Position]:
@@ -225,6 +224,13 @@ class RestAuthenticatedEndpoints(Middleware):
.parse(*self._post("auth/w/position/increase", \
body={ "symbol": symbol, "amount": amount }))
def get_increase_position_info(self,
symbol: str,
amount: Union[Decimal, float, str]) -> PositionIncreaseInfo:
return serializers.PositionIncreaseInfo \
.parse(*self._post("auth/r/position/increase/info", \
body={ "symbol": symbol, "amount": amount }))
def get_positions_history(self,
*,
start: Optional[str] = None,

View File

@@ -49,10 +49,10 @@ class RestPublicEndpoints(Middleware):
return cast(List[FundingCurrencyTicker], data)
def get_t_ticker(self, pair: str) -> TradingPairTicker:
return serializers.TradingPairTicker.parse(*self._get(f"ticker/{pair}"), skip=["symbol"])
return serializers.TradingPairTicker.parse(*([pair] + self._get(f"ticker/{pair}")))
def get_f_ticker(self, currency: str) -> FundingCurrencyTicker:
return serializers.FundingCurrencyTicker.parse(*self._get(f"ticker/{currency}"), skip=["symbol"])
return serializers.FundingCurrencyTicker.parse(*([currency] + self._get(f"ticker/{currency}")))
def get_tickers_history(self,
symbols: List[str],
@@ -183,7 +183,7 @@ class RestPublicEndpoints(Middleware):
limit: Optional[int] = None) -> List[DerivativesStatus]:
params = { "sort": sort, "start": start, "end": end, "limit": limit }
data = self._get(f"status/{type}/{symbol}/hist", params=params)
return [ serializers.DerivativesStatus.parse(*sub_data, skip=[ "key" ]) for sub_data in data ]
return [ serializers.DerivativesStatus.parse(*([symbol] + sub_data)) for sub_data in data ]
def get_liquidations(self,
*,

View File

@@ -786,12 +786,15 @@ SymbolMarginInfo = generate_labeler_serializer(
name="SymbolMarginInfo",
klass=types.SymbolMarginInfo,
labels=[
"_PLACEHOLDER",
"symbol",
"tradable_balance",
"gross_balance",
"buy",
"sell"
]
],
flat=True
)
BaseMarginInfo = generate_labeler_serializer(
@@ -849,11 +852,15 @@ PositionIncreaseInfo = generate_labeler_serializer(
"_PLACEHOLDER",
"_PLACEHOLDER",
"funding_avail",
"_PLACEHOLDER",
"_PLACEHOLDER",
"funding_value",
"funding_required",
"funding_value_currency",
"funding_required_currency"
]
],
flat=True
)
PositionIncrease = generate_labeler_serializer(

View File

@@ -20,7 +20,7 @@ class PlatformStatus(_Type):
@dataclass
class TradingPairTicker(_Type):
symbol: Optional[str]
symbol: str
bid: float
bid_size: float
ask: float
@@ -34,7 +34,7 @@ class TradingPairTicker(_Type):
@dataclass
class FundingCurrencyTicker(_Type):
symbol: Optional[str]
symbol: str
frr: float
bid: float
bid_period: int
@@ -114,7 +114,7 @@ class Candle(_Type):
@dataclass
class DerivativesStatus(_Type):
key: Optional[str]
key: str
mts: int
deriv_price: float
spot_price: float
@@ -466,7 +466,7 @@ class Movement(_Type):
@dataclass
class SymbolMarginInfo(_Type):
symbol: Optional[str]
symbol: str
tradable_balance: float
gross_balance: float
buy: float

View File

@@ -21,9 +21,6 @@ class TestLabeler(unittest.TestCase):
self.assertEqual(serializer.parse(5, None, 65.0, None, "X"), Test(5, 65.0, "X"),
msg="_Serializer should produce the right result.")
self.assertEqual(serializer.parse(None, 65.0, None, "X", skip=[ "A" ]), Test(None, 65.0, "X"),
msg="_Serializer should produce the right result when skip parameter is given.")
self.assertListEqual(serializer.get_labels(), [ "A", "B", "C" ],
msg="_Serializer::get_labels() should return the right list of labels.")