diff --git a/utils/file_templates.py b/utils/file_templates.py index 8815006..ff8be84 100644 --- a/utils/file_templates.py +++ b/utils/file_templates.py @@ -1,60 +1,51 @@ from typing import Dict -def directional_strategy_template(strategy_name: str, - exchange: str, - trading_pair: str, - interval: str) -> str: +def directional_strategy_template(strategy_cls_name: str) -> str: + strategy_config_cls_name = f"{strategy_cls_name}Config" + sma_config_text = "{self.config.sma_length}" return f"""import pandas_ta as ta -import pandas as pd -import numpy as np +from pydantic import BaseModel, Field from quants_lab.strategy.directional_strategy_base import DirectionalStrategyBase +class {strategy_config_cls_name}(BaseModel): + exchange: str = Field(default="binance_perpetual") + trading_pair: str = Field(default="ETH-USDT") + interval: str = Field(default="1h") + sma_length: int = Field(default=20, ge=10, le=200) + # ... Add more fields here -class {strategy_name}(DirectionalStrategyBase): - # Define the attributes of the strategy - def __init__(self, - exchange="{exchange}", - trading_pair="{trading_pair}", - interval="{interval}"): - self.exchange = exchange - self.trading_pair = trading_pair - self.interval = interval + +class {strategy_cls_name}(DirectionalStrategyBase[{strategy_config_cls_name}]): def get_raw_data(self): # The method get candles will search for the data in the folder data/candles # If the data is not there, you can use the candles downloader to get the data df = self.get_candles( - exchange=self.exchange, - trading_pair=self.trading_pair, - interval=self.interval, + exchange=self.config.exchange, + trading_pair=self.config.trading_pair, + interval=self.config.interval, ) return df - def add_indicators(self, df): - df.ta.sma(length=20, append=True) + def preprocessing(self, df): + df.ta.sma(length=self.config.sma_length, append=True) # ... Add more indicators here # ... Check https://github.com/twopirllc/pandas-ta#indicators-by-category for more indicators # ... Use help(ta.indicator_name) to get more info return df - def add_signals(self, df): - # ... Do your own logic - random_series = pd.Series(np.random.randint(low=0, high=101, size=100)) - random_series_2 = pd.Series(np.random.randint(low=0, high=101, size=100)) - random_thold = np.random.randint(low=45, high=65) - random_thold_2 = np.random.randint(low=45, high=65) - + def predict(self, df): # Generate long and short conditions - macd_long_cond = (random_series > random_thold) & (random_series_2 > random_thold_2) - macd_short_cond = (random_series < random_thold) & (random_series_2 > random_thold_2) + long_cond = (df['close'] > df[f'SMA_{sma_config_text}']) + short_cond = (df['close'] < df[f'SMA_{sma_config_text}']) # Choose side df['side'] = 0 - df.loc[macd_long_cond, 'side'] = 1 - df.loc[macd_short_cond, 'side'] = -1 + df.loc[long_cond, 'side'] = 1 + df.loc[short_cond, 'side'] = -1 return df """ @@ -122,7 +113,6 @@ def objective(trial): trial.set_user_attr("avg_trading_time_in_hours", strategy_analysis.avg_trading_time_in_minutes() / 60) return strategy_analysis.net_profit_pct() except Exception as e: - # TODO: Log error traceback.print_exc() raise TrialPruned() """