diff --git a/src/strands_tools/generate_image.py b/src/strands_tools/generate_image.py index ea8c2a16..0d014386 100644 --- a/src/strands_tools/generate_image.py +++ b/src/strands_tools/generate_image.py @@ -1,18 +1,18 @@ """ -Image generation tool for Strands Agent using Stable Diffusion. +Image generation tool for Strands Agent using Amazon Bedrock. This module provides functionality to generate high-quality images using Amazon Bedrock's -Stable Diffusion models based on text prompts. It handles the entire image generation +text-to-image models based on text prompts. It handles the entire image generation process including API integration, parameter management, response processing, and local storage of results. Key Features: 1. Image Generation: - • Text-to-image conversion using Stable Diffusion - • Support for multiple model variants (primarily stable-diffusion-xl-v1) - • Customizable generation parameters (seed, steps, cfg_scale) - • Style preset selection for consistent aesthetics + • Text-to-image conversion using various Bedrock models + • Support for multiple model families (Amazon Titan, Stability AI) + • Dynamic model selection when no model is specified + • Customizable generation parameters (seed, cfg_scale) 2. Output Management: • Automatic local saving with intelligent filename generation @@ -39,11 +39,10 @@ # Advanced usage with custom parameters agent.tool.generate_image( prompt="A futuristic city with flying cars", - model_id="stability.stable-diffusion-xl-v1", + model_id="amazon.titan-image-generator-v2:0", + region="us-west-2", seed=42, - steps=50, - cfg_scale=12, - style_preset="cinematic" + cfg_scale=12 ) ``` @@ -55,14 +54,15 @@ import os import random import re -from typing import Any +from typing import Any, List, Tuple import boto3 +from botocore.exceptions import ClientError from strands.types.tools import ToolResult, ToolUse TOOL_SPEC = { "name": "generate_image", - "description": "Generates an image using Stable Diffusion based on a given prompt", + "description": "Generates an image using Amazon Bedrock models based on a given prompt", "inputSchema": { "json": { "type": "object", @@ -73,24 +73,21 @@ }, "model_id": { "type": "string", - "description": "Model id for image model, stability.stable-diffusion-xl-v1.", + "description": "Model ID for image generation (e.g., amazon.titan-image-generator-v2:0, " + "stability.stable-image-core-v1:1)", }, - "seed": { - "type": "integer", - "description": "Optional: Seed for random number generation (default: random)", + "region": { + "type": "string", + "description": "Optional: AWS region to use (default: from AWS_REGION env variable or us-west-2)", }, - "steps": { + "seed": { "type": "integer", - "description": "Optional: Number of steps for image generation (default: 30)", + "description": "Optional: Seed for deterministic generation (default: random)", }, "cfg_scale": { "type": "number", "description": "Optional: CFG scale for image generation (default: 10)", }, - "style_preset": { - "type": "string", - "description": "Optional: Style preset for image generation (default: 'photographic')", - }, }, "required": ["prompt"], } @@ -98,50 +95,147 @@ } +def validate_model_in_region(model_id: str, region: str) -> Tuple[bool, List[str]]: + """ + Validate if the specified model is available in the given region using the Bedrock API. + + This function checks if the model supports text-to-image generation capability in the specified region + by examining the model's inputModalities and outputModalities. It also checks if the model + is in LEGACY status or doesn't support ON_DEMAND inference, in which case it raises an exception. + + Args: + model_id: The model ID to validate + region: The AWS region to check against + + Returns: + tuple: (is_valid, available_models) + - is_valid: True if the model is supported in the region, False otherwise + - available_models: List of available text-to-image models in the region + + Raises: + ValueError: If the model is in LEGACY status or doesn't support ON_DEMAND inference + """ + try: + # Create a Bedrock client to list available models + bedrock_client = boto3.client("bedrock", region_name=region) + + # Get list of foundation models available in the region + response = bedrock_client.list_foundation_models() + + # Filter for text-to-image models based on input and output modalities + available_models = [] + legacy_models = [] + non_on_demand_models = [] + + for model in response.get("modelSummaries", []): + model_id_from_api = model.get("modelId", "") + input_modalities = model.get("inputModalities", []) + output_modalities = model.get("outputModalities", []) + + # Check if this model supports text-to-image generation + # It should take TEXT as input and produce IMAGE as output + if "TEXT" in input_modalities and "IMAGE" in output_modalities: + # Check if the model is in LEGACY status + model_lifecycle = model.get("modelLifecycle", {}) + status = model_lifecycle.get("status", "") + + # Check if the model supports ON_DEMAND inference + inference_types = model.get("inferenceTypesSupported", []) + supports_on_demand = "ON_DEMAND" in inference_types + + if status == "LEGACY": + legacy_models.append(model_id_from_api) + elif not supports_on_demand: + non_on_demand_models.append(model_id_from_api) + else: + available_models.append(model_id_from_api) + + # Check if the requested model is in the list of legacy models + if any(model_id == legacy_model for legacy_model in legacy_models): + raise ValueError( + f"Model '{model_id}' is in LEGACY status and no longer recommended for use. " + f"Please use one of the active models instead: {', '.join(available_models)}" + ) + + # Check if the requested model is in the list of non-on-demand models + if any(model_id == non_on_demand_model for non_on_demand_model in non_on_demand_models): + raise ValueError( + f"Model '{model_id}' does not support on-demand throughput. " + f"Please use one of these models that support on-demand inference: {', '.join(available_models)}" + ) + + # Check if the requested model is in the list of available models + is_valid = any(model_id == available_model for available_model in available_models) + + return is_valid, available_models + + except ValueError as e: + # Re-raise ValueError for legacy models or non-on-demand models + raise e + except Exception as e: + # If we can't access the API, return False and empty list + print(f"Error checking model availability: {str(e)}") + return False, [] + + +def create_filename(prompt: str) -> str: + """ + Generate a filename from the prompt text. + + Args: + prompt: The text prompt used for image generation + + Returns: + A sanitized filename based on the first few words of the prompt + """ + words = re.findall(r"\w+", prompt.lower())[:5] + filename = "_".join(words) + filename = re.sub(r"[^\w\-_\.]", "_", filename) + return filename[:100] # Limit filename length + + def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: """ - Generate images from text prompts using Stable Diffusion via Amazon Bedrock. + Generate images from text prompts using Amazon Bedrock models. This function transforms textual descriptions into high-quality images using - Stable Diffusion models available through Amazon Bedrock. It provides extensive + various text-to-image models available through Amazon Bedrock. It provides extensive customization options and handles the complete process from API interaction to image storage and result formatting. How It Works: ------------ 1. Extracts and validates parameters from the tool input - 2. Configures the request payload with appropriate parameters - 3. Invokes the Bedrock image generation model through AWS SDK - 4. Processes the response to extract the base64-encoded image - 5. Creates an appropriate filename based on the prompt content - 6. Saves the image to a local output directory - 7. Returns a success response with both text description and rendered image + 2. Dynamically selects an available model if none is specified + 3. Configures the request payload with appropriate parameters for the model family + 4. Invokes the Bedrock image generation model through AWS SDK + 5. Processes the response to extract the base64-encoded image + 6. Creates an appropriate filename based on the prompt content + 7. Saves the image to a local output directory + 8. Returns a success response with both text description and rendered image Generation Parameters: -------------------- - prompt: The textual description of the desired image - - model_id: Specific model to use (defaults to stable-diffusion-xl-v1) + - model_id: Specific model to use (if not provided, automatically selects one) + - region: AWS region to use (defaults to AWS_REGION env variable or us-west-2) - seed: Controls randomness for reproducible results - - style_preset: Artistic style to apply (e.g., photographic, cinematic) - cfg_scale: Controls how closely the image follows the prompt - - steps: Number of diffusion steps (higher = more refined but slower) - Common Usage Scenarios: - --------------------- - - Creating illustrations for documents or presentations - - Generating visual concepts for design projects - - Visualizing scenes or characters for creative writing - - Producing custom artwork based on specific descriptions - - Testing visual ideas before commissioning real artwork + Supported Model Families: + ---------------------- + - Amazon Titan Image Generator (v1 and v2) + - Amazon Nova Canvas + - Stability AI Stable Image (Core and Ultra) + - Stability AI SD3 Args: tool: ToolUse object containing the parameters for image generation. - prompt: The text prompt describing the desired image. - - model_id: Optional model identifier (default: "stability.stable-diffusion-xl-v1"). - - seed: Optional random seed (default: random integer). - - style_preset: Optional style preset name (default: "photographic"). + - model_id: Optional model identifier. + - region: Optional AWS region (default: from AWS_REGION env variable or us-west-2). + - seed: Optional seed value (default: random integer b/w O and 2147483646). - cfg_scale: Optional CFG scale value (default: 10). - - steps: Optional number of diffusion steps (default: 30). **kwargs: Additional keyword arguments (unused). Returns: @@ -163,42 +257,182 @@ def generate_image(tool: ToolUse, **kwargs: Any) -> ToolResult: # Extract input parameters prompt = tool_input.get("prompt", "A stylized picture of a cute old steampunk robot.") - model_id = tool_input.get("model_id", "stability.stable-diffusion-xl-v1") - seed = tool_input.get("seed", random.randint(0, 4294967295)) - style_preset = tool_input.get("style_preset", "photographic") + + # Get region from input, environment variable, or default to us-west-2 + region = tool_input.get("region", os.environ.get("AWS_REGION", "us-west-2")) + + # Check if model_id is explicitly provided + if "model_id" in tool_input: + model_id = tool_input["model_id"] + else: + # Find valid models in the region + try: + _, available_models = validate_model_in_region("", region) + + if not available_models: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"No text-to-image models available in region '{region}'. " + f"Please try a different region." + } + ], + } + + # Simply use the first available model + model_id = available_models[0] + print(f"No model_id provided. Using automatically selected model: {model_id}") + except Exception as e: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Error determining available models in region '{region}': {str(e)}. " + f"Please specify a model_id explicitly." + } + ], + } + + # Get seed from input or use a default value that works for all models + # Keeping range's end to 2147483646 as this is the max seed value supported by Nova models + seed = tool_input.get("seed", random.randint(0, 2147483646)) cfg_scale = tool_input.get("cfg_scale", 10) - steps = tool_input.get("steps", 30) - - # Create a Bedrock Runtime client - client = boto3.client("bedrock-runtime", region_name="us-west-2") - - # Format the request payload - native_request = { - "text_prompts": [{"text": prompt}], - "style_preset": style_preset, - "seed": seed, - "cfg_scale": cfg_scale, - "steps": steps, - } + + # Validate if the model is available in the specified region + try: + is_valid, available_models = validate_model_in_region(model_id, region) + if not is_valid: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Model '{model_id}' is not available in region '{region}'. " + f"Available text-to-image models in this region include: " + f"{', '.join(available_models)}" + } + ], + } + except ValueError as e: + # Handle legacy model error + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": str(e)}], + } + except Exception: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Could not validate model availability in region '{region}'. " + f"Please check your AWS credentials and permissions." + } + ], + } + + # Create a Bedrock Runtime client with the specified region + try: + client = boto3.client("bedrock-runtime", region_name=region) + except ClientError as e: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Failed to create Bedrock client in region '{region}': {str(e)}"}], + } + + # Format the request payload based on the model family + if "stability" in model_id: + # Format for Stability AI models (stable-image, sd3) + native_request = {"prompt": prompt, "seed": seed} + elif "amazon" in model_id: + # Format for Amazon models (Titan, Nova Canvas) + native_request = { + "taskType": "TEXT_IMAGE", + "textToImageParams": {"text": prompt}, + "imageGenerationConfig": { + "numberOfImages": 1, + "cfgScale": cfg_scale, + "seed": seed, + }, + } + else: + # This should not happen due to the validation above, but keeping as a fallback + raise ValueError(f"Unsupported model: {model_id}. Please use one of the supported image generation models.") + request = json.dumps(native_request) # Invoke the model - response = client.invoke_model(modelId=model_id, body=request) + try: + response = client.invoke_model(modelId=model_id, body=request) + except ClientError as e: + if "AccessDeniedException" in str(e): + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Access denied for model '{model_id}' in region '{region}'. " + f"Please check your AWS credentials and permissions." + } + ], + } + elif "ValidationException" in str(e) and "not found" in str(e).lower(): + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Model '{model_id}' not found in region '{region}'. " + f"Please verify the model ID and region." + } + ], + } + else: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error invoking model: {str(e)}"}], + } # Decode the response body model_response = json.loads(response["body"].read()) - # Extract the image data - base64_image_data = model_response["artifacts"][0]["base64"] + # Extract the image data - handle different response formats + try: + if "stability" in model_id: + # For Stability AI models + base64_image_data = model_response["images"][0] + elif "amazon" in model_id: + # For Amazon Titan and Nova Canvas models + if "images" in model_response and isinstance(model_response["images"], list): + if isinstance(model_response["images"][0], dict) and "imageBase64" in model_response["images"][0]: + base64_image_data = model_response["images"][0]["imageBase64"] + elif isinstance(model_response["images"][0], str): + base64_image_data = model_response["images"][0] + else: + raise ValueError("Unexpected Amazon model response format") + else: + raise ValueError("Unexpected Amazon model response structure") + else: + raise ValueError(f"Unsupported model family: {model_id}") + except (KeyError, IndexError) as e: + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [ + { + "text": f"Failed to extract image data from model response: {str(e)}. " + f"Response structure: {json.dumps(model_response, indent=2)[:500]}..." + } + ], + } # Create a filename based on the prompt - def create_filename(prompt: str) -> str: - """Generate a filename from the prompt text.""" - words = re.findall(r"\w+", prompt.lower())[:5] - filename = "_".join(words) - filename = re.sub(r"[^\w\-_\.]", "_", filename) - return filename[:100] # Limit filename length - filename = create_filename(prompt) # Save the generated image to a local folder @@ -206,6 +440,7 @@ def create_filename(prompt: str) -> str: if not os.path.exists(output_dir): os.makedirs(output_dir) + # Handle duplicate filenames i = 1 base_image_path = os.path.join(output_dir, f"{filename}.png") image_path = base_image_path @@ -213,9 +448,11 @@ def create_filename(prompt: str) -> str: image_path = os.path.join(output_dir, f"{filename}_{i}.png") i += 1 + # Save the image to disk with open(image_path, "wb") as file: file.write(base64.b64decode(base64_image_data)) + # Return success response with image return { "toolUseId": tool_use_id, "status": "success", diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index d9849814..5840deaa 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -140,7 +140,7 @@ "description": "Optional metadata to store with the memory", }, }, - "required": ["action"] + "required": ["action"], } }, } diff --git a/tests/test_generate_image.py b/tests/test_generate_image.py index dc2132f6..3962d4a5 100644 --- a/tests/test_generate_image.py +++ b/tests/test_generate_image.py @@ -2,11 +2,13 @@ Tests for the generate_image tool. """ -import base64 import json +import os +from contextlib import contextmanager from unittest.mock import MagicMock, patch import pytest +from botocore.exceptions import ClientError from strands import Agent from strands_tools import generate_image @@ -24,20 +26,97 @@ def extract_result_text(result): return str(result) +# Helper functions for common verification patterns +def verify_error_response_structure(result, expected_tool_use_id="test-tool-use-id"): + """Verify the structure of an error response.""" + assert result["toolUseId"] == expected_tool_use_id + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + return result["content"][0]["text"] + + +def verify_success_response_structure(result, expected_tool_use_id="test-tool-use-id"): + """Verify the structure of a success response.""" + assert result["toolUseId"] == expected_tool_use_id + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 # Text and image content + + # Verify text content + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + assert "The generated image has been saved locally" in result["content"][0]["text"] + + # Verify image content + assert isinstance(result["content"][1], dict) + assert "image" in result["content"][1] + assert result["content"][1]["image"]["format"] == "png" + assert isinstance(result["content"][1]["image"]["source"]["bytes"], bytes) + + return result["content"][0]["text"] + + +def verify_aws_service_calls(mock_boto3_client, mock_validate_model, model_id, region="us-west-2"): + """Verify standard AWS service calls were made correctly.""" + # Verify model validation was called + mock_validate_model.assert_called_once_with(model_id, region) + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name=region) + + # Verify invoke_model was called + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.assert_called_once() + + return mock_client_instance + + +def verify_file_operations(mock_os_makedirs, mock_file_open): + """Verify file system operations were performed correctly.""" + mock_os_makedirs.assert_called_once() + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + +def verify_request_body_amazon_model(mock_client_instance, expected_prompt, expected_seed=123, expected_cfg_scale=10): + """Verify request body structure for Amazon Titan models.""" + args, kwargs = mock_client_instance.invoke_model.call_args + request_body = json.loads(kwargs["body"]) + assert request_body["taskType"] == "TEXT_IMAGE" + assert request_body["textToImageParams"]["text"] == expected_prompt + assert request_body["imageGenerationConfig"]["seed"] == expected_seed + assert request_body["imageGenerationConfig"]["cfgScale"] == expected_cfg_scale + assert request_body["imageGenerationConfig"]["numberOfImages"] == 1 + return request_body + + +def verify_request_body_stability_model(mock_client_instance, expected_prompt, expected_seed=123): + """Verify request body structure for Stability AI models.""" + args, kwargs = mock_client_instance.invoke_model.call_args + request_body = json.loads(kwargs["body"]) + assert request_body["prompt"] == expected_prompt + assert request_body["seed"] == expected_seed + return request_body + + +@pytest.fixture +def mock_validate_model(): + """Mock the validate_model_in_region function.""" + with patch("strands_tools.generate_image.validate_model_in_region") as mock_validate: + # Default to valid model and provide a list of available models + mock_validate.return_value = (True, ["amazon.titan-image-generator-v2:0", "stability.stable-image-ultra-v1:1"]) + yield mock_validate + + @pytest.fixture def mock_boto3_client(): """Mock boto3 client for testing.""" with patch("boto3.client") as mock_client: - # Set up mock response - mock_body = MagicMock() - mock_body.read.return_value = json.dumps( - {"artifacts": [{"base64": base64.b64encode(b"mock_image_data").decode("utf-8")}]} - ).encode("utf-8") - mock_client_instance = MagicMock() - mock_client_instance.invoke_model.return_value = {"body": mock_body} mock_client.return_value = mock_client_instance - yield mock_client @@ -68,37 +147,519 @@ def mock_file_open(): yield mock_open, mock_file -def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_makedirs, mock_file_open): - """Test direct invocation of the generate_image tool.""" - # Create a tool use dictionary similar to how the agent would call it - tool_use = { +@pytest.fixture +def stability_model_response(): + """Mock response for Stability AI models.""" + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": ["base64_encoded_image_data"]}).encode("utf-8") + return mock_body + + +@pytest.fixture +def amazon_model_v1_response(): + """Mock response for Amazon Titan v1 models.""" + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": ["base64_encoded_image_data"]}).encode("utf-8") + return mock_body + + +@pytest.fixture +def amazon_model_v2_response(): + """Mock response for Amazon Titan v2 models.""" + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": [{"imageBase64": "base64_encoded_image_data"}]}).encode("utf-8") + return mock_body + + +@pytest.fixture +def basic_tool_use(): + """Basic tool use object for testing.""" + return { "toolUseId": "test-tool-use-id", "input": { "prompt": "A cute robot", "seed": 123, - "steps": 30, "cfg_scale": 10, - "style_preset": "photographic", }, } - # Call the generate_image function directly + +@contextmanager +def mock_base64_decode(): + """Context manager for mocking base64.b64decode.""" + with patch("base64.b64decode", return_value=b"decoded_image_data"): + yield + + +def test_generate_image_stability_model( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + stability_model_response, + basic_tool_use, +): + """Test direct invocation of the generate_image tool with a Stability AI model.""" + # Update tool use with stability model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "stability.stable-image-ultra-v1:1" + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + # Call the generate_image function directly + result = generate_image.generate_image(tool=tool_use) + + # Verify AWS service calls + mock_client_instance = verify_aws_service_calls( + mock_boto3_client, mock_validate_model, "stability.stable-image-ultra-v1:1" + ) + + # Verify request body for Stability model + verify_request_body_stability_model(mock_client_instance, "A cute robot", 123) + + # Verify file operations + verify_file_operations(mock_os_makedirs, mock_file_open) + + # Verify success response structure + verify_success_response_structure(result) + + +def test_generate_image_amazon_model( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + amazon_model_v2_response, + basic_tool_use, +): + """Test direct invocation of the generate_image tool with an Amazon Titan model.""" + # Update tool use with Amazon model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.titan-image-generator-v2:0" + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": amazon_model_v2_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + # Call the generate_image function directly + result = generate_image.generate_image(tool=tool_use) + + # Verify AWS service calls + mock_client_instance = verify_aws_service_calls( + mock_boto3_client, mock_validate_model, "amazon.titan-image-generator-v2:0" + ) + + # Verify request body for Amazon model + verify_request_body_amazon_model(mock_client_instance, "A cute robot", 123, 10) + + # Verify file operations + verify_file_operations(mock_os_makedirs, mock_file_open) + + # Verify success response structure + verify_success_response_structure(result) + + +def test_generate_image_auto_model_selection( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + amazon_model_v2_response, + basic_tool_use, +): + """Test automatic model selection when no model_id is provided.""" + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": amazon_model_v2_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + # Call the generate_image function directly + result = generate_image.generate_image(tool=basic_tool_use) + + # Verify model validation was called for auto-selection (empty model_id) + mock_validate_model.assert_any_call("", "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Verify request body for Amazon model (auto-selected) + verify_request_body_amazon_model(mock_client_instance, "A cute robot", 123, 10) + + # Verify file operations + verify_file_operations(mock_os_makedirs, mock_file_open) + + # Verify success response structure + verify_success_response_structure(result) + + +def test_generate_image_model_validation_error(mock_validate_model, basic_tool_use): + """Test error handling when model validation fails.""" + # Setup validation to fail + mock_validate_model.return_value = (False, ["amazon.titan-image-generator-v2:0"]) + + # Update tool use with invalid model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "stability.invalid-model" + result = generate_image.generate_image(tool=tool_use) - # Verify the function was called with correct parameters - mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + # Verify model validation was called with the invalid model + mock_validate_model.assert_called_once_with("stability.invalid-model", "us-west-2") + + # Verify error response structure and content + error_text = verify_error_response_structure(result) + assert "not available in region" in error_text + assert "stability.invalid-model" in error_text + assert "us-west-2" in error_text + assert "amazon.titan-image-generator-v2:0" in error_text # Available models should be listed + + +def test_generate_image_legacy_model_error(mock_validate_model, basic_tool_use): + """Test error handling when a legacy model is requested.""" + # Setup validation to raise ValueError for legacy model + mock_validate_model.side_effect = ValueError("Model is in LEGACY status") + + # Update tool use with legacy model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "stability.legacy-model" + + result = generate_image.generate_image(tool=tool_use) + + # Verify model validation was called with the legacy model + mock_validate_model.assert_called_once_with("stability.legacy-model", "us-west-2") + + # Verify error response structure and content + error_text = verify_error_response_structure(result) + assert "Model is in LEGACY status" in error_text + + +def test_generate_image_access_denied(mock_boto3_client, mock_validate_model, basic_tool_use): + """Test error handling when access is denied to the model.""" + # Setup boto3 client to raise an AccessDeniedException mock_client_instance = mock_boto3_client.return_value - mock_client_instance.invoke_model.assert_called_once() + access_denied_exception = Exception("AccessDeniedException: Access denied") + mock_client_instance.invoke_model.side_effect = access_denied_exception - # Check the parameters passed to invoke_model - args, kwargs = mock_client_instance.invoke_model.call_args - request_body = json.loads(kwargs["body"]) + # Update tool use with model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.titan-image-generator-v2:0" + + result = generate_image.generate_image(tool=tool_use) + + # Verify AWS service calls were attempted + verify_aws_service_calls(mock_boto3_client, mock_validate_model, "amazon.titan-image-generator-v2:0") + + # Verify error response structure and content + error_text = verify_error_response_structure(result) + assert "Error generating image" in error_text + + +def test_validate_model_in_region(): + """Test the validate_model_in_region function.""" + with patch("boto3.client") as mock_client: + # Mock the list_foundation_models response + mock_client_instance = MagicMock() + mock_client_instance.list_foundation_models.return_value = { + "modelSummaries": [ + { + "modelId": "amazon.titan-image-generator-v2:0", + "inputModalities": ["TEXT"], + "outputModalities": ["IMAGE"], + "modelLifecycle": {"status": "ACTIVE"}, + "inferenceTypesSupported": ["ON_DEMAND"], + }, + { + "modelId": "stability.stable-image-ultra-v1:1", + "inputModalities": ["TEXT"], + "outputModalities": ["IMAGE"], + "modelLifecycle": {"status": "ACTIVE"}, + "inferenceTypesSupported": ["ON_DEMAND"], + }, + { + "modelId": "stability.stable-diffusion-xl-v1", + "inputModalities": ["TEXT"], + "outputModalities": ["IMAGE"], + "modelLifecycle": {"status": "LEGACY"}, + "inferenceTypesSupported": ["ON_DEMAND"], + }, + ] + } + mock_client.return_value = mock_client_instance + + # Test valid model + is_valid, available_models = generate_image.validate_model_in_region( + "amazon.titan-image-generator-v2:0", "us-west-2" + ) + assert is_valid is True + assert "amazon.titan-image-generator-v2:0" in available_models + + # Test invalid model + is_valid, available_models = generate_image.validate_model_in_region("invalid.model", "us-west-2") + assert is_valid is False + + # Test legacy model + with pytest.raises(ValueError) as excinfo: + generate_image.validate_model_in_region("stability.stable-diffusion-xl-v1", "us-west-2") + assert "LEGACY status" in str(excinfo.value) + + +def test_create_filename(): + """Test the create_filename function.""" + # Test the actual function from the module + filename = generate_image.create_filename("A cute robot dancing in the rain") + assert filename == "a_cute_robot_dancing_in" + + # Test with special characters + filename = generate_image.create_filename("A cute robot! With @#$% special chars") + assert filename == "a_cute_robot_with_special" + + # Test long prompt + long_prompt = "This is a very long prompt " + "word " * 50 + filename = generate_image.create_filename(long_prompt) + assert len(filename) <= 100 + + +def test_generate_image_via_agent( + agent, mock_boto3_client, mock_validate_model, mock_os_path_exists, mock_os_makedirs, mock_file_open +): + """Test image generation via the agent interface.""" + # Set up mock response for agent test + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": ["base64_encoded_image_data"]}).encode("utf-8") + + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": mock_body} + + # This simulates how the tool would be used through the Agent interface + # We mock the agent's tool method to return a comprehensive response + mock_response = { + "toolUseId": "test-tool-use-id", + "status": "success", + "content": [ + {"text": "The generated image has been saved locally to output/test_via_agent.png"}, + {"image": {"format": "png", "source": {"bytes": b"decoded_image_data"}}}, + ], + } + + with patch.object(agent.tool, "generate_image", return_value=mock_response) as mock_generate: + result = agent.tool.generate_image(prompt="Test via agent") + + # Verify the agent tool method was called with correct parameters + mock_generate.assert_called_once_with(prompt="Test via agent") + + # Extract and verify result content + result_text = extract_result_text(result) + assert "The generated image has been saved locally" in result_text + + # Verify complete result structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert len(result["content"]) == 2 + + # Verify text content + assert "The generated image has been saved locally" in result["content"][0]["text"] + + # Verify image content structure + image_content = result["content"][1] + assert "image" in image_content + assert image_content["image"]["format"] == "png" + assert isinstance(image_content["image"]["source"]["bytes"], bytes) + assert image_content["image"]["source"]["bytes"] == b"decoded_image_data" + + +@pytest.mark.parametrize( + "model_id,response_fixture", + [ + ("stability.stable-image-ultra-v1:1", "stability_model_response"), + ("amazon.titan-image-generator-v1", "amazon_model_v1_response"), + ("amazon.titan-image-generator-v2:0", "amazon_model_v2_response"), + ], +) +def test_model_response_parsing( + model_id, + response_fixture, + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + request, +): + """Test parsing of responses from different model families.""" + # Get the response fixture + mock_body = request.getfixturevalue(response_fixture) + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": mock_body} + + # Create tool use with the specified model + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": model_id, + }, + } + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) + + # Verify model validation was called + mock_validate_model.assert_called_once_with(model_id, "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Verify complete success response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 # Text and image content + + # Verify text content + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + assert "The generated image has been saved locally" in result["content"][0]["text"] + + # Verify image content + assert isinstance(result["content"][1], dict) + assert "image" in result["content"][1] + assert result["content"][1]["image"]["format"] == "png" + assert isinstance(result["content"][1]["image"]["source"]["bytes"], bytes) + + # Verify file operations + mock_os_makedirs.assert_called_once() + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + +def test_response_parsing_error(mock_boto3_client, mock_validate_model): + """Test error handling for unexpected response formats.""" + # Set up mock response with unexpected format + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"unexpected_format": "data"}).encode("utf-8") + + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": mock_body} + + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + + result = generate_image.generate_image(tool=tool_use) + + # Verify AWS service calls were attempted + verify_aws_service_calls(mock_boto3_client, mock_validate_model, "amazon.titan-image-generator-v2:0") + + # Verify error response structure and content + error_text = verify_error_response_structure(result) + assert "Error generating image" in error_text + + +def test_missing_prompt_error(): + """Test handling when prompt is missing - should use default prompt.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + # No prompt provided - should use default + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + + # Mock the necessary components to avoid actual AWS calls + with patch("strands_tools.generate_image.validate_model_in_region", return_value=(True, [])): + with patch("boto3.client") as mock_client: + mock_client_instance = MagicMock() + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": [{"imageBase64": "base64_data"}]}).encode("utf-8") + mock_client_instance.invoke_model.return_value = {"body": mock_body} + mock_client.return_value = mock_client_instance + + with patch("os.path.exists", return_value=False): + with patch("os.makedirs"): + with patch("builtins.open", MagicMock()): + with patch("base64.b64decode", return_value=b"image_data"): + result = generate_image.generate_image(tool=tool_use) - assert request_body["text_prompts"][0]["text"] == "A cute robot" - assert request_body["seed"] == 123 - assert request_body["steps"] == 30 - assert request_body["cfg_scale"] == 10 - assert request_body["style_preset"] == "photographic" + # Verify complete success response structure (should succeed with default prompt) + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 # Text and image content + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify success message + success_text = result["content"][0]["text"] + assert "The generated image has been saved locally" in success_text + + # Verify image content + assert isinstance(result["content"][1], dict) + assert "image" in result["content"][1] + assert result["content"][1]["image"]["format"] == "png" + + +def test_file_path_construction( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + stability_model_response, +): + """Test that file paths are constructed correctly.""" + # Create a tool use dictionary + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot dancing", + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) + + # Verify model validation was called + mock_validate_model.assert_called_once_with("amazon.titan-image-generator-v2:0", "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() # Verify directory creation mock_os_makedirs.assert_called_once() @@ -107,79 +668,469 @@ def test_generate_image_direct(mock_boto3_client, mock_os_path_exists, mock_os_m mock_open, mock_file = mock_file_open mock_file.write.assert_called_once() - # Check the result + # Check that the file path contains the expected filename pattern + file_path_arg = mock_open.call_args[0][0] + assert "output" in file_path_arg + assert "a_cute_robot_dancing" in file_path_arg + assert file_path_arg.endswith(".png") + + # Verify successful result assert result["toolUseId"] == "test-tool-use-id" assert result["status"] == "success" assert "The generated image has been saved locally" in result["content"][0]["text"] - assert result["content"][1]["image"]["format"] == "png" - assert isinstance(result["content"][1]["image"]["source"]["bytes"], bytes) -def test_generate_image_default_params(mock_boto3_client, mock_os_path_exists, mock_os_makedirs, mock_file_open): - """Test generate_image with default parameters.""" - tool_use = {"toolUseId": "test-tool-use-id", "input": {"prompt": "A cute robot"}} +def test_custom_region_handling( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + stability_model_response, +): + """Test handling of custom region specification.""" + # Create a tool use dictionary with custom region + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "amazon.titan-image-generator-v2:0", + "region": "us-east-1", + }, + } + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} - with patch("random.randint", return_value=42): + # Mock base64 decode to avoid errors + with mock_base64_decode(): result = generate_image.generate_image(tool=tool_use) - # Check the default parameters were used + # Verify model validation was called with custom region + mock_validate_model.assert_called_once_with("amazon.titan-image-generator-v2:0", "us-east-1") + + # Verify the region was used correctly for client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-east-1") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Verify file operations + mock_os_makedirs.assert_called_once() + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + # Verify successful result + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert "The generated image has been saved locally" in result["content"][0]["text"] + + +def test_environment_variable_region( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + stability_model_response, +): + """Test handling of region from environment variable.""" + # Create a tool use dictionary without region + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + + # Set up mock response mock_client_instance = mock_boto3_client.return_value - args, kwargs = mock_client_instance.invoke_model.call_args - request_body = json.loads(kwargs["body"]) + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} + + # Mock environment variable + with patch.dict(os.environ, {"AWS_REGION": "eu-west-1"}): + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) - assert request_body["seed"] == 42 # From our mocked random.randint - assert request_body["steps"] == 30 - assert request_body["cfg_scale"] == 10 - assert request_body["style_preset"] == "photographic" + # Verify model validation was called with environment region + mock_validate_model.assert_called_once_with("amazon.titan-image-generator-v2:0", "eu-west-1") + # Verify the region from environment variable was used + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="eu-west-1") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Verify file operations + mock_os_makedirs.assert_called_once() + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + # Verify successful result + assert result["toolUseId"] == "test-tool-use-id" assert result["status"] == "success" + assert "The generated image has been saved locally" in result["content"][0]["text"] + + +def test_duplicate_filename_handling( + mock_boto3_client, mock_validate_model, mock_os_makedirs, mock_file_open, stability_model_response +): + """Test handling of duplicate filenames.""" + # Mock os.path.exists to simulate existing files + with patch("os.path.exists") as mock_exists: + # First return False for output directory check, then True for all file checks + mock_exists.side_effect = [False, True, True, True, False] + + # Create a tool use dictionary + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) + + # Verify model validation was called + mock_validate_model.assert_called_once_with("amazon.titan-image-generator-v2:0", "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Verify directory creation + mock_os_makedirs.assert_called_once() + + # Verify file operations - should try multiple filenames + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + file_path_arg = mock_open.call_args[0][0] + + # Should have a number appended to the filename due to duplicates + assert "_3" in file_path_arg + assert "output" in file_path_arg + assert file_path_arg.endswith(".png") + + # Verify successful result + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert "The generated image has been saved locally" in result["content"][0]["text"] + + +def test_client_creation_error(): + """Test error handling when client creation fails.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "amazon.titan-image-generator-v2:0", + }, + } + # Mock validate_model_in_region to avoid model validation errors + with patch("strands_tools.generate_image.validate_model_in_region", return_value=(True, [])): + # Mock boto3.client to raise an exception + with patch("boto3.client", side_effect=Exception("Failed to create client")) as mock_client: + result = generate_image.generate_image(tool=tool_use) -def test_generate_image_error_handling(mock_boto3_client): - """Test error handling in generate_image.""" - # Setup boto3 client to raise an exception + # Verify boto3.client was called and failed + mock_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify error message indicates client creation failure + error_text = result["content"][0]["text"] + assert "Error generating image" in error_text + + +def test_unsupported_model_family(): + """Test error handling for unsupported model families.""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "unsupported.model-family-v1:0", + }, + } + + # Mock validate_model_in_region to return valid (to pass validation) + with patch("strands_tools.generate_image.validate_model_in_region", return_value=(True, [])): + with patch("boto3.client") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + + result = generate_image.generate_image(tool=tool_use) + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify error message indicates unsupported model + error_text = result["content"][0]["text"] + assert "Unsupported model" in error_text + + +def test_validation_exception_model_not_found(mock_boto3_client, mock_validate_model, basic_tool_use): + """Test error handling for ValidationException when model is not found.""" + # Setup boto3 client to raise ValidationException mock_client_instance = mock_boto3_client.return_value - mock_client_instance.invoke_model.side_effect = Exception("API error") - tool_use = {"toolUseId": "test-tool-use-id", "input": {"prompt": "A cute robot"}} + # Create a mock ValidationException + validation_exception = ClientError( + error_response={"Error": {"Code": "ValidationException", "Message": "Model not found"}}, + operation_name="InvokeModel", + ) + mock_client_instance.invoke_model.side_effect = validation_exception + + # Update tool use with model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.nonexistent-model-v1:0" result = generate_image.generate_image(tool=tool_use) - # Verify error handling + # Verify model validation was called + mock_validate_model.assert_called_once_with("amazon.nonexistent-model-v1:0", "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called and raised the exception + mock_client_instance.invoke_model.assert_called_once() + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" assert result["status"] == "error" - assert "Error generating image: API error" in result["content"][0]["text"] + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + # Verify error message content + error_text = result["content"][0]["text"] + assert "not found" in error_text + assert "amazon.nonexistent-model-v1:0" in error_text -def test_filename_creation(): - """Test the filename creation logic using regex patterns similar to create_filename.""" - # Since create_filename is defined inside the function, we'll replicate its functionality - def create_filename_test(prompt: str) -> str: - import re +def test_bedrock_client_error_in_validation(basic_tool_use): + """Test error handling when Bedrock client fails during model validation.""" + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.titan-image-generator-v2:0" - words = re.findall(r"\w+", prompt.lower())[:5] - filename = "_".join(words) - filename = re.sub(r"[^\w\-_\.]", "_", filename) - return filename[:100] + # Mock validate_model_in_region to raise a generic exception + with patch("strands_tools.generate_image.validate_model_in_region", side_effect=Exception("Bedrock API error")): + result = generate_image.generate_image(tool=tool_use) - # Test normal prompt - filename = create_filename_test("A cute robot dancing in the rain") - assert filename == "a_cute_robot_dancing_in" + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] - # Test prompt with special characters - filename = create_filename_test("A cute robot! With @#$% special chars") - assert filename == "a_cute_robot_with_special" + # Verify error message content + error_text = result["content"][0]["text"] + assert "Could not validate model availability" in error_text + assert "us-west-2" in error_text - # Test long prompt - long_prompt = "This is a very long prompt " + "word " * 50 - filename = create_filename_test(long_prompt) - assert len(filename) <= 100 +def test_auto_model_selection_no_models_available(basic_tool_use): + """Test auto model selection when no models are available in region.""" + # Remove model_id to trigger auto-selection + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + # Don't set model_id to trigger auto-selection -def test_generate_image_via_agent(agent, mock_boto3_client, mock_os_path_exists, mock_os_makedirs, mock_file_open): - """Test image generation via the agent interface.""" - # This simulates how the tool would be used through the Agent interface - result = agent.tool.generate_image(prompt="Test via agent") + # Mock validate_model_in_region to return no available models + with patch("strands_tools.generate_image.validate_model_in_region", return_value=(False, [])): + result = generate_image.generate_image(tool=tool_use) + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify error message content + error_text = result["content"][0]["text"] + assert "No text-to-image models available" in error_text + assert "us-west-2" in error_text + + +def test_auto_model_selection_validation_error(basic_tool_use): + """Test auto model selection when validation fails.""" + # Remove model_id to trigger auto-selection + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + # Don't set model_id to trigger auto-selection + + # Mock validate_model_in_region to raise an exception during auto-selection + with patch("strands_tools.generate_image.validate_model_in_region", side_effect=Exception("API error")): + result = generate_image.generate_image(tool=tool_use) + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify error message content + error_text = result["content"][0]["text"] + assert "Error determining available models" in error_text + assert "us-west-2" in error_text + assert "specify a model_id explicitly" in error_text + + +def test_amazon_model_v1_response_format( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + basic_tool_use, +): + """Test parsing Amazon Titan v1 response format (string array).""" + # Set up mock response for Amazon Titan v1 (string format) + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"images": ["base64_encoded_image_data"]}).encode("utf-8") + + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": mock_body} + + # Update tool use with Amazon model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.titan-image-generator-v1" + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) + + # Verify model validation was called + mock_validate_model.assert_called_once_with("amazon.titan-image-generator-v1", "us-west-2") + + # Verify boto3 client creation + mock_boto3_client.assert_called_once_with("bedrock-runtime", region_name="us-west-2") + + # Verify invoke_model was called + mock_client_instance.invoke_model.assert_called_once() + + # Check the parameters passed to invoke_model + args, kwargs = mock_client_instance.invoke_model.call_args + assert kwargs["modelId"] == "amazon.titan-image-generator-v1" + + request_body = json.loads(kwargs["body"]) + assert request_body["taskType"] == "TEXT_IMAGE" + assert request_body["textToImageParams"]["text"] == "A cute robot" + + # Verify file operations + mock_os_makedirs.assert_called_once() + mock_open, mock_file = mock_file_open + mock_file.write.assert_called_once() + + # Check the result + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success" + assert "The generated image has been saved locally" in result["content"][0]["text"] + assert result["content"][1]["image"]["format"] == "png" + assert isinstance(result["content"][1]["image"]["source"]["bytes"], bytes) + + +def test_image_extraction_key_error(mock_boto3_client, mock_validate_model, basic_tool_use): + """Test error handling when image extraction fails due to KeyError.""" + # Set up mock response with missing keys + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"no_images_key": "data"}).encode("utf-8") + + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": mock_body} + + # Update tool use with model + tool_use = basic_tool_use.copy() + tool_use["input"] = tool_use["input"].copy() + tool_use["input"]["model_id"] = "amazon.titan-image-generator-v2:0" + + result = generate_image.generate_image(tool=tool_use) + + # Verify complete error response structure + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "error" + assert isinstance(result["content"], list) + assert len(result["content"]) == 1 + assert isinstance(result["content"][0], dict) + assert "text" in result["content"][0] + + # Verify specific error message content (based on actual implementation behavior) + error_text = result["content"][0]["text"] + assert "Error generating image" in error_text + assert "Unexpected Amazon model response structure" in error_text + + +def test_custom_parameters_validation( + mock_boto3_client, + mock_validate_model, + mock_os_path_exists, + mock_os_makedirs, + mock_file_open, + stability_model_response, +): + """Test that custom parameters (seed, cfg_scale) are properly used.""" + # Create a tool use dictionary with custom parameters + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "prompt": "A cute robot", + "model_id": "stability.stable-image-ultra-v1:1", + "seed": 12345, + "cfg_scale": 15, + }, + } + + # Set up mock response + mock_client_instance = mock_boto3_client.return_value + mock_client_instance.invoke_model.return_value = {"body": stability_model_response} + + # Mock base64 decode to avoid errors + with mock_base64_decode(): + result = generate_image.generate_image(tool=tool_use) + + # Check the parameters passed to invoke_model + args, kwargs = mock_client_instance.invoke_model.call_args + assert kwargs["modelId"] == "stability.stable-image-ultra-v1:1" + + request_body = json.loads(kwargs["body"]) + assert request_body["prompt"] == "A cute robot" + assert request_body["seed"] == 12345 # Custom seed should be used - result_text = extract_result_text(result) - assert "The generated image has been saved locally" in result_text + # Verify successful result + assert result["toolUseId"] == "test-tool-use-id" + assert result["status"] == "success"