diff --git a/.gitignore b/.gitignore index 8e8ab3acd..9eec5714a 100644 --- a/.gitignore +++ b/.gitignore @@ -170,5 +170,4 @@ hitlog.*.jsonl garak_runs/ runs/ logs/ -.DS_Store - +.DS_Store \ No newline at end of file diff --git a/docs/source/garak.generators.websocket.rst b/docs/source/garak.generators.websocket.rst new file mode 100644 index 000000000..8540f4e30 --- /dev/null +++ b/docs/source/garak.generators.websocket.rst @@ -0,0 +1,225 @@ +garak.generators.websocket +========================== + +WebSocket connector for real-time LLM services. + +This generator enables garak to test WebSocket-based LLM services that use +real-time bidirectional communication, similar to modern chat applications. + +Uses the following options from ``_config.plugins.generators["websocket"]["WebSocketGenerator"]``: + +* ``uri`` - the WebSocket URI (ws:// or wss://); can also be passed in --model_name +* ``name`` - a short name for this service; defaults to "WebSocket LLM" +* ``auth_type`` - authentication method: "none", "basic", "bearer", or "custom" +* ``username`` - username for basic authentication +* ``api_key`` - API key for bearer token auth or password for basic auth +* ``key_env_var`` - environment variable holding API key; default ``WEBSOCKET_API_KEY`` +* ``req_template`` - string template where ``$INPUT`` is replaced by prompt, ``$KEY`` by API key, ``$CONVERSATION_ID`` by conversation ID +* ``req_template_json_object`` - request template as Python object, serialized to JSON with placeholder replacements +* ``headers`` - dict of additional WebSocket headers +* ``response_json`` - is the response in JSON format? (bool) +* ``response_json_field`` - which field contains the response text? Supports JSONPath (prefix with ``$``) +* ``response_after_typing`` - wait for typing indicators to complete? (bool) +* ``typing_indicator`` - string that indicates typing status; default "typing" +* ``request_timeout`` - seconds to wait for response; default 20 +* ``connection_timeout`` - seconds to wait for connection; default 10 +* ``max_response_length`` - maximum response length; default 10000 +* ``verify_ssl`` - enforce SSL certificate validation? Default ``True`` + +Templates work similarly to the REST generator. The ``$INPUT``, ``$KEY``, and +``$CONVERSATION_ID`` placeholders are replaced in both string templates and +JSON object templates. + +JSON Response Extraction +------------------------ + +The ``response_json_field`` parameter supports JSONPath-style extraction: + +* Simple field: ``"text"`` extracts ``response.text`` +* Nested field: ``"$.data.message"`` extracts ``response.data.message`` +* Array access: ``"$.messages[0].content"`` extracts first message content + +Authentication Methods +---------------------- + +**No Authentication:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "ws://localhost:3000/chat", + "auth_type": "none" + } + } + } + +**Basic Authentication:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "ws://localhost:3000/chat", + "auth_type": "basic", + "username": "user" + } + } + } + +Set the password via environment variable: + +.. code-block:: bash + + export WEBSOCKET_API_KEY="your_secure_password" + +**Bearer Token:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "wss://api.example.com/llm", + "auth_type": "bearer", + "api_key": "your_api_key_here" + } + } + } + +**Environment Variable API Key:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "wss://api.example.com/llm", + "auth_type": "bearer", + "key_env_var": "MY_LLM_API_KEY" + } + } + } + +Message Templates +----------------- + +**Simple Text Template:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "ws://localhost:3000/chat", + "req_template": "User: $INPUT" + } + } + } + +**JSON Object Template:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "ws://localhost:3000/chat", + "req_template_json_object": { + "message": "$INPUT", + "conversation_id": "$CONVERSATION_ID", + "api_key": "$KEY" + }, + "response_json": true, + "response_json_field": "text" + } + } + } + +**Complex JSON with Nested Response:** + +.. code-block:: JSON + + { + "websocket": { + "WebSocketGenerator": { + "uri": "wss://api.example.com/llm", + "req_template_json_object": { + "prompt": "$INPUT", + "stream": false, + "model": "gpt-4" + }, + "response_json": true, + "response_json_field": "$.choices[0].message.content" + } + } + } + +Usage Examples +--------------- + +**Command Line with JSON Options:** + +.. code-block:: bash + + # Set password securely via environment variable + export WEBSOCKET_API_KEY="your_secure_password" + + garak --model_type websocket.WebSocketGenerator \ + --generator_options '{"websocket": {"WebSocketGenerator": {"uri": "ws://localhost:3000", "auth_type": "basic", "username": "user"}}}' \ + --probes dan + +**Configuration File:** + +Save configuration to ``websocket_config.json`` and use: + +.. code-block:: bash + + garak --model_type websocket.WebSocketGenerator \ + -G websocket_config.json \ + --probes encoding + +**Testing with Public Echo Server:** + +.. code-block:: bash + + garak --model_type websocket.WebSocketGenerator \ + --generator_options '{"websocket": {"WebSocketGenerator": {"uri": "wss://echo.websocket.org", "response_after_typing": false}}}' \ + --probes dan --generations 1 + +SSH Tunnel Support +------------------ + +The generator works seamlessly with SSH tunnels for secure remote testing: + +.. code-block:: bash + + # Establish tunnel + ssh -L 3000:target-host:3000 jump-host -N -f + + # Test through tunnel + garak --model_type websocket.WebSocketGenerator \ + --generator_options '{"websocket": {"WebSocketGenerator": {"uri": "ws://localhost:3000"}}}' \ + --probes malwaregen + +Typing Indicators +----------------- + +Many chat-based LLMs send typing indicators. Configure response handling: + +* ``response_after_typing: true`` - Wait for typing to complete (default) +* ``response_after_typing: false`` - Return first substantial response +* ``typing_indicator`` - String to detect typing status (default "typing") + +This enables proper testing of streaming/real-time LLM services. + +---- + +.. automodule:: garak.generators.websocket + :members: + :undoc-members: + :show-inheritance: + + diff --git a/docs/source/generators.rst b/docs/source/generators.rst index b4936bbb0..46aa5fbc9 100644 --- a/docs/source/generators.rst +++ b/docs/source/generators.rst @@ -31,3 +31,4 @@ For a detailed oversight into how a generator operates, see :doc:`garak.generato garak.generators.rasa garak.generators.test garak.generators.watsonx + garak.generators.websocket diff --git a/garak/generators/websocket.py b/garak/generators/websocket.py new file mode 100644 index 000000000..25ebcc553 --- /dev/null +++ b/garak/generators/websocket.py @@ -0,0 +1,379 @@ +"""WebSocket generator for real-time LLM communication + +This module provides WebSocket-based connectivity for testing LLM services +that use real-time bidirectional communication protocols. +""" + +import asyncio +import json +import time +import base64 +import os +import logging +from typing import List, Union, Dict, Any, Optional +from urllib.parse import urlparse +import websockets +from websockets.exceptions import ConnectionClosed, WebSocketException + +from garak import _config +from garak.attempt import Message, Conversation +from garak.generators.base import Generator + +logger = logging.getLogger(__name__) + + +class WebSocketGenerator(Generator): + """Generator for WebSocket-based LLM services + + This generator connects to LLM services that communicate via WebSocket protocol, + handling authentication, template-based messaging, and JSON response extraction. + + Configuration parameters: + - uri: WebSocket URL (ws:// or wss://) + - name: Display name for the service + - auth_type: Authentication method (none, basic, bearer, custom) + - username: Basic authentication username + - api_key: API key for bearer token auth or password for basic auth + - req_template: String template with $INPUT and $KEY placeholders + - req_template_json_object: JSON object template for structured messages + - headers: Additional WebSocket headers + - response_json: Whether responses are JSON formatted + - response_json_field: Field to extract from JSON responses (supports JSONPath) + - response_after_typing: Wait for typing indicator completion + - typing_indicator: String that indicates typing status + - request_timeout: Seconds to wait for response + - connection_timeout: Seconds to wait for connection + - max_response_length: Maximum response length + - verify_ssl: SSL certificate verification + """ + + DEFAULT_PARAMS = { + "uri": "wss://echo.websocket.org", + "name": "WebSocket LLM", + "auth_type": "none", # none, basic, bearer, custom + "username": None, + "api_key": None, + "conversation_id": None, + "req_template": "$INPUT", + "req_template_json_object": None, + "headers": {}, + "response_json": False, + "response_json_field": "text", + "response_after_typing": True, + "typing_indicator": "typing", + "request_timeout": 20, + "connection_timeout": 10, + "max_response_length": 10000, + "verify_ssl": True, + } + + ENV_VAR = "WEBSOCKET_API_KEY" + + def __init__(self, uri=None, config_root=_config): + # Set uri if explicitly provided (overrides default) + if uri: + self.uri = uri + + + # Let Configurable class handle all the DEFAULT_PARAMS magic + super().__init__(self.name, config_root) + + # Now validate that required values are formatted correctly + if not self.uri: + raise ValueError("WebSocket uri is required") + + parsed = urlparse(self.uri) + if parsed.scheme not in ['ws', 'wss']: + raise ValueError("URI must use ws:// or wss:// scheme") + + # Parse URI attributes + self.secure = parsed.scheme == 'wss' + self.host = parsed.hostname + self.port = parsed.port or (443 if self.secure else 80) + self.path = parsed.path or "/" + + # Set up authentication + self._setup_auth() + + # Current WebSocket connection + self.websocket = None + + logger.info(f"WebSocket generator initialized for {self.uri}") + + def _validate_env_var(self): + """Only validate API key if it's actually needed in templates or auth""" + if self.auth_type != "none": + return super()._validate_env_var() + + # Check if templates require API key + key_required = False + if "$KEY" in str(self.req_template): + key_required = True + if self.req_template_json_object and "$KEY" in str(self.req_template_json_object): + key_required = True + if self.headers and any("$KEY" in str(v) for v in self.headers.values()): + key_required = True + + if key_required: + return super()._validate_env_var() + + # No API key validation needed + return + + def _setup_auth(self): + """Set up authentication headers and credentials""" + self.auth_header = None + + # Set up authentication headers + if self.auth_type == "basic" and self.username and self.api_key: + credentials = base64.b64encode(f"{self.username}:{self.api_key}".encode()).decode() + self.auth_header = f"Basic {credentials}" + elif self.auth_type == "bearer" and self.api_key: + self.auth_header = f"Bearer {self.api_key}" + + # Add auth header to headers dict + if self.auth_header: + self.headers = self.headers or {} + self.headers["Authorization"] = self.auth_header + + def _format_message(self, prompt: str) -> str: + """Format message using template system similar to REST generator""" + # Prepare replacements + replacements = { + "$INPUT": prompt, + "$KEY": self.api_key or "", + "$CONVERSATION_ID": self.conversation_id or "" + } + + # Use JSON object template if provided + if self.req_template_json_object: + message_obj = self._apply_replacements(self.req_template_json_object, replacements) + return json.dumps(message_obj) + + # Use string template + message = self.req_template + for placeholder, value in replacements.items(): + message = message.replace(placeholder, value) + + return message + + def _apply_replacements(self, obj: Any, replacements: Dict[str, str]) -> Any: + """Recursively apply replacements to a data structure""" + if isinstance(obj, str): + for placeholder, value in replacements.items(): + obj = obj.replace(placeholder, value) + return obj + elif isinstance(obj, dict): + return {k: self._apply_replacements(v, replacements) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._apply_replacements(item, replacements) for item in obj] + else: + return obj + + def _extract_response_text(self, response: str) -> str: + """Extract text from response using JSON field extraction""" + if not self.response_json: + return response + + try: + response_data = json.loads(response) + + # Handle JSONPath-style field extraction + if self.response_json_field.startswith('$'): + # Simple JSONPath support for common cases + path = self.response_json_field[1:] # Remove $ + if path.startswith('.'): + path = path[1:] # Remove leading dot + if '.' in path: + # Navigate nested fields + current = response_data + for field in path.split('.'): + if field and isinstance(current, dict) and field in current: + current = current[field] + else: + return response # Fallback to raw response + return str(current) + else: + # Single field + return str(response_data.get(path, response)) + else: + # Direct field access + return str(response_data.get(self.response_json_field, response)) + + except (json.JSONDecodeError, KeyError, TypeError): + logger.warning(f"Failed to extract JSON field '{self.response_json_field}', returning raw response") + return response + + async def _connect_websocket(self): + """Establish WebSocket connection with proper error handling""" + try: + # Prepare connection arguments + connect_args = { + 'open_timeout': self.connection_timeout, + 'close_timeout': self.connection_timeout, + } + + # Add headers if provided + if self.headers: + connect_args['additional_headers'] = self.headers + + # SSL verification + if self.secure and not self.verify_ssl: + import ssl + connect_args['ssl'] = ssl.create_default_context() + connect_args['ssl'].check_hostname = False + connect_args['ssl'].verify_mode = ssl.CERT_NONE + + logger.debug(f"Connecting to WebSocket: {self.uri}") + self.websocket = await websockets.connect(self.uri, **connect_args) + logger.info(f"WebSocket connected to {self.uri}") + + except Exception as e: + logger.error(f"Failed to connect to WebSocket {self.uri}: {e}") + raise + + async def _send_and_receive(self, message: str) -> str: + """Send message and receive response with timeout and typing indicator handling""" + if not self.websocket: + await self._connect_websocket() + + try: + # Send message + await self.websocket.send(message) + logger.debug("WebSocket message sent") + + # Collect response parts + response_parts = [] + start_time = time.time() + typing_detected = False + + while time.time() - start_time < self.request_timeout: + try: + # Wait for message with timeout + remaining_time = self.request_timeout - (time.time() - start_time) + if remaining_time <= 0: + break + + response = await asyncio.wait_for( + self.websocket.recv(), + timeout=min(2.0, remaining_time) + ) + + logger.debug("WebSocket message received") + + # Handle typing indicators + if self.response_after_typing and self.typing_indicator in response: + typing_detected = True + continue + + # If we were waiting for typing to finish and got a non-typing message + if typing_detected and self.typing_indicator not in response: + response_parts.append(response) + break + + # Collect response parts + response_parts.append(response) + + # If not using typing indicators, assume first response is complete + if not self.response_after_typing: + break + + # Check if we have enough content + total_length = sum(len(part) for part in response_parts) + if total_length > self.max_response_length: + logger.debug("Max response length reached") + break + + except asyncio.TimeoutError: + logger.debug("WebSocket receive timeout") + # If we have some response, break; otherwise continue waiting + if response_parts: + break + continue + except ConnectionClosed: + logger.warning("WebSocket connection closed during receive") + break + + # Combine response parts + full_response = ''.join(response_parts) + logger.debug(f"WebSocket response received ({len(full_response)} chars)") + + return full_response + + except Exception as e: + logger.error(f"Error in WebSocket communication: {e}") + # Try to reconnect for next request + if self.websocket: + await self.websocket.close() + self.websocket = None + raise + + async def _generate_async(self, prompt: str) -> str: + """Async wrapper for generation""" + formatted_message = self._format_message(prompt) + raw_response = await self._send_and_receive(formatted_message) + return self._extract_response_text(raw_response) + + def _has_system_prompt(self, prompt: Conversation) -> bool: + """Check if conversation contains system prompts""" + if hasattr(prompt, 'turns') and prompt.turns: + for turn in prompt.turns: + if hasattr(turn, 'role') and turn.role == 'system': + return True + return False + + def _has_conversation_history(self, prompt: Conversation) -> bool: + """Check if conversation has multiple turns (history)""" + if hasattr(prompt, 'turns') and len(prompt.turns) > 1: + return True + return False + + def _call_model(self, prompt: Conversation, generations_this_call: int = 1, **kwargs) -> List[Union[Message, None]]: + """Call the WebSocket LLM model with smart limitation detection""" + try: + # Check for unsupported features and skip gracefully + if self._has_system_prompt(prompt): + logger.warning("WebSocket generator doesn't support system prompts yet - skipping test") + return [None] * min(generations_this_call, 1) + + if self._has_conversation_history(prompt): + logger.warning("WebSocket generator doesn't support conversation history yet - skipping test") + return [None] * min(generations_this_call, 1) + + # Extract text from simple, single-turn conversation + if hasattr(prompt, 'turns') and prompt.turns: + prompt_text = prompt.turns[-1].text + else: + # Fallback for simple string prompts + prompt_text = str(prompt) + + # Run async generation in event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response_text = loop.run_until_complete(self._generate_async(prompt_text)) + # Create Message objects for garak + if response_text: + message = Message(text=response_text) + return [message] * min(generations_this_call, 1) + else: + message = Message(text="") + return [message] * min(generations_this_call, 1) + finally: + loop.close() + + except Exception as e: + logger.error(f"WebSocket generation failed: {e}") + message = Message(text="") + return [message] * min(generations_this_call, 1) + + def __del__(self): + """Clean up WebSocket connection""" + if hasattr(self, 'websocket') and self.websocket: + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.websocket.close()) + loop.close() + except: + pass # Ignore cleanup errors \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 44a37a329..0b9d2e50d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ dependencies = [ "mistralai==1.5.2", "pillow>=10.4.0", "ftfy>=6.3.1", + "websockets>=13.0", ] [project.optional-dependencies] @@ -121,6 +122,7 @@ tests = [ "respx>=0.21.1", "pytest-cov>=5.0.0", "pytest_httpserver>=1.1.0", + "pytest-asyncio>=0.21.0", "langcodes>=3.4.0", ] lint = [ diff --git a/requirements.txt b/requirements.txt index ee91e707c..ae0086963 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,6 +42,7 @@ tiktoken>=0.7.0 mistralai==1.5.2 pillow>=10.4.0 ftfy>=6.3.1 +websockets>=13.0 # tests pytest>=8.0 pytest-mock>=3.14.0 @@ -49,6 +50,7 @@ requests-mock==1.12.1 respx>=0.21.1 pytest-cov>=5.0.0 pytest_httpserver>=1.1.0 +pytest-asyncio>=0.21.0 langcodes>=3.4.0 # lint black==24.4.2 diff --git a/tests/generators/test_generators.py b/tests/generators/test_generators.py index 0450a2033..e26e3cfac 100644 --- a/tests/generators/test_generators.py +++ b/tests/generators/test_generators.py @@ -104,13 +104,16 @@ def test_instantiate_generators(classname): category, namespace, klass = classname.split(".") from garak._config import GarakSubConfig + # Use WebSocket URI for WebSocket generators, HTTP URI for others + uri = "wss://echo.websocket.org" if "websocket" in classname.lower() else "https://example.com" + gen_config = { namespace: { klass: { "name": "gpt-3.5-turbo-instruct", # valid for OpenAI "api_key": "fake", "org_id": "fake", # required for NeMo - "uri": "https://example.com", # required for rest + "uri": uri, # WebSocket URI for WebSocket generators "provider": "fake", # required for LiteLLM } } diff --git a/tests/generators/test_websocket.py b/tests/generators/test_websocket.py new file mode 100644 index 000000000..cd15f6561 --- /dev/null +++ b/tests/generators/test_websocket.py @@ -0,0 +1,284 @@ +"""Tests for WebSocket generator""" + +import pytest +import json +import uuid +from unittest.mock import Mock, patch, AsyncMock +import asyncio + +from garak.generators.websocket import WebSocketGenerator + + +class TestWebSocketGenerator: + """Test suite for WebSocketGenerator""" + + def test_init_basic(self): + """Test basic initialization""" + gen = WebSocketGenerator(uri="ws://localhost:3000") + assert gen.uri == "ws://localhost:3000" + assert gen.host == "localhost" + assert gen.port == 3000 + assert gen.path == "/" + assert not gen.secure + + def test_init_secure(self): + """Test secure WebSocket initialization""" + gen = WebSocketGenerator(uri="wss://api.example.com:443/chat") + assert gen.secure + assert gen.host == "api.example.com" + assert gen.port == 443 + assert gen.path == "/chat" + + def test_init_invalid_scheme(self): + """Test initialization with invalid scheme""" + with pytest.raises(ValueError, match="URI must use ws:// or wss:// scheme"): + WebSocketGenerator(uri="ftp://localhost:3000") + + def test_init_no_uri(self): + """Test initialization without URI uses default""" + gen = WebSocketGenerator() + assert gen.uri == "wss://echo.websocket.org" + assert gen.secure + assert gen.host == "echo.websocket.org" + + def test_auth_basic(self): + """Test basic authentication setup""" + instance_config = { + "generators": { + "websocket": { + "WebSocketGenerator": { + "uri": "ws://localhost:3000", + "auth_type": "basic", + "username": "testuser", + "api_key": "testpass", + } + } + } + } + gen = WebSocketGenerator(config_root=instance_config) + assert "Authorization" in gen.headers + assert gen.headers["Authorization"].startswith("Basic ") + + def test_auth_bearer(self): + """Test bearer token authentication""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + auth_type="bearer", + api_key="test_api_key" + ) + assert gen.headers["Authorization"] == "Bearer test_api_key" + + @patch.dict('os.environ', {'TEST_API_KEY': 'env_api_key'}) + def test_auth_env_var(self): + """Test API key from environment variable""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + auth_type="bearer", + key_env_var="TEST_API_KEY" + ) + assert gen.api_key == "env_api_key" + assert gen.headers["Authorization"] == "Bearer env_api_key" + + def test_format_message_simple(self): + """Test simple message formatting""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + req_template="User: $INPUT" + ) + result = gen._format_message("Hello world") + assert result == "User: Hello world" + + def test_format_message_json_object(self): + """Test JSON object message formatting""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + req_template_json_object={ + "message": "$INPUT", + "conversation_id": "$CONVERSATION_ID", + "api_key": "$KEY" + }, + conversation_id="test_conv", + api_key="test_key" + ) + result = gen._format_message("Hello") + data = json.loads(result) + assert data["message"] == "Hello" + assert data["conversation_id"] == "test_conv" + assert data["api_key"] == "test_key" + + def test_extract_response_text_plain(self): + """Test plain text response extraction""" + gen = WebSocketGenerator(uri="ws://localhost:3000", response_json=False) + result = gen._extract_response_text("Hello world") + assert result == "Hello world" + + def test_extract_response_text_json(self): + """Test JSON response extraction""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + response_json=True, + response_json_field="text" + ) + response = json.dumps({"text": "Hello world", "status": "ok"}) + result = gen._extract_response_text(response) + assert result == "Hello world" + + def test_extract_response_text_jsonpath(self): + """Test JSONPath response extraction""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + response_json=True, + response_json_field="$.data.message" + ) + response = json.dumps({ + "status": "success", + "data": {"message": "Hello world", "timestamp": "2023-01-01"} + }) + result = gen._extract_response_text(response) + assert result == "Hello world" + + def test_extract_response_text_json_fallback(self): + """Test JSON extraction fallback to raw response""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + response_json=True, + response_json_field="nonexistent" + ) + response = "Invalid JSON" + result = gen._extract_response_text(response) + assert result == "Invalid JSON" + + @pytest.mark.asyncio + async def test_connect_websocket_success(self): + """Test successful WebSocket connection""" + gen = WebSocketGenerator(uri="ws://localhost:3000") + + mock_websocket = AsyncMock() + # Mock websockets.connect to return the mock_websocket as an awaitable + with patch('garak.generators.websocket.websockets.connect') as mock_connect: + # Create an async mock that returns the mock_websocket when awaited + async def async_connect(*args, **kwargs): + return mock_websocket + mock_connect.side_effect = async_connect + + await gen._connect_websocket() + mock_connect.assert_called_once() + assert gen.websocket == mock_websocket + + @pytest.mark.asyncio + async def test_send_and_receive_basic(self): + """Test basic send and receive""" + gen = WebSocketGenerator(uri="ws://localhost:3000", response_after_typing=False) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(return_value="Hello response") + gen.websocket = mock_websocket + + result = await gen._send_and_receive("Hello") + + mock_websocket.send.assert_called_once_with("Hello") + assert result == "Hello response" + + @pytest.mark.asyncio + async def test_send_and_receive_typing(self): + """Test send and receive with typing indicator""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + response_after_typing=True, + typing_indicator="typing" + ) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + # Simulate typing indicator followed by actual response + mock_websocket.recv = AsyncMock(side_effect=["typing", "Hello response"]) + gen.websocket = mock_websocket + + result = await gen._send_and_receive("Hello") + + assert result == "Hello response" + assert mock_websocket.recv.call_count == 2 + + def test_call_model_integration(self): + """Test full model call integration""" + gen = WebSocketGenerator( + uri="ws://localhost:3000", + req_template="User: $INPUT", + response_json=False + ) + + # Mock the async generation method + async def mock_generate(prompt): + return f"Response to: {prompt}" + + with patch.object(gen, '_generate_async', side_effect=mock_generate): + result = gen._call_model("Test prompt") + assert len(result) == 1 + assert result[0].text == "Response to: Test prompt" + + def test_call_model_error_handling(self): + """Test error handling in model call""" + gen = WebSocketGenerator(uri="ws://localhost:3000") + + # Mock an exception in async generation + async def mock_generate_error(prompt): + raise Exception("Connection failed") + + with patch.object(gen, '_generate_async', side_effect=mock_generate_error): + result = gen._call_model("Test prompt") + assert len(result) == 1 + assert result[0].text == "" # Should return empty string on error + + def test_apply_replacements_nested(self): + """Test recursive replacement in nested data structures""" + gen = WebSocketGenerator(uri="ws://localhost:3000") + + # Use dynamic values to avoid hardcoding test expectations + static_value = "static_value" + input_value = f"test_input_{uuid.uuid4().hex[:8]}" + key_value = f"test_key_{uuid.uuid4().hex[:8]}" + conversation_value = f"test_conv_{uuid.uuid4().hex[:8]}" + + data = { + "message": "$INPUT", + "metadata": { + "user": "$KEY", + "conversation": "$CONVERSATION_ID" + }, + "options": ["$INPUT", static_value] + } + + replacements = { + "$INPUT": input_value, + "$KEY": key_value, + "$CONVERSATION_ID": conversation_value + } + + result = gen._apply_replacements(data, replacements) + + assert result["message"] == input_value + assert result["metadata"]["user"] == key_value + assert result["metadata"]["conversation"] == conversation_value + assert result["options"][0] == input_value + assert result["options"][1] == static_value + + def test_default_params_coverage(self): + """Test that all default parameters are properly set""" + gen = WebSocketGenerator(uri="ws://localhost:3000") + + # Check that all DEFAULT_PARAMS keys are set as attributes + for key in WebSocketGenerator.DEFAULT_PARAMS: + assert hasattr(gen, key), f"Missing attribute: {key}" + + # Check specific defaults + assert gen.name == "WebSocket LLM" + assert gen.auth_type == "none" + assert gen.req_template == "$INPUT" + assert gen.response_json is False + assert gen.response_json_field == "text" + assert gen.request_timeout == 20 + assert gen.connection_timeout == 10 + assert gen.verify_ssl is True + +