(feat) refactor strategy with pydantic

This commit is contained in:
cardosofede
2023-07-20 15:48:35 +02:00
parent 43784ccd46
commit 2dc689b4a2

View File

@@ -1,32 +1,38 @@
import pandas as pd
import pandas_ta as ta
from pydantic import BaseModel, Field
from quants_lab.strategy.directional_strategy_base import DirectionalStrategyBase
from quants_lab.utils import data_management
class StatArbConfig(BaseModel):
exchange: str = Field(default="binance_perpetual")
trading_pair: str = Field(default="ETH-USDT")
target_trading_pair: str = Field(default="BTC-USDT")
interval: str = Field(default="1h")
lookback: int = Field(default=100, ge=2, le=10000)
z_score_long: float = Field(default=2, ge=0, le=5)
z_score_short: float = Field(default=-2, ge=-5, le=0)
class StatArb(DirectionalStrategyBase):
def __init__(self,
exchange="binance_perpetual",
trading_pair="DOGE-USDT",
target_trading_pair="BTC-USDT",
interval="1h",
periods=100,
deviation_threshold=1.1):
self.exchange = exchange
self.trading_pair = trading_pair
self.interval = interval
self.target_trading_pair = target_trading_pair
self.periods = periods
self.deviation_threshold = deviation_threshold
def __init__(self, config: StatArbConfig):
super().__init__(config)
self.exchange = config.exchange
self.trading_pair = config.trading_pair
self.target_trading_pair = config.target_trading_pair
self.interval = config.interval
self.lookback = config.lookback
self.z_score_long = config.z_score_long
self.z_score_short = config.z_score_short
def get_raw_data(self):
df = data_management.get_dataframe(
df = self.get_candles(
exchange=self.exchange,
trading_pair=self.trading_pair,
interval=self.interval,
)
df_target = data_management.get_dataframe(
df_target = self.get_candles(
exchange=self.exchange,
trading_pair=self.target_trading_pair,
interval=self.interval,
@@ -38,14 +44,14 @@ class StatArb(DirectionalStrategyBase):
df["pct_change_original"] = df["close"].pct_change()
df["pct_change_target"] = df["close_target"].pct_change()
df["spread"] = df["pct_change_target"] - df["pct_change_original"]
df["cum_spread"] = df["spread"].rolling(self.periods).sum()
df["z_score"] = ta.zscore(df["cum_spread"], length=self.periods)
df["cum_spread"] = df["spread"].rolling(self.lookback).sum()
df["z_score"] = ta.zscore(df["cum_spread"], length=self.lookback)
return df
def predict(self, df):
df["side"] = 0
short_condition = df["z_score"] < - self.deviation_threshold
long_condition = df["z_score"] > self.deviation_threshold
short_condition = df["z_score"] < - self.z_score_short
long_condition = df["z_score"] > self.z_score_long
df.loc[long_condition, "side"] = 1
df.loc[short_condition, "side"] = -1
return df