diff --git a/bots/controllers/directional_trading/bollinger_v1.py b/bots/controllers/directional_trading/bollinger_v1.py index 8f1e92e..edaf51a 100644 --- a/bots/controllers/directional_trading/bollinger_v1.py +++ b/bots/controllers/directional_trading/bollinger_v1.py @@ -8,6 +8,7 @@ DirectionalTradingControllerConfigBase, ) from pydantic import Field, validator +import pandas as pd class BollingerV1ControllerConfig(DirectionalTradingControllerConfigBase): @@ -68,23 +69,37 @@ def __init__(self, config: BollingerV1ControllerConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) async def update_processed_data(self): + print("Starting update_processed_data in BollingerV1Controller") df = self.market_data_provider.get_candles_df(connector_name=self.config.candles_connector, trading_pair=self.config.candles_trading_pair, interval=self.config.interval, max_records=self.max_records) - # Add indicators - df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std, append=True) - bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] + print(f"Got candles DataFrame with shape: {df.shape}") + print(f"DataFrame columns: {df.columns.tolist()}") + + # Add indicators using pandas_ta + print(f"Calculating Bollinger Bands with length={self.config.bb_length}, std={self.config.bb_std}") + bbands = df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std) + print(f"Bollinger Bands columns: {bbands.columns.tolist() if bbands is not None else 'None'}") + + df = pd.concat([df, bbands], axis=1) + print(f"Combined DataFrame columns: {df.columns.tolist()}") # Generate signal - long_condition = bbp < self.config.bb_long_threshold - short_condition = bbp > self.config.bb_short_threshold + bbp_col = f"BBP_{self.config.bb_length}_{self.config.bb_std}" + print(f"Looking for BBP column: {bbp_col}") + print(f"BBP values: {df[bbp_col].head() if bbp_col in df.columns else 'Column not found'}") + + long_condition = df[bbp_col] < self.config.bb_long_threshold + short_condition = df[bbp_col] > self.config.bb_short_threshold # Generate signal df["signal"] = 0 df.loc[long_condition, "signal"] = 1 df.loc[short_condition, "signal"] = -1 + print(f"Generated signals: {df['signal'].value_counts()}") # Update processed data self.processed_data["signal"] = df["signal"].iloc[-1] self.processed_data["features"] = df + print("Finished update_processed_data") diff --git a/environment.yml b/environment.yml index c1d6a4b..f5c8f3f 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - cython - pip - pip: - - robotter-hummingbot==20241016 + - hummingbot==20241227 - numpy==1.26.4 - git+https://github.com/felixfontein/docker-py - python-dotenv diff --git a/routers/backtest.py b/routers/backtest.py index c2ab034..dceb3ac 100644 --- a/routers/backtest.py +++ b/routers/backtest.py @@ -1,33 +1,9 @@ from fastapi import APIRouter, HTTPException, status -from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory -from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase -from hummingbot.strategy_v2.backtesting.controllers_backtesting.directional_trading_backtesting import ( - DirectionalTradingBacktesting, -) -from hummingbot.strategy_v2.backtesting.controllers_backtesting.market_making_backtesting import MarketMakingBacktesting - -from config import CONTROLLERS_MODULE, CONTROLLERS_PATH -from routers.backtest_models import BacktestResponse, BacktestResults, BacktestingConfig, ExecutorInfo, ProcessedData -from routers.strategies_models import StrategyError +from routers.backtest_models import BacktestResponse, BacktestingConfig +from services.backtesting_service import BacktestingService, BacktestConfigError, BacktestEngineError, BacktestError router = APIRouter(tags=["Market Backtesting"]) -candles_factory = CandlesFactory() -directional_trading_backtesting = DirectionalTradingBacktesting() -market_making_backtesting = MarketMakingBacktesting() - -BACKTESTING_ENGINES = { - "directional_trading": directional_trading_backtesting, - "market_making": market_making_backtesting -} - -class BacktestError(StrategyError): - """Base class for backtesting-related errors""" - -class BacktestConfigError(BacktestError): - """Raised when there's an error in the backtesting configuration""" - -class BacktestEngineError(BacktestError): - """Raised when there's an error during backtesting execution""" +backtesting_service = BacktestingService() responses = { 400: { @@ -46,6 +22,10 @@ class BacktestEngineError(BacktestError): "invalid_engine": { "summary": "Invalid Engine Type", "value": {"detail": "Backtesting engine for controller type 'unknown' not found. Available types: ['directional_trading', 'market_making']"} + }, + "invalid_strategy": { + "summary": "Invalid Strategy", + "value": {"detail": "Strategy 'unknown_strategy' not found. Use GET /strategies to see available strategies."} } } } @@ -118,20 +98,10 @@ class BacktestEngineError(BacktestError): 3. Simulates trading with the strategy 4. Analyzes performance and generates statistics - Supports two types of backtesting engines: - - Directional Trading: For trend-following and momentum strategies - - Market Making: For liquidity provision strategies - - Returns comprehensive results including: - - Executor statistics (trades, win rate, P&L) - - Processed market data and indicators - - Overall performance metrics: - - Total trades executed - - Win rate - - Profit/Loss - - Sharpe ratio - - Maximum drawdown - - Return on Investment (ROI) + Required Configuration: + - strategy_id: ID of the strategy to backtest (get available strategies from GET /strategies) + - trading_pair: The trading pair to backtest on (e.g., "BTC-USDT") + - Other parameters specific to the chosen strategy Time range requirements: - start_time must be before end_time @@ -141,64 +111,7 @@ class BacktestEngineError(BacktestError): ) async def run_backtesting(backtesting_config: BacktestingConfig) -> BacktestResponse: try: - # Load and validate controller config - try: - if isinstance(backtesting_config.config, str): - controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( - config_path=backtesting_config.config, - controllers_conf_dir_path=CONTROLLERS_PATH, - controllers_module=CONTROLLERS_MODULE - ) - else: - controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( - config_data=backtesting_config.config, - controllers_module=CONTROLLERS_MODULE - ) - except Exception as e: - raise BacktestConfigError(f"Invalid controller configuration: {str(e)}") - - # Get and validate backtesting engine - backtesting_engine = BACKTESTING_ENGINES.get(controller_config.controller_type) - if not backtesting_engine: - raise BacktestConfigError( - f"Backtesting engine for controller type {controller_config.controller_type} not found. " - f"Available types: {list(BACKTESTING_ENGINES.keys())}" - ) - - # Validate time range - if backtesting_config.end_time <= backtesting_config.start_time: - raise BacktestConfigError( - f"Invalid time range: end_time ({backtesting_config.end_time}) must be greater than " - f"start_time ({backtesting_config.start_time})" - ) - - try: - # Run backtesting - backtesting_results = await backtesting_engine.run_backtesting( - controller_config=controller_config, - trade_cost=backtesting_config.trade_cost, - start=int(backtesting_config.start_time), - end=int(backtesting_config.end_time), - backtesting_resolution=backtesting_config.backtesting_resolution - ) - except Exception as e: - raise BacktestEngineError(f"Error during backtesting execution: {str(e)}") - - try: - # Process results - processed_data = backtesting_results["processed_data"]["features"].fillna(0).to_dict() - executors_info = [ExecutorInfo(**e.to_dict()) for e in backtesting_results["executors"]] - results = backtesting_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - - return BacktestResponse( - executors=executors_info, - processed_data=ProcessedData(features=processed_data), - results=BacktestResults(**results) - ) - except Exception as e: - raise BacktestError(f"Error processing backtesting results: {str(e)}") - + return await backtesting_service.run_backtesting(backtesting_config) except BacktestConfigError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except BacktestEngineError as e: @@ -208,5 +121,29 @@ async def run_backtesting(backtesting_config: BacktestingConfig) -> BacktestResp except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Unexpected error during backtesting: {str(e)}" - ) \ No newline at end of file + detail=f"An unexpected error occurred during backtesting: {str(e)}" + ) + +@router.get( + "/backtest/engines", + response_model=dict, + summary="Get Available Backtesting Engines", + description="Returns a list of available backtesting engines and their types." +) +def get_available_engines(): + return backtesting_service.get_available_engines() + +@router.get( + "/backtest/engines/{engine_type}/config", + response_model=dict, + summary="Get Engine Configuration Schema", + description="Returns the configuration schema for a specific backtesting engine type." +) +def get_engine_config_schema(engine_type: str): + schema = backtesting_service.get_engine_config_schema(engine_type) + if not schema: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Engine type '{engine_type}' not found" + ) + return schema \ No newline at end of file diff --git a/routers/backtest_models.py b/routers/backtest_models.py index 08af219..555c11a 100644 --- a/routers/backtest_models.py +++ b/routers/backtest_models.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Union +from typing import List, Dict, Union from pydantic import BaseModel, Field from decimal import Decimal @@ -10,6 +10,7 @@ class BacktestingConfig(BaseModel): config: Union[Dict, str] class ExecutorInfo(BaseModel): + id: str level_id: str timestamp: int connector_name: str @@ -19,6 +20,9 @@ class ExecutorInfo(BaseModel): side: str leverage: int position_mode: str + trades: int + win_rate: float + profit_loss: Decimal class ProcessedData(BaseModel): features: Dict[str, List[Union[float, int, str]]] diff --git a/routers/strategies.py b/routers/strategies.py index 2719f6f..e221ed5 100644 --- a/routers/strategies.py +++ b/routers/strategies.py @@ -1,8 +1,7 @@ -from typing import Dict, Any +from typing import Dict, Any, Optional from fastapi import APIRouter, HTTPException, status from fastapi import FastAPI from contextlib import asynccontextmanager -from fastapi.responses import JSONResponse from services.libert_ai_service import LibertAIService from routers.strategies_models import ( @@ -140,11 +139,17 @@ async def lifespan(app: FastAPI): - Directional Trading: Strategies that follow market trends - Market Making: Strategies that provide market liquidity - Generic: Other types of strategies (e.g., arbitrage) + + Optional query parameter: + - strategy_type: Filter strategies by type (directional_trading, market_making, generic) """ ) -async def get_strategies() -> Dict[str, StrategyConfig]: +async def get_strategies(strategy_type: Optional[StrategyType] = None) -> Dict[str, StrategyConfig]: """Get all available strategies and their configurations.""" try: + if strategy_type: + strategies = StrategyRegistry.get_strategies_by_type(strategy_type) + return {s.id: s for s in strategies} return StrategyRegistry.get_all_strategies() except StrategyError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -154,6 +159,77 @@ async def get_strategies() -> Dict[str, StrategyConfig]: detail=f"Internal server error while fetching strategies: {str(e)}" ) +@router.get( + "/strategies/{strategy_id}", + response_model=StrategyConfig, + responses={ + 200: { + "description": "Successfully retrieved strategy details", + "content": { + "application/json": { + "example": { + "mapping": { + "id": "bollinger_v1", + "config_class": "BollingerConfig", + "module_path": "bots.controllers.directional_trading.bollinger_v1", + "strategy_type": "directional_trading", + "display_name": "Bollinger Bands Strategy", + "description": "Buys when price is low and sells when price is high based on Bollinger Bands." + }, + "parameters": { + "stop_loss": { + "name": "stop_loss", + "type": "Decimal", + "required": True, + "default": "0.03", + "display_name": "Stop Loss", + "description": "Stop loss percentage", + "group": "Risk Management", + "is_advanced": False, + "constraints": { + "min_value": 0, + "max_value": 0.1 + } + } + } + } + } + } + }, + **responses + }, + summary="Get Strategy Details", + description=""" + Returns detailed information about a specific strategy, including all its parameters and configuration options. + + Use this endpoint to: + 1. Get the complete list of parameters needed for the strategy + 2. Understand parameter constraints (min/max values, valid options) + 3. See default values and parameter descriptions + 4. Determine which parameters are required vs optional + + This information is essential for: + - Configuring a strategy for backtesting + - Understanding parameter relationships + - Setting up proper risk management + - Optimizing strategy performance + """ +) +async def get_strategy_details(strategy_id: str) -> StrategyConfig: + """Get detailed information about a specific strategy""" + try: + return StrategyRegistry.get_strategy(strategy_id) + except StrategyNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Strategy not found: {str(e)}" + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error fetching strategy details: {str(e)}" + ) + @router.post( "/strategies/suggest-parameters", response_model=ParameterSuggestionResponse, diff --git a/routers/strategies_models.py b/routers/strategies_models.py index 070ad7a..ea936ad 100644 --- a/routers/strategies_models.py +++ b/routers/strategies_models.py @@ -88,15 +88,6 @@ class StrategyParameter(BaseModel): # Type flags (for backward compatibility and specific handling) parameter_type: Optional[ParameterType] = None -class Strategy(BaseModel): - id: str - name: str - description: str - type: StrategyType - module_path: str - config_class: str - parameters: Dict[str, StrategyParameter] - class StrategyMapping(BaseModel): """Maps a strategy ID to its implementation details""" id: str # e.g., "supertrend_v1" @@ -105,6 +96,7 @@ class StrategyMapping(BaseModel): strategy_type: StrategyType display_name: str # e.g., "Supertrend V1" description: str = "" + parameters: Dict[str, StrategyParameter] = {} class StrategyConfig(BaseModel): """Complete strategy configuration including metadata and parameters""" @@ -181,7 +173,7 @@ def get_strategy_display_info() -> Dict[str, Dict[str, str]]: class StrategyRegistry: """Central registry for all trading strategies""" - _cache: Dict[str, Strategy] = {} + _cache: Dict[str, StrategyMapping] = {} @classmethod def _ensure_cache_loaded(cls): @@ -189,13 +181,13 @@ def _ensure_cache_loaded(cls): cls._cache = discover_strategies() @classmethod - def get_all_strategies(cls) -> Dict[str, Strategy]: + def get_all_strategies(cls) -> Dict[str, StrategyMapping]: """Get all available strategies with their configurations""" cls._ensure_cache_loaded() return cls._cache @classmethod - def get_strategy(cls, strategy_id: str) -> Optional[Strategy]: + def get_strategy(cls, strategy_id: str) -> Optional[StrategyMapping]: """Get a specific strategy by ID""" cls._ensure_cache_loaded() strategy = cls._cache.get(strategy_id) @@ -204,10 +196,10 @@ def get_strategy(cls, strategy_id: str) -> Optional[Strategy]: return strategy @classmethod - def get_strategies_by_type(cls, strategy_type: StrategyType) -> List[Strategy]: + def get_strategies_by_type(cls, strategy_type: StrategyType) -> List[StrategyMapping]: """Get all strategies of a specific type""" cls._ensure_cache_loaded() - return [s for s in cls._cache.values() if s.type == strategy_type] + return [s for s in cls._cache.values() if s.strategy_type == strategy_type] def convert_to_strategy_parameter(name: str, field: ModelField) -> StrategyParameter: """Convert a model field to a strategy parameter""" @@ -332,7 +324,7 @@ def generate_strategy_mapping(module_path: str, config_class: Any) -> StrategyMa ) @functools.lru_cache(maxsize=1) -def discover_strategies() -> Dict[str, Strategy]: +def discover_strategies() -> Dict[str, StrategyMapping]: """Discover and load all available strategies""" strategies = {} controllers_dir = "bots/controllers" @@ -372,11 +364,11 @@ def discover_strategies() -> Dict[str, Strategy]: parameters[field_name] = param # Create strategy - strategies[strategy_id] = Strategy( + strategies[strategy_id] = StrategyMapping( id=strategy_id, - name=display_info.get("pretty_name", " ".join(word.capitalize() for word in strategy_id.split("_"))), + display_name=display_info.get("pretty_name", " ".join(word.capitalize() for word in strategy_id.split("_"))), description=display_info.get("description", obj.__doc__ or ""), - type=strategy_type, + strategy_type=strategy_type, module_path=module_path, config_class=obj.__name__, parameters=parameters diff --git a/services/backtesting_service.py b/services/backtesting_service.py new file mode 100644 index 0000000..793edd5 --- /dev/null +++ b/services/backtesting_service.py @@ -0,0 +1,281 @@ +from typing import Dict, Any, Optional, List +from decimal import Decimal +from routers.strategies_models import StrategyRegistry, StrategyNotFoundError +from routers.backtest_models import BacktestingConfig, BacktestResponse +from hummingbot.strategy_v2.backtesting. import DirectionalTradingBacktesting +from hummingbot.strategy_v2.backtesting.controllers_backtesting.market_making_backtesting import MarketMakingBacktesting + +class BacktestError(Exception): + """Base class for backtesting errors""" + pass + +class BacktestingService: + """Service for running backtesting operations""" + + def __init__(self): + self.directional_trading_backtesting = DirectionalTradingBacktesting() + self.market_making_backtesting = MarketMakingBacktesting() + self.backtesting_engines = { + "directional_trading": self.directional_trading_backtesting, + "market_making": self.market_making_backtesting + } + + def validate_time_range(self, start_time: int, end_time: int) -> None: + """Validate the time range for backtesting""" + if start_time >= end_time: + raise BacktestError("Invalid time range: start time must be before end time") + + def transform_strategy_config(self, config: Dict[str, Any]) -> Any: + """Transform API strategy config to controller config""" + controller_name = config.get("controller_name") + if not controller_name: + raise BacktestError("Missing controller_name in configuration") + + # Get strategy info from registry + strategy = StrategyRegistry.get_strategy(controller_name) + + # Keep all configuration parameters except controller_name + filtered_config = {k: v for k, v in config.items() if k != "controller_name"} + + # Add required controller configuration + controller_config = { + "controller_name": controller_name, + "controller_type": strategy.strategy_type.value, + "id": None, # Required by base class + "connector_name": filtered_config.get("connector_name"), + "trading_pair": filtered_config.get("trading_pair"), + "leverage": filtered_config.get("leverage", 1), + "position_mode": filtered_config.get("position_mode", "ONEWAY"), + "stop_loss": filtered_config.get("stop_loss"), + "take_profit": filtered_config.get("take_profit"), + "time_limit": filtered_config.get("time_limit", 60 * 60 * 24 * 7), # 1 week default + "trailing_stop": filtered_config.get("trailing_stop"), + **filtered_config + } + + # Import and configure the strategy + module_path = f"bots.controllers.{strategy.strategy_type.value}.{controller_name}" + try: + module = __import__(module_path, fromlist=["*"]) + config_class = next( + (getattr(module, name) for name in dir(module) + if name.endswith(("ControllerConfig", "Config"))), + None + ) + if not config_class: + raise BacktestError(f"Could not find config class in module {module_path}") + + # Add candles config if needed + if "candles_connector" in filtered_config and "candles_trading_pair" in filtered_config: + from hummingbot.data_feed.candles_feed.data_types import CandlesConfig + controller_config["candles_config"] = [CandlesConfig( + connector=filtered_config["candles_connector"], + trading_pair=filtered_config["candles_trading_pair"], + interval=filtered_config.get("interval", "1m"), + max_records=500 + )] + + return config_class(**controller_config) + + except ImportError: + raise BacktestError(f"Could not import strategy module: {module_path}") + except Exception as e: + raise BacktestError(f"Error creating strategy config: {str(e)}") + + def get_available_engines(self) -> Dict[str, str]: + """Get available backtesting engines""" + return { + engine_type: engine.__class__.__name__ + for engine_type, engine in self.backtesting_engines.items() + } + + def get_engine_config_schema(self, engine_type: str) -> Optional[Dict[str, Any]]: + """Get configuration schema for a specific backtesting engine type""" + engine = self.backtesting_engines.get(engine_type) + if not engine: + return None + + strategies = StrategyRegistry.get_strategies_by_type(engine_type) + if not strategies: + return None + + schema = { + "type": "object", + "properties": { + "controller_name": { + "type": "string", + "description": "Name of the strategy controller to use", + "enum": [s.id for s in strategies] + } + }, + "required": ["controller_name"], + "additionalProperties": True + } + + try: + # Import the base config classes + from hummingbot.strategy_v2.controllers.controller_base import ControllerConfigBase + from hummingbot.strategy_v2.controllers.directional_trading_controller_base import DirectionalTradingControllerConfigBase + from hummingbot.strategy_v2.controllers.market_making_controller_base import MarketMakingControllerConfigBase + + # Get the appropriate base class for this engine type + base_config_class = { + "directional_trading": DirectionalTradingControllerConfigBase, + "market_making": MarketMakingControllerConfigBase, + "generic": ControllerConfigBase + }.get(engine_type) + + if not base_config_class: + print(f"No base config class found for engine type: {engine_type}") + return schema + + print(f"\nBase config class for {engine_type}: {base_config_class.__name__}") + + # Add fields from base classes first + def add_fields_from_class(cls): + print(f"\nProcessing class: {cls.__name__}") + if not hasattr(cls, "__fields__"): + print(f"No __fields__ in {cls.__name__}") + return + + fields = cls.__fields__ + print(f"Fields in {cls.__name__}: {list(fields.keys())}") + + for field_name, field in fields.items(): + if field_name in ["controller_type", "id"]: # Allow controller_name through + continue + + field_info = field.field_info + field_schema = { + "type": self._get_json_schema_type(field.type_), + "description": field_info.description or field_name + } + + if field.default is not None and not callable(field.default): + field_schema["default"] = field.default + + if hasattr(field_info, "gt"): + field_schema["minimum"] = field_info.gt + if hasattr(field_info, "lt"): + field_schema["maximum"] = field_info.lt + + schema["properties"][field_name] = field_schema + if field.required: + if field_name not in schema["required"]: + schema["required"].append(field_name) + + # Process all classes in the MRO chain except object + for cls in reversed(base_config_class.__mro__[:-1]): + add_fields_from_class(cls) + + # Get example strategy to determine additional fields + example_strategy = strategies[0] + module_path = f"bots.controllers.{engine_type}.{example_strategy.id}" + + print(f"\nTrying to import module: {module_path}") + module = __import__(module_path, fromlist=["*"]) + + # Find the strategy-specific config class + strategy_specific_pattern = f"{example_strategy.id.title().replace('_', '')}Config" + print(f"Looking for strategy-specific config class: {strategy_specific_pattern}") + + config_class = None + for name in dir(module): + if name.endswith(("ControllerConfig", "Config")): + cls = getattr(module, name) + print(f"Found potential config class: {name}") + # Skip CandlesConfig and classes not inheriting from ControllerConfigBase + if name == "CandlesConfig" or not issubclass(cls, ControllerConfigBase): + print(f"Skipping {name} - not a valid controller config") + continue + # Prefer strategy-specific config if found + if name == strategy_specific_pattern: + config_class = cls + break + # Otherwise take the first valid config class + if not config_class: + config_class = cls + + if config_class: + print(f"\nFound strategy config class: {config_class.__name__}") + # Process all classes in the strategy config's MRO chain except object + for cls in reversed(config_class.__mro__[:-1]): + if cls not in base_config_class.__mro__: # Skip classes we've already processed + add_fields_from_class(cls) + + print(f"\nFinal schema properties: {list(schema['properties'].keys())}") + print(f"Final required fields: {schema['required']}") + return schema + + except Exception as e: + print(f"Error generating schema: {str(e)}") + import traceback + traceback.print_exc() + return schema + + def _get_json_schema_type(self, python_type: type) -> str: + """Convert Python type to JSON schema type""" + type_map = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + list: "array", + dict: "object", + Decimal: "number" + } + return type_map.get(python_type, "string") + + async def run_backtesting(self, config: BacktestingConfig) -> BacktestResponse: + """Run backtesting with the given configuration""" + self.validate_time_range(config.start_time, config.end_time) + + strategy = StrategyRegistry.get_strategy(config.config["controller_name"]) + transformed_config = self.transform_strategy_config(config.config) + + engine_type = strategy.strategy_type.value + engine = self.backtesting_engines.get(engine_type) + if not engine: + raise BacktestError(f"Backtesting engine '{engine_type}' not found") + + results = await engine.run_backtesting( + controller_config=transformed_config, + start=config.start_time, + end=config.end_time, + backtesting_resolution=config.backtesting_resolution, + trade_cost=config.trade_cost + ) + + if not isinstance(results, dict): + raise BacktestError("Invalid results format returned from backtesting engine") + + # Process results + processed_data = results.get("processed_data", {}) + features = processed_data.get("features", {}) + + if hasattr(features, "to_dict"): + features = {col: features[col].tolist() for col in features.columns} + + # Prepare results with defaults + results_data = results.get("results", {}) + default_results = { + "total_pnl": Decimal("0"), + "total_trades": 0, + "win_rate": 0.0, + "profit_loss_ratio": 0.0, + "max_drawdown": 0.0, + "start_timestamp": config.start_time, + "end_timestamp": config.end_time + } + + if results_data: + for key, default in default_results.items(): + results_data.setdefault(key, default) + else: + results_data = default_results + + return BacktestResponse( + executors=results.get("executors", []), + results=results_data, + processed_data={"features": features} + ) \ No newline at end of file diff --git a/services/libert_ai_service.py b/services/libert_ai_service.py index cbe2a8d..92a00c7 100644 --- a/services/libert_ai_service.py +++ b/services/libert_ai_service.py @@ -3,7 +3,7 @@ import aiohttp import inspect import importlib -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Protocol from routers.strategies_models import ( ParameterSuggestion, StrategyConfig, @@ -13,76 +13,95 @@ logger = logging.getLogger(__name__) +class AIClient(Protocol): + """Protocol for AI client implementations""" + async def initialize_system_context(self, prompt: str) -> Dict[str, Any]: + """Initialize system context with the given prompt""" + ... + + async def initialize_strategy_context(self, prompt: str, slot_id: int) -> Dict[str, Any]: + """Initialize strategy context with the given prompt""" + ... + + async def get_suggestions(self, prompt: str, slot_id: int) -> Dict[str, Any]: + """Get suggestions from the AI model""" + ... + +class LibertAIClient: + """Default implementation of AIClient using Libert API""" + def __init__(self, api_url: Optional[str] = None): + self.api_url = api_url or "https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion" + self.session: Optional[aiohttp.ClientSession] = None + + async def _ensure_session(self): + """Ensure we have an active session""" + if self.session is None: + self.session = aiohttp.ClientSession() + return self.session + + async def _make_api_request(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + """Make the actual API request. This is the only method that should make HTTP calls.""" + session = await self._ensure_session() + async with session.post( + url, + headers={"Content-Type": "application/json"}, + json=payload + ) as response: + if response.status != 200: + raise ValueError(f"API request failed with status {response.status}") + return await response.json() + + def _build_request_payload( + self, + prompt: str, + slot_id: Optional[int] = None, + parent_slot_id: Optional[int] = None + ) -> Dict[str, Any]: + """Build the request payload with all necessary parameters.""" + payload = { + "prompt": prompt, + "temperature": 0.9, + "top_p": 1, + "top_k": 40, + "n": 1, + "n_predict": 100, + "stop": ["<|im_end|>"] + } + if slot_id is not None: + payload["slot_id"] = slot_id + if parent_slot_id is not None: + payload["parent_slot_id"] = parent_slot_id + return payload + + async def _make_request(self, prompt: str, slot_id: Optional[int] = None, parent_slot_id: Optional[int] = None) -> Dict[str, Any]: + """Make a request to the AI API""" + payload = self._build_request_payload(prompt, slot_id, parent_slot_id) + return await self._make_api_request(self.api_url, payload) + + async def initialize_system_context(self, prompt: str) -> Dict[str, Any]: + """Initialize system context with the given prompt""" + return await self._make_request(prompt) + + async def initialize_strategy_context(self, prompt: str, slot_id: int) -> Dict[str, Any]: + """Initialize strategy context with the given prompt""" + return await self._make_request(prompt, slot_id=slot_id, parent_slot_id=-1) + + async def get_suggestions(self, prompt: str, slot_id: int) -> Dict[str, Any]: + """Get suggestions from the AI model""" + return await self._make_request(prompt, slot_id=slot_id) + + async def close(self): + """Close the client session""" + if self.session: + await self.session.close() + self.session = None + class LibertAIService: - def __init__(self): - # Hermes 2 pro - self.api_url = "https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion" - + def __init__(self, ai_client: Optional[AIClient] = None): + self.ai_client = ai_client or LibertAIClient() self.strategy_slot_map: Dict[str, int] = {} # Maps strategy IDs to their slot IDs self.next_slot_id = 0 - - async def initialize_contexts(self, strategies: Dict[str, StrategyConfig]): - """Initialize context slots for system prompt and each strategy.""" - try: - logger.info("Starting context initialization...") - - # Initialize system prompt in slot -1 - logger.info("Initializing system context...") - await self._initialize_system_context() - - # Initialize each strategy's context - for strategy_id, strategy_config in strategies.items(): - logger.info(f"Initializing context for strategy: {strategy_id}") - slot_id = self.next_slot_id - self.strategy_slot_map[strategy_id] = slot_id - self.next_slot_id += 1 - - # Load strategy implementation code - strategy_code = await self._load_strategy_code(strategy_config.mapping) - logger.info(f"Loaded strategy code for {strategy_id}, code length: {len(strategy_code)}") - - await self._initialize_strategy_context( - strategy_mapping=strategy_config.mapping, - strategy_config=strategy_config.parameters, - strategy_code=strategy_code, - slot_id=slot_id - ) - - logger.info(f"Context initialization complete. Strategy slot map: {self.strategy_slot_map}") - - except Exception as e: - logger.error(f"Error initializing contexts: {str(e)}") - raise - - async def _load_strategy_code(self, mapping: StrategyMapping) -> str: - """Load the strategy implementation code using the strategy mapping.""" - try: - # Import the module using the mapping's module path - module = importlib.import_module(mapping.module_path) - - # Get all classes in the module - strategy_classes = inspect.getmembers( - module, - lambda member: ( - inspect.isclass(member) - and member.__module__ == module.__name__ - and not member.__name__.endswith('Config') - ) - ) - - if not strategy_classes: - raise ValueError(f"No strategy class found in {mapping.module_path}") - - # Get the source code of the strategy class - strategy_class = strategy_classes[0][1] # Take the first class - source_code = inspect.getsource(strategy_class) - - return source_code - - except Exception as e: - logger.error(f"Error loading strategy code for {mapping.id}: {str(e)}") - return f"# Strategy implementation code not found for {mapping.id}" - + async def _initialize_system_context(self): """Initialize the system prompt in slot -1.""" system_prompt = """<|im_start|>system @@ -101,28 +120,14 @@ async def _initialize_system_context(self): <|im_end|>""" try: - async with aiohttp.ClientSession() as session: - await session.post( - self.api_url, - headers={"Content-Type": "application/json"}, - json={ - "prompt": system_prompt, - "temperature": 0.9, - "top_p": 1, - "top_k": 40, - "n": 1, - "n_predict": 100, - "stop": ["<|im_end|>"] - } - ) + await self.ai_client.initialize_system_context(system_prompt) except Exception as e: - print(f"ERROR: Error initializing system context: {str(e)}") + logger.error(f"Error initializing system context: {str(e)}") raise - + async def _initialize_strategy_context( self, - strategy_mapping: StrategyMapping, - strategy_config: Dict[str, Any], + strategy: StrategyMapping, strategy_code: str, slot_id: int ): @@ -133,21 +138,20 @@ async def _initialize_strategy_context( "name": param.name, "group": param.group, "type": param.type, - "prompt": param.prompt, "default": str(param.default) if param.default is not None else None, "required": param.required, - "min_value": str(param.min_value) if param.min_value is not None else None, - "max_value": str(param.max_value) if param.max_value is not None else None, + "min_value": str(param.constraints.min_value) if param.constraints and param.constraints.min_value is not None else None, + "max_value": str(param.constraints.max_value) if param.constraints and param.constraints.max_value is not None else None, "is_advanced": param.is_advanced, "display_type": param.display_type } - for name, param in strategy_config.items() + for name, param in strategy.parameters.items() } strategy_context = f"""<|im_start|>user -Trading Strategy: {strategy_mapping.display_name} -Type: {strategy_mapping.strategy_type.value} -Description: {strategy_mapping.description} +Trading Strategy: {strategy.display_name} +Type: {strategy.strategy_type.value} +Description: {strategy.description} Strategy Configuration Schema: {json.dumps(serializable_config, indent=2)} @@ -161,25 +165,71 @@ async def _initialize_strategy_context( <|im_end|>""" try: - async with aiohttp.ClientSession() as session: - await session.post( - self.api_url, - headers={"Content-Type": "application/json"}, - json={ - "prompt": strategy_context, - "temperature": 0.9, - "top_p": 1, - "top_k": 40, - "n": 1, - "n_predict": 100, - "stop": ["<|im_end|>"], - "slot_id": slot_id, - "parent_slot_id": -1, - } + await self.ai_client.initialize_strategy_context(strategy_context, slot_id) + except Exception as e: + logger.error(f"Error initializing strategy context for {strategy.id}: {str(e)}") + raise + + async def initialize_contexts(self, strategies: Dict[str, StrategyMapping]): + """Initialize context slots for system prompt and each strategy.""" + try: + logger.info("Starting context initialization...") + + # Initialize system prompt in slot -1 + logger.info("Initializing system context...") + await self._initialize_system_context() + + # Initialize each strategy's context + for strategy_id, strategy in strategies.items(): + logger.info(f"Initializing context for strategy: {strategy_id}") + slot_id = self.next_slot_id + self.strategy_slot_map[strategy_id] = slot_id + self.next_slot_id += 1 + + # Load strategy implementation code + strategy_code = await self._load_strategy_code(strategy) + logger.info(f"Loaded strategy code for {strategy_id}, code length: {len(strategy_code)}") + + await self._initialize_strategy_context( + strategy=strategy, + strategy_code=strategy_code, + slot_id=slot_id ) + + logger.info(f"Context initialization complete. Strategy slot map: {self.strategy_slot_map}") + except Exception as e: - print(f"ERROR: Error initializing strategy context for {strategy_mapping.id}: {str(e)}") + logger.error(f"Error initializing contexts: {str(e)}") raise + + async def _load_strategy_code(self, strategy: StrategyMapping) -> str: + """Load the strategy implementation code using the strategy mapping.""" + try: + # Import the module using the mapping's module path + module = importlib.import_module(strategy.module_path) + + # Get all classes in the module + strategy_classes = inspect.getmembers( + module, + lambda member: ( + inspect.isclass(member) + and member.__module__ == module.__name__ + and not member.__name__.endswith('Config') + ) + ) + + if not strategy_classes: + raise ValueError(f"No strategy class found in {strategy.module_path}") + + # Get the source code of the strategy class + strategy_class = strategy_classes[0][1] # Take the first class + source_code = inspect.getsource(strategy_class) + + return source_code + + except Exception as e: + logger.error(f"Error loading strategy code for {strategy.id}: {str(e)}") + return f"# Strategy implementation code not found for {strategy.id}" async def get_parameter_suggestions( self, @@ -196,16 +246,16 @@ async def get_parameter_suggestions( provided_params: Parameters already provided by the user requested_params: Optional list of specific parameters to get suggestions for """ - print("\n=== Getting Parameter Suggestions ===") - print(f"Strategy ID: {strategy_id}") - print(f"Provided parameters: {json.dumps(provided_params, indent=2)}") - print(f"Requested parameters: {requested_params}") + logger.info("\n=== Getting Parameter Suggestions ===") + logger.info(f"Strategy ID: {strategy_id}") + logger.info(f"Provided parameters: {json.dumps(provided_params, indent=2)}") + logger.info(f"Requested parameters: {requested_params}") # Get strategy configuration strategies = discover_strategies() strategy = strategies.get(strategy_id) if not strategy: - print(f"ERROR: No strategy found with ID {strategy_id}") + logger.error(f"No strategy found with ID {strategy_id}") return [] # Identify missing required parameters and optional parameters @@ -224,13 +274,13 @@ async def get_parameter_suggestions( else: optional_params.append(param_name) - print(f"Missing required parameters: {missing_required}") - print(f"Optional parameters: {optional_params}") + logger.info(f"Missing required parameters: {missing_required}") + logger.info(f"Optional parameters: {optional_params}") # Get the strategy's slot ID slot_id = self.strategy_slot_map.get(strategy_id) if slot_id is None: - print(f"ERROR: No cached context found for strategy {strategy_id}") + logger.error(f"No cached context found for strategy {strategy_id}") return [] # Convert parameters to a serializable format @@ -241,181 +291,122 @@ async def get_parameter_suggestions( # Update the prompt to be more explicit about the format and requested parameters optional_params_text = f"Optional Parameters That Could Be Set:\n{', '.join(optional_params) if optional_params else 'None'}" if not requested_params else "" - + request_prompt = f"""<|im_start|>user -Strategy: {strategy.mapping.display_name} -Type: {strategy.mapping.strategy_type.value} - -Currently Provided Parameters: -{json.dumps(serializable_params, indent=2)} - -{"Parameters to Suggest:" if requested_params else "Missing Required Parameters:"} -{', '.join(requested_params) if requested_params else ', '.join(missing_required) if missing_required else 'None'} - -{optional_params_text} - -Please suggest optimal values for {"the requested" if requested_params else "the missing"} parameters using exactly this format for each parameter: - -PARAMETER: [parameter_name] -VALUE: [suggested_value] -REASONING: [detailed explanation of why this value is appropriate] - -End with a summary: -SUMMARY: [overall explanation of the suggested configuration] - -Do not include code blocks or other formats. Use only the PARAMETER/VALUE/REASONING structure. -<|im_end|>""" +> Strategy: {strategy.display_name} + Type: {strategy.strategy_type.value} + + Currently Provided Parameters: + {json.dumps(serializable_params, indent=2)} + + Parameters Requiring Suggestions: + {', '.join(missing_required) if missing_required else 'None'} + + {optional_params_text} + + Please suggest values for the missing required parameters and any optional parameters that would improve the strategy's performance. + <|im_end|>""" try: - async with aiohttp.ClientSession() as session: - print(f"\nSending request to LibertAI API...") - print(f"Request prompt:\n{request_prompt}") - - request_payload = { - "slot_id": self.next_slot_id, - "parent_slot_id": slot_id, - "prompt": request_prompt, - "temperature": 0.9, - "top_p": 1, - "top_k": 40, - "n": 1, - "n_predict": 1500, - "stop": ["<|im_end|>"] - } - - async with session.post( - self.api_url, - headers={"Content-Type": "application/json"}, - json=request_payload - ) as response: - if response.status != 200: - print(f"ERROR: API returned status {response.status}") - response_text = await response.text() - print(f"Response body: {response_text}") - return [] - - result = await response.json() - print(f"\nReceived response from API: {json.dumps(result, indent=2)}") - return self._parse_ai_response( - {"choices": [{"message": {"content": result["content"]}}]}, - strategy_config=strategy_config, - provided_params=provided_params - ) - + response = await self.ai_client.get_suggestions(request_prompt, slot_id) + suggestions = self._parse_ai_response(response, strategy_config, provided_params) + return suggestions + except Exception as e: - print(f"ERROR: Exception during API call: {str(e)}") + logger.error(f"Error getting parameter suggestions: {str(e)}") return [] - def _parse_ai_response(self, ai_response: Dict[str, Any], strategy_config: Dict[str, Any], provided_params: Dict[str, Any]) -> List[ParameterSuggestion]: - print("\n=== Parsing AI Response ===") + def _parse_ai_response( + self, + ai_response: Dict[str, Any], + strategy_config: Dict[str, Any], + provided_params: Dict[str, Any] + ) -> List[ParameterSuggestion]: + """Parse the AI response and extract parameter suggestions.""" + suggestions = [] + summary_suggestion = None + try: - content = ai_response["choices"][0]["message"]["content"] - print(f"Response content preview: {content[:200]}...") + # Extract the response content + content = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "") + if not content: + return [] - suggestions = [] - seen_params = set() - summary = None + # Split the content into sections + sections = content.strip().split("\n\n") - # Create a map of default values and provided values for comparison - default_values = { - name: str(param.default) if param.default is not None else None - for name, param in strategy_config.items() - } + # Process each section + current_param = None + current_value = None + current_reasoning = None - provided_values = { - name: str(value) if hasattr(value, "__str__") else str(value) - for name, value in provided_params.items() - } - - if "PARAMETER:" in content: - print("Found structured format with PARAMETER/VALUE/REASONING") - parameter_sections = content.split("PARAMETER:") - - for section in parameter_sections[1:]: - lines = section.strip().split("\n") - param_name = lines[0].strip() - - # Initialize collectors for multi-line values - value_lines = [] - reasoning_lines = [] - collecting_value = False - collecting_reasoning = False - - # Process remaining lines - for line in lines[1:]: - line = line.strip() + for section in sections: + lines = section.strip().split("\n") + for line in lines: + if line.startswith("PARAMETER:"): + # If we have a complete suggestion, add it + if current_param and current_value is not None: + param_config = strategy_config.get(current_param) + if param_config: + suggestion = ParameterSuggestion( + parameter_name=current_param, + suggested_value=current_value, + reasoning=current_reasoning or "", + differs_from_default=str(current_value) != str(param_config.default), + differs_from_provided=current_param in provided_params and str(current_value) != str(provided_params[current_param]) + ) + suggestions.append(suggestion) - if line.startswith("VALUE:"): - collecting_value = True - collecting_reasoning = False - value_lines.append(line.replace("VALUE:", "").strip()) - elif line.startswith("REASONING:"): - collecting_value = False - collecting_reasoning = True - reasoning_lines.append(line.replace("REASONING:", "").strip()) - elif line.startswith("SUMMARY:"): - collecting_value = False - collecting_reasoning = False - summary = line.replace("SUMMARY:", "").strip() - else: - # Continue collecting multi-line values - if collecting_value and line: - value_lines.append(line) - elif collecting_reasoning and line: - reasoning_lines.append(line) - - # Process collected values - if param_name and value_lines and param_name not in seen_params: - seen_params.add(param_name) + # Start a new suggestion + current_param = line.replace("PARAMETER:", "").strip() + current_value = None + current_reasoning = None - # Join multi-line values and try to parse as JSON if it looks like a JSON structure - value = "\n".join(value_lines) - if value.strip().startswith("{") and value.strip().endswith("}"): - try: - parsed_value = json.loads(value) - value = json.dumps(parsed_value) - except json.JSONDecodeError: - pass + elif line.startswith("VALUE:"): + current_value = line.replace("VALUE:", "").strip() - # Compare with default and provided values - differs_from_default = ( - param_name in default_values and - default_values[param_name] is not None and - value != default_values[param_name] - ) - differs_from_provided = ( - param_name in provided_values and - value != provided_values[param_name] - ) + elif line.startswith("REASONING:"): + current_reasoning = line.replace("REASONING:", "").strip() - suggestions.append(ParameterSuggestion( - parameter_name=param_name, - suggested_value=value, - reasoning="\n".join(reasoning_lines) if reasoning_lines else "No reasoning provided", - differs_from_default=differs_from_default, - differs_from_provided=differs_from_provided - )) - - if summary: - suggestions.append(ParameterSuggestion( - parameter_name="summary", - suggested_value=summary, - reasoning="Summary of the suggested configuration", - differs_from_default=False, - differs_from_provided=False - )) + elif line.startswith("SUMMARY:"): + # Create a summary suggestion + summary_value = line.replace("SUMMARY:", "").strip() + if summary_value: + summary_suggestion = ParameterSuggestion( + parameter_name="summary", + suggested_value=summary_value, + reasoning="Overall summary of suggestions", + differs_from_default=False, + differs_from_provided=False + ) + + elif current_reasoning is not None: + # Append additional reasoning lines + current_reasoning += " " + line.strip() - print(f"\nTotal suggestions parsed: {len(suggestions)}") - for s in suggestions: - print(f"- {s.parameter_name}: {s.suggested_value}") - if s.differs_from_default: - print(f" (differs from default: {s.differs_from_default})") - if s.differs_from_provided: - print(f" (differs from provided: {s.differs_from_provided})") + # Add the last suggestion if complete + if current_param and current_value is not None: + param_config = strategy_config.get(current_param) + if param_config: + suggestion = ParameterSuggestion( + parameter_name=current_param, + suggested_value=current_value, + reasoning=current_reasoning or "", + differs_from_default=str(current_value) != str(param_config.default), + differs_from_provided=current_param in provided_params and str(current_value) != str(provided_params[current_param]) + ) + suggestions.append(suggestion) - return suggestions + # Add the summary suggestion at the end if we have one + if summary_suggestion: + suggestions.append(summary_suggestion) except Exception as e: - print(f"ERROR: Failed to parse AI response: {str(e)}") - print(f"Raw response: {json.dumps(ai_response, indent=2)}") - return [] \ No newline at end of file + logger.error(f"Error parsing AI response: {str(e)}") + + return suggestions + + async def close(self): + """Close the AI client.""" + if isinstance(self.ai_client, LibertAIClient): + await self.ai_client.close() \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_backtesting_service.py b/tests/integration/test_backtesting_service.py new file mode 100644 index 0000000..677f378 --- /dev/null +++ b/tests/integration/test_backtesting_service.py @@ -0,0 +1,313 @@ +import pytest +from decimal import Decimal +from typing import Dict, Any +from datetime import datetime, timedelta +from services.backtesting_service import BacktestingService, BacktestError +from routers.strategies_models import ( + StrategyMapping, + StrategyType, + StrategyParameter, + ParameterConstraints, + ParameterGroup, + DisplayType, StrategyNotFoundError +) +from routers.backtest_models import BacktestingConfig, BacktestResponse + +@pytest.fixture +def backtesting_service(): + """Create a real backtesting service instance for integration tests""" + return BacktestingService() + +@pytest.fixture +def recent_timestamps(): + """Get recent timestamps for testing""" + now = datetime.now() + start = now - timedelta(hours=1) + return { + "start": int(start.timestamp()), + "end": int(now.timestamp()) + } + +@pytest.fixture +def bollinger_config(recent_timestamps): + """Valid Bollinger Bands strategy configuration""" + return BacktestingConfig( + start_time=recent_timestamps["start"], + end_time=recent_timestamps["end"], + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "controller_name": "bollinger_v1", + "bb_length": 100, + "bb_std": 2.0, + "bb_long_threshold": 0.0, + "bb_short_threshold": 1.0, + "trading_pair": "BTC-USDT", + "leverage": 1, + "interval": "1m", + "stop_loss": 0.03, + "take_profit": 0.02, + "connector_name": "binance_perpetual", + "candles_connector": "binance_perpetual", + "candles_trading_pair": "BTC-USDT" + } + ) + +@pytest.fixture +def pmm_config(recent_timestamps): + """Valid Pure Market Making strategy configuration""" + return BacktestingConfig( + start_time=recent_timestamps["start"], + end_time=recent_timestamps["end"], + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "controller_name": "pmm_simple", + "trading_pair": "BTC-USDT", + "leverage": 1, + "interval": "1m", + "bid_spread": 0.002, + "ask_spread": 0.002, + "order_amount": 0.01, + "order_refresh_time": 60, + "max_order_age": 1800, + "connector_name": "binance_perpetual", + "candles_connector": "binance_perpetual", + "candles_trading_pair": "BTC-USDT" + } + ) + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_directional_trading_workflow(backtesting_service, bollinger_config): + """Test complete backtesting workflow with Bollinger Bands strategy""" + response = await backtesting_service.run_backtesting(bollinger_config) + + # Verify response structure + assert isinstance(response, BacktestResponse) + assert response.processed_data is not None + assert response.results is not None + + # Verify processed data contains required features + features = response.processed_data.features + assert isinstance(features, dict) + assert len(features) > 0 + assert "BBL_100_2.0" in features # Lower Bollinger Band + assert "BBM_100_2.0" in features # Middle Bollinger Band + assert "BBU_100_2.0" in features # Upper Bollinger Band + assert "BBP_100_2.0" in features # Bollinger Band Position + + # Verify executors + assert isinstance(response.executors, list) + if response.executors: + executor = response.executors[0] + assert executor.level_id is not None + assert executor.timestamp is not None + assert executor.connector_name is not None + assert executor.trading_pair == "BTC-USDT" + assert isinstance(executor.entry_price, Decimal) + assert isinstance(executor.amount, Decimal) + assert executor.side in ["BUY", "SELL"] + assert executor.leverage == 1 + assert executor.position_mode == "ONEWAY" + + # Verify results + results = response.results + assert isinstance(results.total_trades, int) + assert isinstance(results.win_rate, float) + assert isinstance(results.total_pnl, Decimal) + assert isinstance(results.sharpe_ratio, float) + assert isinstance(results.max_drawdown, float) + assert isinstance(results.profit_loss_ratio, float) + assert results.start_timestamp == bollinger_config.start_time + assert results.end_timestamp == bollinger_config.end_time + assert results.win_rate >= 0 and results.win_rate <= 1 + assert results.max_drawdown >= 0 and results.max_drawdown <= 1 + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_market_making_workflow(backtesting_service, pmm_config): + """Test complete backtesting workflow with Pure Market Making strategy""" + response = await backtesting_service.run_backtesting(pmm_config) + + # Verify response structure + assert isinstance(response, BacktestResponse) + assert response.processed_data is not None + assert response.results is not None + + # Verify processed data contains required features + features = response.processed_data.features + assert isinstance(features, dict) + assert len(features) > 0 + assert "price" in features + assert "volume" in features + + # Verify executors + assert isinstance(response.executors, list) + if response.executors: + executor = response.executors[0] + assert executor.level_id is not None + assert executor.timestamp is not None + assert executor.connector_name is not None + assert executor.trading_pair == "BTC-USDT" + assert isinstance(executor.entry_price, Decimal) + assert isinstance(executor.amount, Decimal) + assert executor.side in ["BUY", "SELL"] + assert executor.leverage == 1 + assert executor.position_mode == "ONEWAY" + + # Verify results + results = response.results + assert isinstance(results.total_trades, int) + assert isinstance(results.win_rate, float) + assert isinstance(results.total_pnl, Decimal) + assert isinstance(results.sharpe_ratio, float) + assert isinstance(results.max_drawdown, float) + assert isinstance(results.profit_loss_ratio, float) + assert results.start_timestamp == pmm_config.start_time + assert results.end_timestamp == pmm_config.end_time + assert results.win_rate >= 0 and results.win_rate <= 1 + assert results.max_drawdown >= 0 and results.max_drawdown <= 1 + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_different_time_ranges(backtesting_service, bollinger_config): + """Test backtesting with different time ranges and resolutions""" + # Test with 1-hour data + bollinger_config.backtesting_resolution = "1h" + bollinger_config.config["interval"] = "1h" + response = await backtesting_service.run_backtesting(bollinger_config) + assert response.results.end_timestamp - response.results.start_timestamp == bollinger_config.end_time - bollinger_config.start_time + + # Test with 15-minute data + bollinger_config.backtesting_resolution = "15m" + bollinger_config.config["interval"] = "15m" + response = await backtesting_service.run_backtesting(bollinger_config) + assert response.results.end_timestamp - response.results.start_timestamp == bollinger_config.end_time - bollinger_config.start_time + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_parameter_validation(backtesting_service, bollinger_config): + """Test parameter validation for strategy configurations""" + # Test invalid BB length + invalid_config = bollinger_config.copy() + invalid_config.config["bb_length"] = 0 + with pytest.raises(BacktestError, match="Invalid parameter"): + await backtesting_service.run_backtesting(invalid_config) + + # Test invalid BB std + invalid_config = bollinger_config.copy() + invalid_config.config["bb_std"] = -1.0 + with pytest.raises(BacktestError, match="Invalid parameter"): + await backtesting_service.run_backtesting(invalid_config) + + # Test invalid leverage + invalid_config = bollinger_config.copy() + invalid_config.config["leverage"] = 0 + with pytest.raises(BacktestError, match="Invalid parameter"): + await backtesting_service.run_backtesting(invalid_config) + + # Test invalid trading pair format + invalid_config = bollinger_config.copy() + invalid_config.config["trading_pair"] = "BTCUSDT" # Missing hyphen + with pytest.raises(BacktestError, match="Invalid trading pair format"): + await backtesting_service.run_backtesting(invalid_config) + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_invalid_controller_name(backtesting_service, recent_timestamps): + """Test backtesting with non-existent strategy ID""" + config = BacktestingConfig( + start_time=recent_timestamps["start"], + end_time=recent_timestamps["end"], + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "controller_name": "nonexistent_strategy" + } + ) + + with pytest.raises(StrategyNotFoundError, match="Strategy not found"): + await backtesting_service.run_backtesting(config) + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_invalid_time_range(backtesting_service, recent_timestamps): + """Test backtesting with invalid time range""" + config = BacktestingConfig( + start_time=recent_timestamps["end"], + end_time=recent_timestamps["start"], + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "controller_name": "bollinger_v1", + "bb_length": 100, + "bb_std": 2.0, + "bb_long_threshold": 0.0, + "bb_short_threshold": 1.0, + "trading_pair": "BTC-USDT", + "leverage": 1, + "interval": "1m" + } + ) + + with pytest.raises(BacktestError, match="Invalid time range"): + await backtesting_service.run_backtesting(config) + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_missing_required_parameters(backtesting_service, recent_timestamps): + """Test backtesting with missing required strategy parameters""" + config = BacktestingConfig( + start_time=recent_timestamps["start"], + end_time=recent_timestamps["end"], + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "controller_name": "bollinger_v1", + # Missing required parameters + } + ) + + with pytest.raises(BacktestError): + await backtesting_service.run_backtesting(config) + +@pytest.mark.integration +def test_get_available_engines_integration(backtesting_service): + """Test getting available backtesting engines""" + engines = backtesting_service.get_available_engines() + + assert isinstance(engines, dict) + assert len(engines) > 0 + assert "directional_trading" in engines + assert "market_making" in engines + +@pytest.mark.integration +def test_get_engine_config_schema_integration(backtesting_service): + """Test getting engine configuration schema""" + # Test directional trading schema + dt_schema = backtesting_service.get_engine_config_schema("directional_trading") + assert isinstance(dt_schema, dict) + assert dt_schema["type"] == "object" + assert "properties" in dt_schema + assert "controller_name" in dt_schema["properties"] + assert "trading_pair" in dt_schema["properties"] + assert "leverage" in dt_schema["properties"] + assert "interval" in dt_schema["properties"] + assert "stop_loss" in dt_schema["properties"] + assert "take_profit" in dt_schema["properties"] + assert "required" in dt_schema + print(dt_schema["required"]) + assert "trading_pair" in dt_schema["required"] + + # Test market making schema + mm_schema = backtesting_service.get_engine_config_schema("market_making") + assert isinstance(mm_schema, dict) + assert mm_schema["type"] == "object" + assert "properties" in mm_schema + assert "trading_pair" in mm_schema["properties"] + assert "bid_spread" in mm_schema["properties"] + assert "ask_spread" in mm_schema["properties"] + assert "order_amount" in mm_schema["properties"] + assert "required" in mm_schema + assert "trading_pair" in mm_schema["required"] \ No newline at end of file diff --git a/tests/integration/test_libert_ai_service.py b/tests/integration/test_libert_ai_service.py new file mode 100644 index 0000000..a40d74a --- /dev/null +++ b/tests/integration/test_libert_ai_service.py @@ -0,0 +1,75 @@ +import pytest +from services.libert_ai_service import LibertAIService +from routers.strategies_models import discover_strategies + +@pytest.fixture +def libert_ai_service(): + service = LibertAIService() + return service + +@pytest.fixture +def strategy_configs(): + """Load all available strategies""" + return discover_strategies() + +@pytest.fixture +def bollinger_strategy(strategy_configs): + """Get the Bollinger strategy configuration""" + return strategy_configs["bollinger_v1"] + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_initialize_system_context(libert_ai_service): + """Integration test: Test system context initialization""" + await libert_ai_service._initialize_system_context() + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_initialize_strategy_context(libert_ai_service, bollinger_strategy): + """Integration test: Test strategy context initialization""" + with open(f"bots/controllers/{bollinger_strategy.module_path.split('.')[-1]}.py", "r") as f: + strategy_code = f.read() + + await libert_ai_service._initialize_strategy_context( + strategy=bollinger_strategy, + strategy_code=strategy_code, + slot_id=0 + ) + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_get_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Integration test: Test parameter suggestion generation""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0} + ) + + # Verify suggestions are part of the bollinger_strategy.parameters + for suggestion in suggestions: + assert suggestion.parameter_name in bollinger_strategy.parameters or suggestion.parameter_name == "summary" + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_get_specific_parameter_suggestions(libert_ai_service, bollinger_strategy): + """Integration test: Test getting suggestions for specific parameters""" + + libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 + + # Request suggestions for specific parameters + requested_params = ["bb_length", "bb_long_threshold"] + suggestions = await libert_ai_service.get_parameter_suggestions( + strategy_id="bollinger_v1", + strategy_config=bollinger_strategy.parameters, + provided_params={"bb_std": 2.0}, + requested_params=requested_params + ) + + # Verify we only got suggestions for the requested parameters (plus summary) + assert len(suggestions) == 3 # 2 requested parameters + summary + suggestion_params = {s.parameter_name for s in suggestions if s.parameter_name != "summary"} + assert suggestion_params == set(requested_params) \ No newline at end of file diff --git a/tests/test_libert_ai_service.py b/tests/test_libert_ai_service.py deleted file mode 100644 index b8c38c6..0000000 --- a/tests/test_libert_ai_service.py +++ /dev/null @@ -1,164 +0,0 @@ -import pytest -from services.libert_ai_service import LibertAIService -from routers.strategies_models import ( - ParameterSuggestion, - discover_strategies, -) -from typing import Any -from dataclasses import dataclass - -@pytest.fixture -def libert_ai_service(): - service = LibertAIService() - return service - -@pytest.fixture -def strategy_configs(): - """Load all available strategies""" - return discover_strategies() - -@pytest.fixture -def bollinger_strategy(strategy_configs): - """Get the Bollinger strategy configuration""" - return strategy_configs["bollinger_v1"] - -@pytest.mark.asyncio -async def test_initialize_system_context(libert_ai_service): - """Test system context initialization""" - await libert_ai_service._initialize_system_context() - -@pytest.mark.asyncio -async def test_initialize_strategy_context(libert_ai_service, bollinger_strategy): - """Test strategy context initialization""" - with open(f"bots/{bollinger_strategy.mapping.module_path.split('bots.')[-1].replace('.', '/')}.py", "r") as f: - strategy_code = f.read() - - await libert_ai_service._initialize_strategy_context( - strategy_mapping=bollinger_strategy.mapping, - strategy_config=bollinger_strategy.parameters, - strategy_code=strategy_code, - slot_id=0 - ) - -@pytest.mark.asyncio -async def test_get_parameter_suggestions(libert_ai_service, bollinger_strategy): - """Test parameter suggestion generation""" - - libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 - - suggestions = await libert_ai_service.get_parameter_suggestions( - strategy_id="bollinger_v1", - strategy_config=bollinger_strategy.parameters, - provided_params={"bb_std": 2.0} - ) - - # Verify suggestions are part of the bollinger_strategy.parameters - for suggestion in suggestions: - assert suggestion.parameter_name in bollinger_strategy.parameters or suggestion.parameter_name == "summary" - -@pytest.mark.asyncio -async def test_parse_ai_response(libert_ai_service): - """Test AI response parsing""" - ai_response = { - "choices": [{ - "message": { - "content": """ -PARAMETER: bb_length -VALUE: 100 -REASONING: Standard length for Bollinger Bands calculation provides reliable signals while filtering out noise. - -PARAMETER: bb_long_threshold -VALUE: 0.2 -REASONING: Enter long positions when price is 20% below the middle band, indicating oversold conditions. - -SUMMARY: These parameters are optimized for mean reversion trading using Bollinger Bands. -""" - } - }] - } - - # Mock strategy config with proper parameter objects - @dataclass - class MockParameter: - name: str - default: Any - required: bool = True - type: str = "float" - - strategy_config = { - "bb_length": MockParameter( - name="BB Length", - default=20, - type="int" - ), - "bb_long_threshold": MockParameter( - name="BB Long Threshold", - default=0.1, - type="float" - ) - } - - provided_params = { - "bb_length": 20 - } - - suggestions = libert_ai_service._parse_ai_response( - ai_response, - strategy_config=strategy_config, - provided_params=provided_params - ) - - assert len(suggestions) == 3 # 2 parameters + summary - assert all(isinstance(s, ParameterSuggestion) for s in suggestions) - assert suggestions[0].parameter_name == "bb_length" - assert suggestions[0].suggested_value == "100" - assert suggestions[0].differs_from_default is True # 100 vs default 20 - assert suggestions[0].differs_from_provided is True # 100 vs provided 20 - assert suggestions[1].parameter_name == "bb_long_threshold" - assert suggestions[1].suggested_value == "0.2" - assert suggestions[1].differs_from_default is True # 0.2 vs default 0.1 - assert suggestions[1].differs_from_provided is False # Not provided - assert suggestions[2].parameter_name == "summary" - assert suggestions[2].suggested_value == "These parameters are optimized for mean reversion trading using Bollinger Bands." - -@pytest.mark.asyncio -async def test_parse_ai_response_handles_invalid_format(libert_ai_service): - """Test handling of invalid AI response format""" - invalid_response = { - "choices": [{ - "message": { - "content": "Invalid format response" - } - }] - } - - # Mock empty config with proper parameter objects - strategy_config = {} - provided_params = {} - - suggestions = libert_ai_service._parse_ai_response( - invalid_response, - strategy_config=strategy_config, - provided_params=provided_params - ) - assert suggestions == [] - -@pytest.mark.asyncio -async def test_get_specific_parameter_suggestions(libert_ai_service, bollinger_strategy): - """Test getting suggestions for specific parameters""" - - libert_ai_service.strategy_slot_map["bollinger_v1"] = 0 - - # Request suggestions for specific parameters - requested_params = ["bb_length", "bb_long_threshold"] - suggestions = await libert_ai_service.get_parameter_suggestions( - strategy_id="bollinger_v1", - strategy_config=bollinger_strategy.parameters, - provided_params={"bb_std": 2.0}, - requested_params=requested_params - ) - - # Verify we only got suggestions for the requested parameters (plus summary) - assert len(suggestions) == 3 # 2 requested parameters + summary - suggestion_params = {s.parameter_name for s in suggestions if s.parameter_name != "summary"} - assert suggestion_params == set(requested_params) \ No newline at end of file diff --git a/tests/test_strategies.py b/tests/test_strategies.py deleted file mode 100644 index 91f3b16..0000000 --- a/tests/test_strategies.py +++ /dev/null @@ -1,234 +0,0 @@ -import pytest -from unittest.mock import Mock, patch, MagicMock -from decimal import Decimal -from typing import Dict, Any -from pydantic import BaseModel, Field -from hummingbot.strategy_v2.controllers import ControllerConfigBase - -from routers.strategies_models import ( - StrategyType, - StrategyMapping, - StrategyParameter, - StrategyConfig, - discover_strategies, - generate_strategy_mapping, - convert_to_strategy_parameter, - infer_strategy_type -) - -# Mock strategy config class for testing -class MockStrategyConfig(ControllerConfigBase): - """Test strategy for unit testing""" - controller_name = "test_strategy_v1" - - stop_loss: Decimal = Field( - default=Decimal("0.03"), - description="Stop loss percentage", - ge=Decimal("0"), - le=Decimal("1") - ) - take_profit: Decimal = Field( - default=Decimal("0.02"), - description="Take profit percentage", - ge=Decimal("0"), - le=Decimal("1") - ) - time_limit: int = Field( - default=2700, - description="Time limit in seconds", - gt=0 - ) - leverage: int = Field( - default=20, - description="Leverage multiplier", - gt=0 - ) - trading_pair: str = Field( - default="BTC-USDT", - description="Trading pair to use" - ) - -# Test data -MOCK_MODULE_PATH = "bots.controllers.directional_trading.test_strategy_v1" - -@pytest.fixture -def mock_strategy_config(): - return MockStrategyConfig - -@pytest.fixture(autouse=True) -def mock_importlib(): - with patch("importlib.import_module") as mock: - mock.return_value = MagicMock( - __name__="test_module", - MockStrategyConfig=MockStrategyConfig - ) - yield mock - -@pytest.fixture(autouse=True) -def mock_os_walk(): - with patch("os.walk") as mock: - mock.return_value = [ - ("bots/controllers/directional_trading", [], ["test_strategy_v1.py"]), - ] - yield mock - -@pytest.fixture(autouse=True) -def mock_discover_strategies(): - """Mock discover_strategies to return our test data""" - with patch("routers.strategies_models.discover_strategies", autospec=True) as mock: - mock.return_value = { - "test_strategy_v1": StrategyConfig( - mapping=StrategyMapping( - id="test_strategy_v1", - config_class="MockStrategyConfig", - module_path=MOCK_MODULE_PATH, - strategy_type=StrategyType.DIRECTIONAL_TRADING, - display_name="Test Strategy V1", - description="Test strategy for unit testing" - ), - parameters={ - "stop_loss": StrategyParameter( - name="stop_loss", - pretty_name="Stop Loss", - description="Stop loss percentage", - group="Risk Management", - type="Decimal", - prompt="Enter stop loss value", - default=Decimal("0.03"), - required=True, - min_value=Decimal("0"), - max_value=Decimal("1") - ), - "take_profit": StrategyParameter( - name="take_profit", - pretty_name="Take Profit", - description="Take profit percentage", - group="Risk Management", - type="Decimal", - prompt="Enter take profit value", - default=Decimal("0.02"), - required=True, - min_value=Decimal("0"), - max_value=Decimal("1") - ), - "time_limit": StrategyParameter( - name="time_limit", - pretty_name="Time Limit", - description="Time limit in seconds", - group="General Settings", - type="int", - prompt="Enter time limit in seconds", - default=2700, - required=True, - min_value=0 - ), - "leverage": StrategyParameter( - name="leverage", - pretty_name="Leverage", - description="Leverage multiplier", - group="Risk Management", - type="int", - prompt="Enter leverage multiplier", - default=20, - required=True, - min_value=1, - is_advanced=True - ), - "trading_pair": StrategyParameter( - name="trading_pair", - pretty_name="Trading Pair", - description="Trading pair to use", - group="General Settings", - type="str", - prompt="Enter trading pair", - default="BTC-USDT", - required=True, - is_trading_pair=True - ) - } - ) - } - yield mock - -def test_infer_strategy_type(): - """Test strategy type inference from module path""" - assert infer_strategy_type("bots.controllers.directional_trading.test", None) == StrategyType.DIRECTIONAL_TRADING - assert infer_strategy_type("bots.controllers.market_making.test", None) == StrategyType.MARKET_MAKING - assert infer_strategy_type("bots.controllers.generic.test", None) == StrategyType.GENERIC - -def test_generate_strategy_mapping(): - """Test strategy mapping generation""" - mapping = generate_strategy_mapping(MOCK_MODULE_PATH, MockStrategyConfig) - - assert mapping.id == "test_strategy_v1" - assert mapping.config_class == "MockStrategyConfig" - assert mapping.module_path == MOCK_MODULE_PATH - assert mapping.strategy_type == StrategyType.DIRECTIONAL_TRADING - assert mapping.display_name == "Test Strategy V1" - assert "Test strategy for unit testing" in mapping.description - -def test_convert_to_strategy_parameter(): - """Test parameter conversion from config field""" - # Get a field from the mock config - field = MockStrategyConfig.__fields__["stop_loss"] - param = convert_to_strategy_parameter("stop_loss", field) - - assert param.pretty_name == "Stop Loss" - assert param.group == "Risk Management" - assert param.type == "ConstrainedDecimalValue" # We want the base type, not the constrained type - assert param.default == Decimal("0.03") - assert param.required is True - assert param.min_value == Decimal("0") - assert param.max_value == Decimal("1") - assert param.display_type == "slider" - -@pytest.mark.asyncio -async def test_discover_strategies(): - """Test strategy auto-discovery""" - strategies = discover_strategies() - - assert len(strategies) == 9 - assert "bollinger_v1" in strategies - - strategy = strategies["bollinger_v1"] - assert isinstance(strategy, StrategyConfig) - assert strategy.mapping.id == "bollinger_v1" - - # Check some parameters - assert "stop_loss" in strategy.parameters - assert "take_profit" in strategy.parameters - assert strategy.parameters["leverage"].is_advanced is True - assert strategy.parameters["trading_pair"].is_trading_pair is True - - -def test_parameter_validation(): - """Test parameter validation in strategy config""" - # Test required parameters - with pytest.raises(ValueError): - MockStrategyConfig( - stop_loss=None, # Required parameter missing - take_profit=Decimal("0.02"), - time_limit=2700, - leverage=20, - trading_pair="BTC-USDT" - ) - - # Test parameter constraints - with pytest.raises(ValueError): - MockStrategyConfig( - stop_loss=Decimal("-0.03"), # Negative value not allowed - take_profit=Decimal("0.02"), - time_limit=2700, - leverage=20, - trading_pair="BTC-USDT" - ) - -def test_strategy_type_enum(): - """Test StrategyType enum values""" - assert StrategyType.DIRECTIONAL_TRADING == "directional_trading" - assert StrategyType.MARKET_MAKING == "market_making" - assert StrategyType.GENERIC == "generic" - - # Test that invalid types are not allowed - with pytest.raises(ValueError): - StrategyType("invalid_type") \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/fixtures/backtest_results.json b/tests/unit/fixtures/backtest_results.json new file mode 100644 index 0000000..305fbf1 --- /dev/null +++ b/tests/unit/fixtures/backtest_results.json @@ -0,0 +1,39 @@ +{ + "executors": [ + { + "id": "executor_1", + "level_id": "level_1", + "timestamp": 1735457150, + "connector_name": "binance", + "trading_pair": "BTC-USDT", + "entry_price": 50000.0, + "amount": 0.1, + "side": "BUY", + "leverage": 1, + "position_mode": "ONEWAY", + "trades": 10, + "win_rate": 0.6, + "profit_loss": 150.25 + } + ], + "processed_data": { + "features": { + "price": [50100.0, 50200.0, 50300.0], + "volume": [10.5, 12.3, 8.7], + "BBL_100_2.0": [49800.0, 49900.0, 50000.0], + "BBM_100_2.0": [50000.0, 50100.0, 50200.0], + "BBU_100_2.0": [50200.0, 50300.0, 50400.0], + "BBP_100_2.0": [0.75, 0.80, 0.85] + } + }, + "results": { + "total_pnl": 150.25, + "total_trades": 10, + "win_rate": 0.6, + "profit_loss_ratio": 1.5, + "sharpe_ratio": 1.5, + "max_drawdown": 0.05, + "start_timestamp": 1735457150, + "end_timestamp": 1735460750 + } +} \ No newline at end of file diff --git a/tests/unit/fixtures/market_data.json b/tests/unit/fixtures/market_data.json new file mode 100644 index 0000000..dd8492f --- /dev/null +++ b/tests/unit/fixtures/market_data.json @@ -0,0 +1,12 @@ +{ + "timestamp": [1735457150, 1735457210, 1735457270], + "open": [50000.0, 50100.0, 50200.0], + "high": [50200.0, 50300.0, 50400.0], + "low": [49800.0, 49900.0, 50000.0], + "close": [50100.0, 50200.0, 50300.0], + "volume": [10.5, 12.3, 8.7], + "quote_asset_volume": [525000.0, 615000.0, 435000.0], + "n_trades": [100, 120, 80], + "taker_buy_base_volume": [6.3, 7.4, 5.2], + "taker_buy_quote_volume": [315000.0, 369000.0, 261000.0] +} \ No newline at end of file diff --git a/tests/unit/fixtures/strategy_config.json b/tests/unit/fixtures/strategy_config.json new file mode 100644 index 0000000..b998d41 --- /dev/null +++ b/tests/unit/fixtures/strategy_config.json @@ -0,0 +1,12 @@ +{ + "strategy_id": "bollinger_v1", + "stop_loss": 0.03, + "take_profit": 0.02, + "trading_pair": "BTC-USDT", + "leverage": 1, + "bb_length": 100, + "bb_std": 2.0, + "bb_long_threshold": 0.0, + "bb_short_threshold": 1.0, + "interval": "1m" +} \ No newline at end of file diff --git a/tests/unit/test_backtest.py b/tests/unit/test_backtest.py new file mode 100644 index 0000000..17f8f6b --- /dev/null +++ b/tests/unit/test_backtest.py @@ -0,0 +1,265 @@ +import json +from pathlib import Path + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from fastapi import HTTPException +import pandas as pd + +from routers.backtest import ( + run_backtesting, + get_available_engines, + get_engine_config_schema, +) +from routers.backtest_models import BacktestResponse, BacktestResults, BacktestingConfig, ExecutorInfo, ProcessedData +from services.backtesting_service import ( + BacktestConfigError, + BacktestEngineError, +) +from routers.strategies_models import ( + StrategyType, + StrategyMapping, + StrategyConfig, + StrategyParameter, + ParameterGroup, +) + +# Load test fixtures +FIXTURES_DIR = Path(__file__).parent / "fixtures" + +def load_fixture(filename: str) -> dict: + with open(FIXTURES_DIR / filename) as f: + return json.load(f) + +MOCK_STRATEGY_CONFIG = load_fixture("strategy_config.json") +MOCK_BACKTEST_RESULTS = load_fixture("backtest_results.json") +MOCK_MARKET_DATA = load_fixture("market_data.json") + +@pytest.fixture +def mock_market_data_provider(): + """Mock the market data provider to return test data""" + mock = Mock() + df = pd.DataFrame(MOCK_MARKET_DATA) + mock.get_candles_df.return_value = df + return mock + +@pytest.fixture +def mock_strategy(): + """Mock strategy configuration from registry""" + return StrategyConfig( + id="bollinger_v1", + name="Bollinger Bands V1", + description="Trading strategy based on Bollinger Bands", + mapping=StrategyMapping( + id="bollinger_v1", + config_class="BollingerV1ControllerConfig", + module_path="bots.controllers.directional_trading.bollinger_v1", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + display_name="Bollinger Bands Strategy", + description="Trading strategy based on Bollinger Bands" + ), + parameters={ + "bb_length": StrategyParameter( + name="bb_length", + display_name="BB Length", + description="Length of Bollinger Bands", + type="integer", + default=100, + min_value=20, + max_value=200, + group=ParameterGroup.INDICATORS, + is_advanced=False, + required=True + ), + "bb_std": StrategyParameter( + name="bb_std", + display_name="BB Standard Deviation", + description="Standard deviation multiplier", + type="float", + default=2.0, + min_value=1.0, + max_value=3.0, + group=ParameterGroup.INDICATORS, + is_advanced=False, + required=True + ), + "bb_long_threshold": StrategyParameter( + name="bb_long_threshold", + display_name="BB Long Threshold", + description="Long entry threshold", + type="float", + default=0.0, + min_value=0.0, + max_value=1.0, + group=ParameterGroup.INDICATORS, + is_advanced=False, + required=True + ), + "bb_short_threshold": StrategyParameter( + name="bb_short_threshold", + display_name="BB Short Threshold", + description="Short entry threshold", + type="float", + default=1.0, + min_value=0.0, + max_value=1.0, + group=ParameterGroup.INDICATORS, + is_advanced=False, + required=True + ) + } + ) + +@pytest.fixture +def mock_registry(mock_strategy): + """Mock strategy registry""" + with patch("services.backtesting_service.StrategyRegistry") as mock_registry: + mock_registry.get_strategy.return_value = mock_strategy + yield mock_registry + +@pytest.fixture +def mock_backtesting_service(mock_registry): + """Mock backtesting service""" + with patch("routers.backtest.backtesting_service") as mock_service: + mock_response = BacktestResponse( + executors=[ + ExecutorInfo( + id="executor_1", + level_id="level_1", + timestamp=1735457150, + connector_name="binance", + trading_pair="BTC-USDT", + entry_price=50000.0, + amount=0.1, + side="BUY", + leverage=1, + position_mode="ONEWAY", + trades=10, + win_rate=0.6, + profit_loss=150.25 + ) + ], + processed_data=ProcessedData( + features={ + "price": [50100.0, 50200.0, 50300.0], + "volume": [10.5, 12.3, 8.7], + "BBL_100_2.0": [49800.0, 49900.0, 50000.0], + "BBM_100_2.0": [50000.0, 50100.0, 50200.0], + "BBU_100_2.0": [50200.0, 50300.0, 50400.0], + "BBP_100_2.0": [0.75, 0.80, 0.85] + } + ), + results=BacktestResults( + total_trades=10, + win_rate=0.6, + total_pnl=150.25, + sharpe_ratio=1.5, + profit_loss_ratio=1.5, + max_drawdown=0.05, + start_timestamp=1735457150, + end_timestamp=1735457450 + ) + ) + mock_service.run_backtesting = AsyncMock() + mock_service.run_backtesting.return_value = mock_response + yield mock_service + +@pytest.fixture +def valid_backtest_config(): + """Valid backtest configuration""" + return BacktestingConfig( + start_time=1735458769, + end_time=1735462369, + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "strategy_id": "bollinger_v1", + "trading_pair": "BTC-USDT", + "leverage": 1, + "bb_length": 100, + "bb_std": 2.0, + "bb_long_threshold": 0.0, + "bb_short_threshold": 1.0, + "interval": "1m" + } + ) + +# Unit Tests + +@pytest.mark.asyncio +async def test_successful_backtest(valid_backtest_config, mock_backtesting_service): + """Test successful backtesting execution and results processing""" + response = await run_backtesting(valid_backtest_config) + + assert isinstance(response, BacktestResponse) + assert len(response.executors) == 1 + assert response.executors[0].trades == 10 + assert response.executors[0].win_rate == 0.6 + assert response.executors[0].profit_loss == 150.25 + + assert isinstance(response.processed_data, ProcessedData) + assert "price" in response.processed_data.features + assert "volume" in response.processed_data.features + assert "BBL_100_2.0" in response.processed_data.features + + assert isinstance(response.results, BacktestResults) + assert response.results.total_trades == 10 + assert response.results.win_rate == 0.6 + assert response.results.total_pnl == 150.25 + assert response.results.sharpe_ratio == 1.5 + assert response.results.profit_loss_ratio == 1.5 + +@pytest.mark.asyncio +async def test_config_error(valid_backtest_config, mock_backtesting_service): + """Test handling of configuration errors""" + mock_backtesting_service.run_backtesting.side_effect = BacktestConfigError("Invalid config") + + with pytest.raises(HTTPException) as exc_info: + await run_backtesting(valid_backtest_config) + assert exc_info.value.status_code == 400 + assert str(exc_info.value.detail) == "Invalid config" + +@pytest.mark.asyncio +async def test_engine_error(valid_backtest_config, mock_backtesting_service): + """Test handling of engine errors""" + mock_backtesting_service.run_backtesting.side_effect = BacktestEngineError("Engine error") + + with pytest.raises(HTTPException) as exc_info: + await run_backtesting(valid_backtest_config) + assert exc_info.value.status_code == 500 + assert str(exc_info.value.detail) == "Engine error" + +def test_get_available_engines(mock_backtesting_service): + """Test getting available engines""" + mock_backtesting_service.get_available_engines.return_value = { + "directional_trading": "DirectionalTradingEngine", + "market_making": "MarketMakingEngine" + } + engines = get_available_engines() + assert isinstance(engines, dict) + assert "directional_trading" in engines.keys() + assert "market_making" in engines.keys() + assert engines["directional_trading"] == "DirectionalTradingEngine" + assert engines["market_making"] == "MarketMakingEngine" + +def test_get_engine_config_schema(mock_backtesting_service): + """Test getting engine configuration schema""" + mock_backtesting_service.get_engine_config_schema.return_value = { + "type": "object", + "properties": { + "stop_loss": {"type": "number"}, + "take_profit": {"type": "number"} + } + } + schema = get_engine_config_schema("directional_trading") + assert "type" in schema + assert "properties" in schema + +def test_get_engine_config_schema_not_found(mock_backtesting_service): + """Test getting configuration schema for non-existent engine""" + mock_backtesting_service.get_engine_config_schema.return_value = None + + with pytest.raises(HTTPException) as exc_info: + get_engine_config_schema("invalid_engine") + assert exc_info.value.status_code == 404 + assert "Engine type 'invalid_engine' not found" in str(exc_info.value.detail) \ No newline at end of file diff --git a/tests/unit/test_backtesting_service.py b/tests/unit/test_backtesting_service.py new file mode 100644 index 0000000..d50f835 --- /dev/null +++ b/tests/unit/test_backtesting_service.py @@ -0,0 +1,188 @@ +import pytest +from unittest.mock import Mock, patch, AsyncMock +from decimal import Decimal + +from services.backtesting_service import ( + BacktestingService, + BacktestError +) +from routers.backtest_models import BacktestingConfig, BacktestResponse +from routers.strategies_models import ( + StrategyRegistry, + StrategyNotFoundError, + StrategyType, + StrategyMapping, + StrategyParameter, + ParameterGroup, + DisplayType, + StrategyConfig +) + +@pytest.fixture +def mock_strategy_registry(monkeypatch): + """Mock strategy registry for testing""" + strategy = StrategyMapping( + id="bollinger_v1", + config_class="BollingerV1ControllerConfig", + module_path="bots.controllers.directional_trading.bollinger_v1", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + display_name="Bollinger Bands Strategy", + description="A strategy that uses Bollinger Bands for trading decisions", + parameters={ + "bb_length": StrategyParameter( + name="bb_length", + type="int", + required=True, + default=100, + display_name="BB Length", + description="Length of the Bollinger Bands period", + group=ParameterGroup.INDICATORS, + is_advanced=False, + display_type=DisplayType.INPUT + ), + "bb_std": StrategyParameter( + name="bb_std", + type="float", + required=True, + default=2.0, + display_name="BB Standard Deviation", + description="Number of standard deviations for the bands", + group=ParameterGroup.INDICATORS, + is_advanced=False, + display_type=DisplayType.INPUT + ) + } + ) + + def mock_get_strategy(strategy_id: str) -> StrategyMapping: + if strategy_id != "bollinger_v1": + raise StrategyNotFoundError(f"Strategy {strategy_id} not found") + return strategy + + monkeypatch.setattr(StrategyRegistry, "get_strategy", mock_get_strategy) + return strategy + +@pytest.fixture +def backtesting_service(): + """Create a backtesting service instance for testing""" + return BacktestingService() + +@pytest.mark.asyncio +async def test_validate_time_range(backtesting_service): + """Test time range validation""" + # Valid time range + backtesting_service.validate_time_range(1000, 2000) + + # Invalid time range + with pytest.raises(BacktestError, match="Invalid time range"): + backtesting_service.validate_time_range(2000, 1000) + +@pytest.mark.asyncio +async def test_transform_strategy_config_success(backtesting_service, mock_strategy_registry): + """Test successful strategy config transformation""" + config = { + "strategy_id": "bollinger_v1", + "bb_length": 100, + "bb_std": 2.0 + } + + result = backtesting_service.transform_strategy_config(config) + + assert result["controller_type"] == StrategyType.DIRECTIONAL_TRADING.value + assert result["controller_name"] == "bollinger_v1" + assert result["bb_length"] == 100 + assert result["bb_std"] == 2.0 + +@pytest.mark.asyncio +async def test_transform_strategy_config_missing_id(backtesting_service): + """Test strategy config transformation with missing ID""" + config = { + "bb_length": 100, + "bb_std": 2.0 + } + + with pytest.raises(BacktestError, match="Missing strategy_id"): + backtesting_service.transform_strategy_config(config) + +@pytest.mark.asyncio +async def test_transform_strategy_config_not_found(backtesting_service, mock_strategy_registry): + """Test strategy config transformation with non-existent strategy""" + config = { + "strategy_id": "nonexistent_strategy", + "bb_length": 100, + "bb_std": 2.0 + } + + with pytest.raises(BacktestError, match="Strategy not found"): + backtesting_service.transform_strategy_config(config) + +@pytest.mark.asyncio +async def test_run_backtesting_success(backtesting_service, mock_strategy_registry): + """Test successful backtesting run""" + config = BacktestingConfig( + start_time=1000, + end_time=2000, + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "strategy_id": "bollinger_v1", + "bb_length": 100, + "bb_std": 2.0 + } + ) + + # Mock the backtesting engine + mock_engine = AsyncMock() + mock_engine.run_backtesting.return_value = { + "executors": [], + "processed_data": {"features": {}}, + "results": { + "total_trades": 10, + "win_rate": 0.6, + "total_pnl": Decimal("100.0"), + "sharpe_ratio": 1.5, + "max_drawdown": 0.1, + "profit_loss_ratio": 2.0, + "start_timestamp": 1000, + "end_timestamp": 2000 + } + } + backtesting_service.backtesting_engines["directional_trading"] = mock_engine + + result = await backtesting_service.run_backtesting(config) + + assert isinstance(result, BacktestResponse) + assert result.results.total_trades == 10 + assert result.results.win_rate == 0.6 + assert result.results.total_pnl == Decimal("100.0") + +@pytest.mark.asyncio +async def test_run_backtesting_engine_error(backtesting_service, mock_strategy_registry): + """Test backtesting run with engine error""" + config = BacktestingConfig( + start_time=1000, + end_time=2000, + backtesting_resolution="1m", + trade_cost=0.001, + config={ + "strategy_id": "bollinger_v1", + "bb_length": 100, + "bb_std": 2.0 + } + ) + + # Mock the backtesting engine to raise an error + mock_engine = AsyncMock() + mock_engine.run_backtesting.side_effect = Exception("Engine error") + backtesting_service.backtesting_engines["directional_trading"] = mock_engine + + with pytest.raises(BacktestError, match="Error during backtesting execution"): + await backtesting_service.run_backtesting(config) + +def test_get_available_engines(backtesting_service): + """Test getting available backtesting engines""" + engines = backtesting_service.get_available_engines() + + assert isinstance(engines, dict) + assert "directional_trading" in engines + assert "market_making" in engines \ No newline at end of file diff --git a/tests/unit/test_libert_ai_client.py b/tests/unit/test_libert_ai_client.py new file mode 100644 index 0000000..0ffa1a6 --- /dev/null +++ b/tests/unit/test_libert_ai_client.py @@ -0,0 +1,112 @@ +import pytest +from unittest.mock import AsyncMock, patch +from services.libert_ai_service import LibertAIClient +from typing import Dict, Any + +pytestmark = pytest.mark.asyncio + +@pytest.fixture +def mock_api_response(): + return {"content": "Test response"} + +@pytest.fixture +def client(): + return LibertAIClient(api_url="http://test.api") + +async def test_build_request_payload(client): + """Test payload construction with different parameters""" + # Test basic payload + payload = client._build_request_payload("test prompt") + assert payload["prompt"] == "test prompt" + assert payload["temperature"] == 0.9 + assert payload["top_p"] == 1 + assert payload["top_k"] == 40 + assert payload["n"] == 1 + assert payload["n_predict"] == 100 + assert payload["stop"] == ["<|im_end|>"] + assert "slot_id" not in payload + assert "parent_slot_id" not in payload + + # Test with slot_id + payload = client._build_request_payload("test prompt", slot_id=1) + assert payload["slot_id"] == 1 + assert "parent_slot_id" not in payload + + # Test with both slot_id and parent_slot_id + payload = client._build_request_payload("test prompt", slot_id=1, parent_slot_id=0) + assert payload["slot_id"] == 1 + assert payload["parent_slot_id"] == 0 + +async def test_initialize_system_context(client, mock_api_response): + """Test system context initialization""" + with patch.object(client, '_make_api_request', new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_api_response + + response = await client.initialize_system_context("test prompt") + + assert response == mock_api_response + mock_request.assert_called_once() + call_args = mock_request.call_args[0] + assert call_args[0] == "http://test.api" # URL + assert call_args[1]["prompt"] == "test prompt" # Payload + assert "slot_id" not in call_args[1] + assert "parent_slot_id" not in call_args[1] + +async def test_initialize_strategy_context(client, mock_api_response): + """Test strategy context initialization""" + with patch.object(client, '_make_api_request', new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_api_response + + response = await client.initialize_strategy_context("test prompt", slot_id=1) + + assert response == mock_api_response + mock_request.assert_called_once() + call_args = mock_request.call_args[0] + assert call_args[0] == "http://test.api" # URL + assert call_args[1]["prompt"] == "test prompt" # Payload + assert call_args[1]["slot_id"] == 1 + assert call_args[1]["parent_slot_id"] == -1 + +async def test_get_suggestions(client, mock_api_response): + """Test getting suggestions""" + with patch.object(client, '_make_api_request', new_callable=AsyncMock) as mock_request: + mock_request.return_value = mock_api_response + + response = await client.get_suggestions("test prompt", slot_id=1) + + assert response == mock_api_response + mock_request.assert_called_once() + call_args = mock_request.call_args[0] + assert call_args[0] == "http://test.api" # URL + assert call_args[1]["prompt"] == "test prompt" # Payload + assert call_args[1]["slot_id"] == 1 + assert "parent_slot_id" not in call_args[1] + +async def test_session_management(client): + """Test session creation and cleanup""" + # Test session creation + session = await client._ensure_session() + assert session is not None + assert client.session is session + + # Test session reuse + session2 = await client._ensure_session() + assert session2 is session + + # Test session cleanup + await client.close() + assert client.session is None + +async def test_api_request_error_handling(client): + """Test error handling in API requests""" + with patch.object(client, '_make_api_request', new_callable=AsyncMock) as mock_request: + mock_request.side_effect = ValueError("API request failed with status 404") + + with pytest.raises(ValueError, match="API request failed with status 404"): + await client.get_suggestions("test prompt", slot_id=1) + + mock_request.assert_called_once() + call_args = mock_request.call_args[0] + assert call_args[0] == "http://test.api" # URL + assert call_args[1]["prompt"] == "test prompt" # Payload + assert call_args[1]["slot_id"] == 1 \ No newline at end of file diff --git a/tests/unit/test_libert_ai_service.py b/tests/unit/test_libert_ai_service.py new file mode 100644 index 0000000..e778154 --- /dev/null +++ b/tests/unit/test_libert_ai_service.py @@ -0,0 +1,192 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from services.libert_ai_service import LibertAIService, LibertAIClient +from routers.strategies_models import ( + StrategyMapping, + StrategyType, + StrategyParameter, + ParameterConstraints, + ParameterGroup, + DisplayType +) + +pytestmark = pytest.mark.asyncio + +@pytest.fixture +def mock_ai_client(): + return AsyncMock(spec=LibertAIClient) + +@pytest.fixture +def service(mock_ai_client): + return LibertAIService(mock_ai_client) + +async def test_initialize_system_context(service, mock_ai_client): + """Test successful system context initialization""" + await service._initialize_system_context() + mock_ai_client.initialize_system_context.assert_called_once() + +async def test_initialize_system_context_error(service, mock_ai_client): + """Test error handling in system context initialization""" + mock_ai_client.initialize_system_context.side_effect = ValueError("API Error") + + with pytest.raises(ValueError, match="API Error"): + await service._initialize_system_context() + +async def test_load_strategy_code_success(service): + """Test successful strategy code loading""" + mapping = StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + module_path="tests.unit.fixtures.test_strategy", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + config_class="TestConfig" + ) + + with patch('importlib.import_module') as mock_import: + mock_module = MagicMock(spec=['__name__']) + mock_class = MagicMock(spec=['__module__']) + mock_module.__name__ = mapping.module_path + mock_class.__module__ = mapping.module_path + mock_import.return_value = mock_module + + # Mock inspect.getmembers to return our mock class + with patch('inspect.getmembers', return_value=[('TestStrategy', mock_class)]): + # Mock inspect.getsource to return some test code + test_code = "class TestStrategy:\n pass" + with patch('inspect.getsource', return_value=test_code): + result = await service._load_strategy_code(mapping) + assert result == test_code + +async def test_load_strategy_code_import_error(service): + """Test error handling when strategy module cannot be imported""" + mapping = StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + module_path="nonexistent.module", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + config_class="TestConfig" + ) + + with patch('importlib.import_module', side_effect=ImportError("Module not found")): + result = await service._load_strategy_code(mapping) + assert "Strategy implementation code not found" in result + +async def test_load_strategy_code_no_class(service): + """Test error handling when no strategy class is found""" + mapping = StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + module_path="tests.unit.fixtures.test_strategy", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + config_class="TestConfig" + ) + + with patch('importlib.import_module') as mock_import: + mock_module = MagicMock(spec=['__name__']) + mock_module.__name__ = mapping.module_path + mock_import.return_value = mock_module + + # Mock inspect.getmembers to return an empty list (no classes found) + with patch('inspect.getmembers', return_value=[]): + result = await service._load_strategy_code(mapping) + assert "Strategy implementation code not found" in result + +async def test_initialize_strategy_context_error(service, mock_ai_client): + """Test error handling in strategy context initialization""" + strategy = StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + module_path="test.module", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + config_class="TestConfig", + description="Test strategy", + parameters={ + "param1": StrategyParameter( + name="param1", + type="int", + default=1, + required=True, + display_name="Parameter 1", + description="Test parameter", + group=ParameterGroup.GENERAL, + is_advanced=False, + constraints=ParameterConstraints(), + display_type=DisplayType.INPUT + ) + } + ) + + mock_ai_client.initialize_strategy_context.side_effect = ValueError("API Error") + + with pytest.raises(ValueError, match="API Error"): + await service._initialize_strategy_context( + strategy=strategy, + strategy_code="class TestStrategy: pass", + slot_id=1 + ) + +async def test_initialize_contexts_error(service, mock_ai_client): + """Test error handling in contexts initialization""" + strategies = { + "test_strategy": MagicMock( + mapping=StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + module_path="test.module", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + config_class="TestConfig" + ), + parameters={} + ) + } + + mock_ai_client.initialize_system_context.side_effect = ValueError("API Error") + + with pytest.raises(ValueError, match="API Error"): + await service.initialize_contexts(strategies) + +async def test_get_parameter_suggestions_no_strategy(service): + """Test error handling when strategy is not found""" + result = await service.get_parameter_suggestions( + strategy_id="nonexistent", + strategy_config={}, + provided_params={} + ) + assert result == [] + +async def test_get_parameter_suggestions_no_slot(service): + """Test error handling when no slot is found for strategy""" + with patch('routers.strategies_models.discover_strategies') as mock_discover: + mock_discover.return_value = { + "test_strategy": MagicMock() + } + + result = await service.get_parameter_suggestions( + strategy_id="test_strategy", + strategy_config={}, + provided_params={} + ) + assert result == [] + +async def test_get_parameter_suggestions_api_error(service, mock_ai_client): + """Test error handling when AI API call fails""" + service.strategy_slot_map = {"test_strategy": 1} + + with patch('routers.strategies_models.discover_strategies') as mock_discover: + mock_discover.return_value = { + "test_strategy": MagicMock( + mapping=MagicMock( + display_name="Test Strategy", + strategy_type=StrategyType.DIRECTIONAL_TRADING + ) + ) + } + + mock_ai_client.get_suggestions.side_effect = ValueError("API Error") + + result = await service.get_parameter_suggestions( + strategy_id="test_strategy", + strategy_config={}, + provided_params={} + ) + assert result == [] \ No newline at end of file diff --git a/tests/unit/test_strategies.py b/tests/unit/test_strategies.py new file mode 100644 index 0000000..1b91b06 --- /dev/null +++ b/tests/unit/test_strategies.py @@ -0,0 +1,66 @@ +import pytest +from unittest.mock import patch, MagicMock +from routers.strategies_models import ( + StrategyType, + StrategyParameter, + ParameterConstraints, + ParameterGroup, + DisplayType, + discover_strategies, + StrategyMapping +) + +@pytest.fixture +def mock_strategy(): + """Create a mock strategy for testing""" + return StrategyMapping( + id="test_strategy", + display_name="Test Strategy", + description="Test description", + strategy_type=StrategyType.DIRECTIONAL_TRADING, + module_path="test.module", + config_class="TestConfig", + parameters={ + "test_param": StrategyParameter( + name="test_param", + type="int", + required=True, + default=1, + display_name="Test Parameter", + description="Test description", + group=ParameterGroup.GENERAL, + is_advanced=False, + constraints=ParameterConstraints(), + display_type=DisplayType.INPUT + ) + } + ) + +@pytest.mark.asyncio +async def test_discover_strategies(): + """Test strategy auto-discovery""" + strategies = discover_strategies() + + # Verify we found some strategies + assert len(strategies) > 0 + + # Verify each strategy has the required fields + for strategy_id, strategy in strategies.items(): + assert strategy.id == strategy_id + assert strategy.display_name + assert strategy.description + assert strategy.strategy_type in [StrategyType.DIRECTIONAL_TRADING, StrategyType.MARKET_MAKING, StrategyType.GENERIC] + assert strategy.module_path + assert strategy.config_class + assert strategy.parameters + + # Verify each parameter has the required fields + for param_name, param in strategy.parameters.items(): + assert param.name == param_name + assert param.type + assert isinstance(param.required, bool) + assert param.display_name + assert param.description is not None # Can be empty but should exist + assert param.group in ParameterGroup + assert isinstance(param.is_advanced, bool) + assert param.display_type in DisplayType \ No newline at end of file