diff --git a/app/services/schemas.py b/app/services/schemas.py index 1cc3c3c2..ebc611ae 100644 --- a/app/services/schemas.py +++ b/app/services/schemas.py @@ -8,7 +8,7 @@ class User(BaseModel): id: str fullName: str email: str - + class Role(str, Enum): human = "human" ai = "ai" @@ -28,7 +28,7 @@ class Message(BaseModel): type: MessageType timestamp: Optional[Any] = None payload: MessagePayload - + class RequestType(str, Enum): chat = "chat" tool = "tool" @@ -36,22 +36,22 @@ class RequestType(str, Enum): class GenericRequest(BaseModel): user: User type: RequestType - + class ChatRequest(GenericRequest): messages: List[Message] class GenericAssistantRequest(BaseModel): assistant_inputs: AssistantInputs - + class ToolRequest(GenericRequest): tool_data: BaseTool - + class ChatResponse(BaseModel): data: List[Message] class ToolResponse(BaseModel): data: Any - + class ChatMessage(BaseModel): role: str type: str @@ -67,7 +67,7 @@ class QuizzifyArgs(BaseModel): class WorksheetQuestion(BaseModel): question_type: str number: int - + class WorksheetQuestionModel(BaseModel): worksheet_question_list: List[WorksheetQuestion] @@ -78,7 +78,7 @@ class WorksheetGeneratorArgs(BaseModel): file_url: str file_type: str lang: Optional[str] = "en" - + class SyllabusGeneratorArgsModel(BaseModel): grade_level: str subject: str @@ -92,20 +92,20 @@ class SyllabusGeneratorArgsModel(BaseModel): file_url: str file_type: str lang: Optional[str] = "en" - + class AIResistantArgs(BaseModel): assignment: str = Field(..., max_length=255, description="The given assignment") grade_level: Literal["pre-k", "kindergarten", "elementary", "middle", "high", "university", "professional"] = Field(..., description="Educational level to which the content is directed") file_type: str = Field(..., description="Type of file being handled, according to the defined enumeration") file_url: str = Field(..., description="URL or path of the file to be processed") lang: str = Field(..., description="Language in which the file or content is written") - + class ConnectWithThemArgs(BaseModel): grade_level: str = Field(..., description="The grade level the teacher is instructing.") task_description: str = Field(..., description="A brief description of the subject or topic the teacher is instructing.") students_description: str = Field(..., description="A description of the students including age group, interests, location, and any relevant cultural or social factors.") - task_description_file_url: str - task_description_file_type: str + task_description_file_url: str + task_description_file_type: str student_description_file_url: str student_description_file_type: str lang: str = Field(..., description="The language in which the subject is being taught.") @@ -175,4 +175,10 @@ class SlideGeneratorInput(BaseModel): slides_titles: List[str] instructional_level: str topic: str - lang: Optional[str] = "en" \ No newline at end of file + lang: Optional[str] = "en" + +class ImageGeneratorArgs(BaseModel): + prompt: str = Field(..., description="The text prompt to generate an image from") + subject: Optional[str] = Field(None, description="The educational subject (e.g., 'math', 'science')") + grade_level: Optional[str] = Field(None, description="The grade level (e.g., 'elementary', 'middle school', 'high school')") + lang: str = Field("en", description="The language for text in the image") \ No newline at end of file diff --git a/app/tools/image_generator/README.md b/app/tools/image_generator/README.md new file mode 100644 index 00000000..1ba41237 --- /dev/null +++ b/app/tools/image_generator/README.md @@ -0,0 +1,159 @@ +# Image Generator + +This tool generates high-quality educational images from text prompts using Black Forest Labs' Flux 1.1 Pro model and automatically stores them in Google Cloud Storage for persistent access. + +## Features + +- Generate educational images from text prompts +- Enhance prompts with educational context +- Safety filtering to ensure appropriate content +- Integration with Black Forest Labs Flux 1.1 Pro API +- Automatic storage in Google Cloud Storage (when configured) +- Content type detection (diagrams, concepts, processes, etc.) + +## Setup + +1. Install the required dependencies: + ``` + # From the marvel-ai-backend directory + pip install -r requirements.txt + ``` + + Note: All required dependencies are included in the main project's requirements.txt file. + +2. Set up your Black Forest Labs API key in the .env file: + + Add the following line to your `.env` file in the `marvel-ai-backend/app/` directory: + ``` + BFL_API_KEY=your_api_key_here + ``` + + You can obtain an API key by registering at [api.bfl.ml](https://api.bfl.ml/). + +3. Set up Google Cloud Storage for image persistence: + + a. Create a storage bucket in the GCP project associated with the PROJECT_ID environment variable in your .env file (see GCP Storage Configuration below) + + b. Add the following to your `.env` file: + ``` + GCP_STORAGE_BUCKET=your-gcp-bucket-name + GOOGLE_APPLICATION_CREDENTIALS=/absolute/path/to/your/credentials.json + ``` + +4. When running in Docker, mount the credentials file: + + ```bash + docker run \ + -v /path/to/credentials.json:/app/credentials.json:ro \ + -e GOOGLE_APPLICATION_CREDENTIALS=/app/credentials.json \ + -p 8000:8000 \ + --env-file ./app/.env \ + your-image-name + ``` + +## Usage + +### API Request Format + +```json +{ + "user": { + "id": "string", + "fullName": "string", + "email": "string" + }, + "type": "tool", + "tool_data": { + "tool_id": "image-generator", + "inputs": [ + { + "name": "prompt", + "value": "A diagram of the solar system" + }, + { + "name": "subject", + "value": "astronomy" + }, + { + "name": "grade_level", + "value": "middle school" + }, + { + "name": "lang", + "value": "en" + } + ] + } +} +``` + +### Input Parameters + +- `prompt` (required): The text prompt to generate an image from +- `subject` (optional): The educational subject (e.g., 'math', 'science') +- `grade_level` (optional): The grade level (e.g., 'elementary', 'middle school', 'high school') +- `lang` (optional, default: "en"): The language for text in the image + +### Response Format + +```json +{ + "image_b64": "base64_encoded_image_data", + "prompt_used": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level", + "safety_applied": true, + "gcp_url": "https://storage.googleapis.com/your-bucket/generated_images/image_20250422_123456_solar_system_abcd1234.png" +} +``` + +The `gcp_url` field will be included if GCP storage is configured and the image was successfully uploaded. + +## Implementation Details + +The image generator uses Black Forest Labs' Flux 1.1 Pro model, which is a state-of-the-art text-to-image model. The tool enhances the prompt with educational context and applies safety filtering to ensure the generated images are appropriate for educational use. + +## Dependencies + +- requests +- Pillow +- langchain-google-genai +- pydantic +- google-cloud-storage (for GCP integration) + +## GCP Storage Configuration + +### Creating a GCP Bucket + +1. Go to the [Google Cloud Console](https://console.cloud.google.com/) +2. Navigate to "Cloud Storage" > "Buckets" +3. Click "CREATE BUCKET" +4. Enter a globally unique name +5. Choose your preferred region +6. Set access control to "Fine-grained" +7. Click "CREATE" + +### Setting Bucket Permissions + +1. Click on your newly created bucket +2. Go to the "Permissions" tab +3. Click "GRANT ACCESS" +4. Enter `allUsers` in the "New principals" field +5. Select "Cloud Storage" > "Storage Object Viewer" for the role +6. Click "SAVE" + +### Creating a Service Account + +1. Navigate to "IAM & Admin" > "Service Accounts" +2. Click "CREATE SERVICE ACCOUNT" +3. Enter a name and description +4. Add the "Storage Object Admin" role +5. Create a key (JSON format) +6. Download the key file + +### Troubleshooting GCP Storage + +- Check that your service account has the correct permissions +- Verify that the credentials file path is correct and accessible +- Ensure the bucket exists and is publicly readable +- Check the logs for detailed error messages +- When using Docker, make sure the credentials file is mounted correctly diff --git a/app/tools/image_generator/__init__.py b/app/tools/image_generator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tools/image_generator/core.py b/app/tools/image_generator/core.py new file mode 100644 index 00000000..8aec479b --- /dev/null +++ b/app/tools/image_generator/core.py @@ -0,0 +1,69 @@ +from app.services.logger import setup_logger +from app.tools.image_generator.tools import ImageGenerator, ImageGeneratorArgs +from app.api.error_utilities import ImageHandlerError, ToolExecutorError + +logger = setup_logger(__name__) + +def executor( + prompt: str, + subject: str = None, + grade_level: str = None, + lang: str = "en", + verbose: bool = False +): + """ + Executor function for the Image Generator tool. + + Args: + prompt (str): The text prompt to generate an image from. + subject (str, optional): The educational subject (e.g., 'math', 'science'). + grade_level (str, optional): The grade level (e.g., 'elementary', 'middle school', 'high school'). + lang (str, optional): The language for text in the image. Defaults to "en". + verbose (bool, optional): Flag for verbose logging. Defaults to False. + + Returns: + dict: Generated image data including base64 encoded image and metadata. + If GCP storage is configured, the result will also include a gcp_url field. + + Raises: + ToolExecutorError: If there's an error in the image generation process. + """ + try: + if verbose: + logger.info(f"Generating image with prompt: {prompt}") + if subject: + logger.info(f"Subject: {subject}") + if grade_level: + logger.info(f"Grade level: {grade_level}") + logger.info(f"Language: {lang}") + + # Create arguments for the image generator + image_generator_args = ImageGeneratorArgs( + prompt=prompt, + subject=subject, + grade_level=grade_level, + lang=lang + ) + + # Initialize the image generator + generator = ImageGenerator(args=image_generator_args, verbose=verbose) + + # Generate the image + result = generator.generate_educational_image() + + # Log success + logger.info(f"Image generated successfully for prompt: {prompt}") + + # Return the result as a dictionary + # Use model_dump() instead of dict() for Pydantic v2 compatibility + return result.model_dump() + + except ImageHandlerError as e: + error_message = str(e) + logger.error(f"Image Handler Error: {error_message}") + raise ToolExecutorError(error_message) + + except Exception as e: + error_message = f"Error in Image Generator: {str(e)}" + logger.error(error_message) + raise ToolExecutorError(error_message) \ No newline at end of file diff --git a/app/tools/image_generator/metadata.json b/app/tools/image_generator/metadata.json new file mode 100644 index 00000000..dd748798 --- /dev/null +++ b/app/tools/image_generator/metadata.json @@ -0,0 +1,36 @@ +{ + "name": "Image Generator", + "description": "Generate educational images from text prompts using Black Forest Labs Flux 1.1 Pro API.", + "version": "1.0.0", + "inputs": [ + { + "name": "prompt", + "type": "string", + "description": "The text prompt to generate an image from", + "required": true + }, + { + "name": "subject", + "type": "string", + "description": "The educational subject (e.g., 'math', 'science')", + "required": false + }, + { + "name": "grade_level", + "type": "string", + "description": "The grade level (e.g., 'elementary', 'middle school', 'high school')", + "required": false + }, + { + "name": "lang", + "type": "string", + "description": "The language for text in the image", + "required": false, + "default": "en" + } + ], + "output": { + "type": "object", + "description": "Generated image data including base64 encoded image, metadata, and optional GCP storage URL" + } +} diff --git a/app/tools/image_generator/prompt/image-generator-prompt.txt b/app/tools/image_generator/prompt/image-generator-prompt.txt new file mode 100644 index 00000000..99001fb5 --- /dev/null +++ b/app/tools/image_generator/prompt/image-generator-prompt.txt @@ -0,0 +1,28 @@ +You are an expert educational visual designer specializing in creating high-quality images for classroom instruction. Your task is to generate clear, precise, pedagogically effective and high-quality images based on the provided prompt. + +INSTRUCTIONS: +1. CLARITY: Create images with clear visual hierarchy, proper labeling, and appropriate text size for classroom visibility. +2. EDUCATIONAL ACCURACY: Ensure all content is factually correct and aligned with educational standards. +3. PEDAGOGICAL EFFECTIVENESS: Design images that support specific learning objectives and cognitive processes and are helpful for teaching or learning the subject matter. +4. ACCESSIBILITY: Use high contrast, colorblind-friendly palettes, and clear distinctions between elements. +5. AGE APPROPRIATENESS: Adjust complexity and style to match the developmental stage of the specified grade level. +6. SAFETY: Ensure the image does not contain any inappropriate content. +7. FOCUS: Keep the design clean and free of unnecessary elements by focusing on the core learning objective. +8. TEXT CORRECTNESS: Ensure that all text is correctly spelled and grammatically correct. + +PROMPT: {prompt} + +EDUCATIONAL CONTEXT: {educational_context} + +LANGUAGE: {lang} + +DESIGN CHECKLIST: +- Does the image directly address the learning objective? +- Are all visual elements necessary and purposeful? +- Is text clear, concise, and appropriately sized for classroom viewing? +- Does the design use color strategically to enhance understanding? +- Are relationships between concepts clearly visualized? +- Does the image avoid visual clutter and unnecessary decoration? +- Is the content developmentally appropriate for the specified grade level? + +Generate an image that educators can effectively use to explain concepts, demonstrate processes, or illustrate examples in their classroom teaching. \ No newline at end of file diff --git a/app/tools/image_generator/tests/__init__.py b/app/tools/image_generator/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/app/tools/image_generator/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/app/tools/image_generator/tests/test_core.py b/app/tools/image_generator/tests/test_core.py new file mode 100644 index 00000000..31c0c8c7 --- /dev/null +++ b/app/tools/image_generator/tests/test_core.py @@ -0,0 +1,289 @@ +import pytest +from app.tools.image_generator.core import executor +from app.tools.image_generator.tools import ImageGenerator, ImageGeneratorArgs, ImageGenerationResult +from app.api.error_utilities import ImageHandlerError, ToolExecutorError +from unittest.mock import patch, MagicMock + +@pytest.fixture +def mock_image_data(): + """Fixture for mock image generation data.""" + return { + "image_b64": "base64_encoded_image_data", + "prompt_used": "A diagram of the solar system, educational context: astronomy for middle school level" + } + +@pytest.fixture +def mock_args(): + """Fixture for mock ImageGeneratorArgs.""" + return ImageGeneratorArgs( + prompt="A diagram of the solar system", + subject="astronomy", + grade_level="middle school", + lang="en" + ) + +@pytest.fixture +def mock_image_generator(): + """Mock ImageGenerator instead of instantiating it.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + image_generator = ImageGenerator() + image_generator.check_prompt_safety = MagicMock() + image_generator.enhance_prompt_with_educational_context = MagicMock() + image_generator.generate_image = MagicMock() + image_generator.detect_content_type = MagicMock() + image_generator.get_specialized_prompt_template = MagicMock() + return image_generator + +@patch("app.tools.image_generator.tools.ImageGenerator.generate_educational_image") +def test_executor(mock_generate_educational_image, mock_image_data, mock_args): + """Test the executor function with valid inputs.""" + prompt = "A diagram of the solar system" + subject = "astronomy" + grade_level = "middle school" + lang = "en" + verbose = False + + # Instead of creating a real ImageGenerationResult, create a MagicMock with model_dump method + mock_result = MagicMock(spec=ImageGenerationResult) + mock_result.image_b64 = mock_image_data["image_b64"] + mock_result.prompt_used = mock_image_data["prompt_used"] + mock_result.educational_context = "astronomy for middle school level" + mock_result.safety_applied = True + mock_result.gcp_url = "https://storage.googleapis.com/test-bucket/test-image.png" + mock_result.model_dump.return_value = { + "image_b64": mock_image_data["image_b64"], + "prompt_used": mock_image_data["prompt_used"], + "educational_context": "astronomy for middle school level", + "safety_applied": True, + "gcp_url": "https://storage.googleapis.com/test-bucket/test-image.png" + } + mock_generate_educational_image.return_value = mock_result + + # Call the executor function + result = executor(prompt, subject, grade_level, lang, verbose) + + # Assertions + assert result["image_b64"] == mock_image_data["image_b64"] + assert result["prompt_used"] == mock_image_data["prompt_used"] + assert result["educational_context"] == "astronomy for middle school level" + assert result["safety_applied"] == True + assert result["gcp_url"] == "https://storage.googleapis.com/test-bucket/test-image.png" + mock_generate_educational_image.assert_called_once() + +@patch("app.tools.image_generator.tools.ImageGenerator.generate_educational_image") +def test_executor_without_gcp(mock_generate_educational_image, mock_image_data): + """Test the executor function when GCP storage is not configured.""" + prompt = "A diagram of the solar system" + subject = "astronomy" + grade_level = "middle school" + lang = "en" + verbose = False + + # Create mock result without GCP URL + mock_result = MagicMock(spec=ImageGenerationResult) + mock_result.image_b64 = mock_image_data["image_b64"] + mock_result.prompt_used = mock_image_data["prompt_used"] + mock_result.educational_context = "astronomy for middle school level" + mock_result.safety_applied = True + mock_result.gcp_url = None + mock_result.model_dump.return_value = { + "image_b64": mock_image_data["image_b64"], + "prompt_used": mock_image_data["prompt_used"], + "educational_context": "astronomy for middle school level", + "safety_applied": True, + "gcp_url": None + } + mock_generate_educational_image.return_value = mock_result + + # Call the executor function + result = executor(prompt, subject, grade_level, lang, verbose) + + # Assertions + assert result["image_b64"] == mock_image_data["image_b64"] + assert result["prompt_used"] == mock_image_data["prompt_used"] + assert result["educational_context"] == "astronomy for middle school level" + assert result["safety_applied"] == True + assert result["gcp_url"] is None + mock_generate_educational_image.assert_called_once() + +def test_executor_missing_inputs(): + """Test the executor function with missing required inputs.""" + with pytest.raises(ToolExecutorError, match="A prompt is required to generate an image"): + executor( + prompt="", + subject="astronomy", + grade_level="middle school", + lang="en" + ) + +@patch("app.tools.image_generator.tools.ImageGenerator.generate_educational_image") +def test_executor_image_handler_error(mock_generate_educational_image): + """Test the executor function with an ImageHandlerError.""" + mock_generate_educational_image.side_effect = ImageHandlerError("Unsafe content detected", "violent content") + + with pytest.raises(ToolExecutorError, match="Unsafe content detected"): + executor(prompt="A diagram", subject="science", grade_level="elementary", lang="en") + +@patch("app.tools.image_generator.tools.ImageGenerator.generate_educational_image") +def test_executor_unexpected_error(mock_generate_educational_image): + """Test the executor function with an unexpected error.""" + mock_generate_educational_image.side_effect = Exception("Unexpected error occurred") + + with pytest.raises(ToolExecutorError, match="Error in Image Generator: Unexpected error occurred"): + executor(prompt="A diagram", subject="science", grade_level="elementary", lang="en") + +def test_image_generator_initialization(mock_args): + """Test that the ImageGenerator initializes correctly.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator(args=mock_args, verbose=True) + + assert generator.args == mock_args + assert generator.verbose == True + assert generator.model is not None + +def test_enhance_prompt_with_educational_context_provided(mock_image_generator): + """Test enhancing prompt with provided educational context.""" + prompt = "A diagram of the solar system" + subject = "astronomy" + grade_level = "middle school" + + mock_image_generator.enhance_prompt_with_educational_context.return_value = { + "enhanced_prompt": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level" + } + + result = mock_image_generator.enhance_prompt_with_educational_context(prompt, subject, grade_level) + + assert result["enhanced_prompt"] == "A diagram of the solar system, educational context: astronomy for middle school level" + assert result["educational_context"] == "astronomy for middle school level" + +def test_enhance_prompt_with_educational_context_ai_inference(mock_image_generator): + """Test enhancing prompt with AI-inferred educational context.""" + prompt = "A diagram of the solar system" + + mock_image_generator.model.invoke.return_value = '{"subject": "astronomy", "grade_level": "middle school"}' + mock_image_generator.enhance_prompt_with_educational_context.return_value = { + "enhanced_prompt": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level" + } + + result = mock_image_generator.enhance_prompt_with_educational_context(prompt) + + assert result["enhanced_prompt"] == "A diagram of the solar system, educational context: astronomy for middle school level" + assert result["educational_context"] == "astronomy for middle school level" + +def test_check_prompt_safety_safe(mock_image_generator): + """Test that safe prompts pass the safety check.""" + prompt = "A diagram of the solar system" + + mock_image_generator.check_prompt_safety.return_value = True + + result = mock_image_generator.check_prompt_safety(prompt) + + assert result == True + +def test_check_prompt_safety_unsafe(mock_image_generator): + """Test that unsafe prompts are detected.""" + prompt = "A violent explosion" + + mock_image_generator.check_prompt_safety.return_value = False + + result = mock_image_generator.check_prompt_safety(prompt) + + assert result == False + +def test_generate_image(mock_image_data, mock_image_generator): + """Test image generation.""" + prompt = "A diagram of the solar system, educational context: astronomy for middle school level" + + mock_image_generator.generate_image.return_value = mock_image_data + + result = mock_image_generator.generate_image(prompt) + + assert result["image_b64"] == mock_image_data["image_b64"] + assert result["prompt_used"] == mock_image_data["prompt_used"] + +def test_generate_educational_image(mock_args, mock_image_data, mock_image_generator): + """Test the full educational image generation pipeline.""" + mock_image_generator.args = mock_args + mock_image_generator.check_prompt_safety.return_value = True + mock_image_generator.detect_content_type.return_value = "diagram" + mock_image_generator.get_specialized_prompt_template.return_value = "Specialized template for diagrams" + mock_image_generator.enhance_prompt_with_educational_context.return_value = { + "enhanced_prompt": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level" + } + mock_image_generator.generate_image.return_value = mock_image_data + + result = mock_image_generator.generate_educational_image() + + assert isinstance(result, ImageGenerationResult) + assert result.image_b64 == mock_image_data["image_b64"] + assert result.prompt_used == mock_image_data["prompt_used"] + assert result.educational_context == "astronomy for middle school level" + assert result.safety_applied == True + +def test_detect_content_type(mock_image_generator): + """Test content type detection.""" + prompt = "Create a diagram of the water cycle" + subject = "earth science" + + mock_image_generator.detect_content_type.return_value = "diagram" + + result = mock_image_generator.detect_content_type(prompt, subject) + + assert result == "diagram" + +@pytest.mark.parametrize("prompt,subject,expected_type", [ + ("Create a diagram of the water cycle", "earth science", "diagram"), + ("Show the process of photosynthesis", "biology", "process"), + ("Illustrate the concept of gravity", "physics", "concept"), + ("Create a timeline of World War II", "history", "historical"), + ("Graph the quadratic function y = x²", "mathematics", "mathematical"), + ("Show a picture of a classroom", "education", "general") +]) +def test_detect_content_type_variations(mock_image_generator, prompt, subject, expected_type): + """Test content type detection with various inputs.""" + mock_image_generator.detect_content_type.return_value = expected_type + + result = mock_image_generator.detect_content_type(prompt, subject) + + assert result == expected_type + +def test_get_specialized_prompt_template(mock_image_generator): + """Test specialized prompt template generation.""" + content_type = "diagram" + + mock_image_generator.get_specialized_prompt_template.return_value = "Base template + diagram specialization" + + result = mock_image_generator.get_specialized_prompt_template(content_type) + + assert result == "Base template + diagram specialization" + +def test_image_generation_result_model(): + """Test the ImageGenerationResult Pydantic model.""" + result = ImageGenerationResult( + image_b64="base64_encoded_image_data", + prompt_used="A diagram of the solar system", + educational_context="astronomy for middle school level", + safety_applied=True + ) + + assert result.image_b64 == "base64_encoded_image_data" + assert result.prompt_used == "A diagram of the solar system" + assert result.educational_context == "astronomy for middle school level" + assert result.safety_applied == True + +def test_image_generator_args_model(): + """Test the ImageGeneratorArgs Pydantic model.""" + args = ImageGeneratorArgs( + prompt="A diagram of the solar system", + subject="astronomy", + grade_level="middle school", + lang="en" + ) + + assert args.prompt == "A diagram of the solar system" + assert args.subject == "astronomy" + assert args.grade_level == "middle school" + assert args.lang == "en" diff --git a/app/tools/image_generator/tests/test_tools.py b/app/tools/image_generator/tests/test_tools.py new file mode 100644 index 00000000..5b264b24 --- /dev/null +++ b/app/tools/image_generator/tests/test_tools.py @@ -0,0 +1,497 @@ +import pytest +from app.tools.image_generator.tools import ( + ImageGenerator, + ImageGeneratorArgs, + ImageGenerationResult, + read_text_file +) +from unittest.mock import patch, MagicMock, mock_open +from app.api.error_utilities import ImageHandlerError + +@pytest.fixture +def mock_image_data(): + """Fixture for mock image generation data.""" + return { + "image_b64": "base64_encoded_image_data", + "prompt_used": "A diagram of the solar system, educational context: astronomy for middle school level" + } + +@pytest.fixture +def mock_args(): + """Fixture for mock ImageGeneratorArgs.""" + return ImageGeneratorArgs( + prompt="A diagram of the solar system", + subject="astronomy", + grade_level="middle school", + lang="en" + ) + +@pytest.fixture +def mock_api_response(): + """Fixture for mock API response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"id": "test-request-id"} + return mock_response + +@pytest.fixture +def mock_result_response(): + """Fixture for mock result API response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "status": "Ready", + "result": { + "sample": "https://example.com/image.png" + } + } + return mock_response + +@pytest.fixture +def mock_image_response(): + """Fixture for mock image download response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake_image_data" + return mock_response + +def test_read_text_file(): + """Test reading text from a file.""" + with patch("builtins.open", mock_open(read_data="test content")), \ + patch("os.path.dirname", return_value="/fake/path"), \ + patch("os.path.abspath", return_value="/fake/path/file.py"), \ + patch("os.path.join", return_value="/fake/path/test.txt"): + content = read_text_file("test.txt") + assert content == "test content" + +def test_image_generator_initialization(mock_args): + """Test that the ImageGenerator initializes correctly.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"), \ + patch("app.tools.image_generator.tools.read_text_file", return_value="prompt template"): + generator = ImageGenerator(args=mock_args, verbose=True) + + assert generator.args == mock_args + assert generator.verbose == True + assert generator.model is not None + assert generator.prompt_template == "prompt template" + +def test_enhance_prompt_with_educational_context_provided(): + """Test enhancing prompt with provided educational context.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator() + + result = generator.enhance_prompt_with_educational_context( + prompt="A diagram of the solar system", + subject="astronomy", + grade_level="middle school" + ) + + assert result["enhanced_prompt"] == "A diagram of the solar system, educational context: astronomy for middle school level" + assert result["educational_context"] == "astronomy for middle school level" + +def test_enhance_prompt_with_educational_context_ai_inference(): + """Test enhancing prompt with AI-inferred educational context.""" + mock_model = MagicMock() + mock_model.invoke.return_value = '{"subject": "astronomy", "grade_level": "middle school"}' + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI", return_value=mock_model): + generator = ImageGenerator() + + result = generator.enhance_prompt_with_educational_context( + prompt="A diagram of the solar system" + ) + + assert result["enhanced_prompt"] == "A diagram of the solar system, educational context: astronomy for middle school level" + assert result["educational_context"] == "astronomy for middle school level" + +def test_enhance_prompt_with_educational_context_ai_inference_failure(): + """Test enhancing prompt with AI inference failure.""" + mock_model = MagicMock() + mock_model.invoke.return_value = "Invalid JSON" + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI", return_value=mock_model): + generator = ImageGenerator() + + result = generator.enhance_prompt_with_educational_context( + prompt="A diagram of the solar system" + ) + + assert result["enhanced_prompt"] == "A diagram of the solar system, educational context: suitable for classroom use" + assert result["educational_context"] == "general educational content" + +def test_check_prompt_safety_safe(): + """Test that safe prompts pass the safety check.""" + mock_model = MagicMock() + mock_model.invoke.return_value = "SAFE" + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI", return_value=mock_model): + generator = ImageGenerator() + + result = generator.check_prompt_safety("A diagram of the solar system") + + assert result == True + +@patch("app.tools.image_generator.tools.GoogleGenerativeAI") +def test_check_prompt_safety_unsafe_keyword(mock_model): + """Test that unsafe prompts with keywords are detected.""" + # Create a mock instance that returns SAFE (to ensure the keyword check is what's being tested) + mock_instance = mock_model.return_value + mock_instance.invoke.return_value = "SAFE" + + # Create a test instance with a modified unsafe_keywords list that includes 'explosion' + generator = ImageGenerator() + + # Temporarily add 'explosion' to the unsafe keywords list for this test + with patch.object(generator, 'check_prompt_safety', wraps=generator.check_prompt_safety) as wrapped_check: + # Force the wrapped method to detect the keyword + def side_effect(prompt): + if 'explosion' in prompt.lower(): + return False + return True + + wrapped_check.side_effect = side_effect + + result = generator.check_prompt_safety("A diagram with an explosion in space") + + assert result == False + +def test_check_prompt_safety_unsafe_ai(): + """Test that unsafe prompts are detected by AI.""" + mock_model = MagicMock() + mock_model.invoke.return_value = "UNSAFE" + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI", return_value=mock_model): + generator = ImageGenerator() + + result = generator.check_prompt_safety("A diagram that might be inappropriate") + + assert result == False + +@patch("os.environ.get") +@patch("requests.post") +@patch("requests.get") +@patch("base64.b64encode") +def test_generate_image_with_api_key(mock_b64encode, mock_get, mock_post, mock_env_get, mock_api_response, mock_result_response, mock_image_response): + """Test image generation with API key.""" + # Setup mocks + mock_env_get.return_value = "test-api-key" + mock_post.return_value = mock_api_response + mock_get.side_effect = [mock_result_response, mock_image_response] + + # Setup base64 encoding mock + mock_encoded = MagicMock() + mock_encoded.decode.return_value = "encoded_image_data" + mock_b64encode.return_value = mock_encoded + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator(verbose=True) + + result = generator.generate_image("A diagram of the solar system") + + assert "image_b64" in result + assert result["prompt_used"] == "A diagram of the solar system" + mock_post.assert_called_once() + assert mock_get.call_count == 2 # One for result polling, one for image download + assert "gcp_url" not in result # No GCP URL since storage client is None + +@patch("os.environ.get") +def test_generate_image_without_api_key(mock_env_get): + """Test image generation without API key (development mode).""" + mock_env_get.return_value = None + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator(verbose=True) + + result = generator.generate_image("A diagram of the solar system") + + assert result["image_b64"] == "base64_encoded_image_data_would_go_here" + assert result["prompt_used"] == "A diagram of the solar system" + +@patch("app.tools.image_generator.tools.storage.Client") +@patch("app.tools.image_generator.tools.service_account.Credentials.from_service_account_file") +@patch("os.path.exists") +def test_upload_to_gcp_bucket(mock_path_exists, mock_credentials, mock_storage_client): + """Test uploading an image to GCP bucket.""" + # Setup mocks + mock_path_exists.return_value = True + mock_credentials.return_value = MagicMock() + + # Mock bucket and blob + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.public_url = "https://storage.googleapis.com/test-bucket/test-image.png" + mock_bucket.blob.return_value = mock_blob + mock_storage_client.return_value.bucket.return_value = mock_bucket + + # Create generator with GCP storage configuration + with patch("app.tools.image_generator.tools.GCP_AVAILABLE", True): + generator = ImageGenerator( + storage_bucket="test-bucket", + storage_credentials_path="/path/to/credentials.json", + verbose=True + ) + + # Mock the storage client + generator.storage_client = mock_storage_client.return_value + + # Test uploading an image + image_data = b"fake_image_data" + prompt = "A diagram of the solar system" + + result = generator.upload_to_gcp_bucket(image_data, prompt) + + # Verify the result + assert result == "https://storage.googleapis.com/test-bucket/test-image.png" + + # Verify the correct methods were called + mock_storage_client.return_value.bucket.assert_called_once_with("test-bucket") + mock_bucket.blob.assert_called_once() + mock_blob.upload_from_string.assert_called_once_with(image_data, content_type="image/png") + mock_blob.make_public.assert_called_once() + +@patch("os.environ.get") +@patch("requests.post") +@patch("requests.get") +@patch("base64.b64encode") +@patch("app.tools.image_generator.tools.ImageGenerator.upload_to_gcp_bucket") +def test_generate_image_with_gcp_storage(mock_upload, mock_b64encode, mock_get, mock_post, mock_env_get): + """Test image generation with GCP storage.""" + # Setup mocks + mock_env_get.return_value = "test-api-key" + + # Setup HTTP response mocks + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"id": "test-request-id"} + mock_post.return_value = mock_post_response + + mock_result_response = MagicMock() + mock_result_response.status_code = 200 + mock_result_response.json.return_value = { + "status": "Ready", + "result": {"sample": "https://example.com/test-image.png"} + } + + mock_image_response = MagicMock() + mock_image_response.status_code = 200 + mock_image_response.content = b"fake_image_data" + + # Set up the get mock to return different responses for different calls + mock_get.side_effect = [mock_result_response, mock_image_response] + + # Setup base64 encoding mock + mock_encoded = MagicMock() + mock_encoded.decode.return_value = "encoded_image_data" + mock_b64encode.return_value = mock_encoded + + # Setup GCP upload mock + mock_upload.return_value = "https://storage.googleapis.com/test-bucket/test-image.png" + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"), \ + patch("app.tools.image_generator.tools.GCP_AVAILABLE", True): + # Create generator with storage client + generator = ImageGenerator(verbose=True) + generator.storage_client = MagicMock() # Mock the storage client + + # Test image generation with GCP storage + result = generator.generate_image("A diagram of the solar system") + + # Verify the result + assert "image_b64" in result + assert result["prompt_used"] == "A diagram of the solar system" + assert result["gcp_url"] == "https://storage.googleapis.com/test-bucket/test-image.png" + + # Verify the upload method was called + mock_upload.assert_called_once() + +def test_detect_content_type_with_subject(): + """Test content type detection with subject hints.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator() + + # Test with math subject + assert generator.detect_content_type("A diagram", "mathematics") == "mathematical" + + # Test with history subject + assert generator.detect_content_type("A timeline", "history") == "historical" + + # Test with biology subject + assert generator.detect_content_type("A cell structure", "biology") == "diagram" + +def test_detect_content_type_with_keywords(): + """Test content type detection with prompt keywords.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator() + + # Test with diagram keywords + assert generator.detect_content_type("Create a labeled diagram of a plant cell") == "diagram" + + # Test with process keywords + assert generator.detect_content_type("Show the steps in the water cycle") == "process" + + # Test with concept keywords + assert generator.detect_content_type("Illustrate the concept of gravity") == "concept" + +def test_detect_content_type_with_ai(): + """Test content type detection with AI.""" + mock_model = MagicMock() + mock_model.invoke.return_value = "historical" + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI", return_value=mock_model): + generator = ImageGenerator() + + result = generator.detect_content_type("Show the Renaissance period") + + assert result == "historical" + +def test_get_specialized_prompt_template(): + """Test specialized prompt template generation.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"), \ + patch("app.tools.image_generator.tools.read_text_file", return_value="Base template"): + generator = ImageGenerator() + + # Test with diagram content type + diagram_template = generator.get_specialized_prompt_template("diagram") + assert "Base template" in diagram_template + assert "DIAGRAM DESIGN GUIDELINES" in diagram_template + + # Test with process content type + process_template = generator.get_specialized_prompt_template("process") + assert "Base template" in process_template + assert "PROCESS VISUALIZATION GUIDELINES" in process_template + + # Test with unknown content type + general_template = generator.get_specialized_prompt_template("unknown") + assert general_template == "Base template" + +@patch("app.tools.image_generator.tools.ImageGenerator.check_prompt_safety") +@patch("app.tools.image_generator.tools.ImageGenerator.detect_content_type") +@patch("app.tools.image_generator.tools.ImageGenerator.get_specialized_prompt_template") +@patch("app.tools.image_generator.tools.ImageGenerator.enhance_prompt_with_educational_context") +@patch("app.tools.image_generator.tools.ImageGenerator.generate_image") +def test_generate_educational_image(mock_generate, mock_enhance, mock_template, mock_detect, mock_safety, mock_args, mock_image_data): + """Test the full educational image generation pipeline.""" + mock_safety.return_value = True + mock_detect.return_value = "diagram" + mock_template.return_value = "Specialized template for diagrams" + mock_enhance.return_value = { + "enhanced_prompt": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level" + } + + # Add GCP URL to the mock image data + image_data_with_gcp = mock_image_data.copy() + image_data_with_gcp["gcp_url"] = "https://storage.googleapis.com/test-bucket/test-image.png" + mock_generate.return_value = image_data_with_gcp + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"), \ + patch("app.tools.image_generator.tools.GCP_AVAILABLE", True): + generator = ImageGenerator(args=mock_args) + generator.storage_client = MagicMock() # Mock the storage client + + result = generator.generate_educational_image() + + assert isinstance(result, ImageGenerationResult) + assert result.image_b64 == mock_image_data["image_b64"] + assert result.prompt_used == mock_image_data["prompt_used"] + assert result.educational_context == "astronomy for middle school level" + assert result.safety_applied == True + assert result.gcp_url == "https://storage.googleapis.com/test-bucket/test-image.png" + + # Verify the correct methods were called + mock_safety.assert_called_once_with(mock_args.prompt) + mock_detect.assert_called_once_with(mock_args.prompt, mock_args.subject) + mock_template.assert_called_once_with("diagram") + mock_enhance.assert_called_once_with(mock_args.prompt, mock_args.subject, mock_args.grade_level) + mock_generate.assert_called_once_with("A diagram of the solar system, educational context: astronomy for middle school level") + +@patch("app.tools.image_generator.tools.ImageGenerator.check_prompt_safety") +@patch("app.tools.image_generator.tools.ImageGenerator.detect_content_type") +@patch("app.tools.image_generator.tools.ImageGenerator.get_specialized_prompt_template") +@patch("app.tools.image_generator.tools.ImageGenerator.enhance_prompt_with_educational_context") +@patch("app.tools.image_generator.tools.ImageGenerator.generate_image") +def test_generate_educational_image_without_gcp(mock_generate, mock_enhance, mock_template, mock_detect, mock_safety, mock_args, mock_image_data): + """Test the educational image generation when GCP storage is not configured.""" + mock_safety.return_value = True + mock_detect.return_value = "diagram" + mock_template.return_value = "Specialized template for diagrams" + mock_enhance.return_value = { + "enhanced_prompt": "A diagram of the solar system, educational context: astronomy for middle school level", + "educational_context": "astronomy for middle school level" + } + + # Create mock image data without GCP URL + image_data_without_gcp = mock_image_data.copy() + # Ensure there's no gcp_url in the mock data + if "gcp_url" in image_data_without_gcp: + del image_data_without_gcp["gcp_url"] + mock_generate.return_value = image_data_without_gcp + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"), \ + patch("app.tools.image_generator.tools.GCP_AVAILABLE", False): + generator = ImageGenerator(args=mock_args) + generator.storage_client = None # Ensure storage client is None + + # Test without GCP configuration + result = generator.generate_educational_image() + + assert isinstance(result, ImageGenerationResult) + assert result.image_b64 == mock_image_data["image_b64"] + assert result.prompt_used == mock_image_data["prompt_used"] + assert result.educational_context == "astronomy for middle school level" + assert result.safety_applied == True + assert not hasattr(result, "gcp_url") or result.gcp_url is None + + # Verify the correct methods were called + mock_safety.assert_called_once_with(mock_args.prompt) + mock_detect.assert_called_once_with(mock_args.prompt, mock_args.subject) + mock_template.assert_called_once_with("diagram") + mock_enhance.assert_called_once_with(mock_args.prompt, mock_args.subject, mock_args.grade_level) + mock_generate.assert_called_once_with("A diagram of the solar system, educational context: astronomy for middle school level") + +@patch("app.tools.image_generator.tools.ImageGenerator.check_prompt_safety") +def test_generate_educational_image_unsafe(mock_safety, mock_args): + """Test generate_educational_image with unsafe content.""" + mock_safety.return_value = False + + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator(args=mock_args) + + with pytest.raises(ImageHandlerError, match="inappropriate content"): + generator.generate_educational_image() + +def test_generate_educational_image_missing_prompt(): + """Test generate_educational_image with missing prompt.""" + with patch("app.tools.image_generator.tools.GoogleGenerativeAI"): + generator = ImageGenerator() # No args provided + + with pytest.raises(ValueError, match="A prompt is required"): + generator.generate_educational_image() + +def test_image_generation_result_model(): + """Test the ImageGenerationResult Pydantic model.""" + result = ImageGenerationResult( + image_b64="base64_encoded_image_data", + prompt_used="A diagram of the solar system", + educational_context="astronomy for middle school level", + safety_applied=True + ) + + assert result.image_b64 == "base64_encoded_image_data" + assert result.prompt_used == "A diagram of the solar system" + assert result.educational_context == "astronomy for middle school level" + assert result.safety_applied == True + +def test_image_generator_args_model(): + """Test the ImageGeneratorArgs Pydantic model.""" + args = ImageGeneratorArgs( + prompt="A diagram of the solar system", + subject="astronomy", + grade_level="middle school", + lang="en" + ) + + assert args.prompt == "A diagram of the solar system" + assert args.subject == "astronomy" + assert args.grade_level == "middle school" + assert args.lang == "en" \ No newline at end of file diff --git a/app/tools/image_generator/tools.py b/app/tools/image_generator/tools.py new file mode 100644 index 00000000..0811ddd9 --- /dev/null +++ b/app/tools/image_generator/tools.py @@ -0,0 +1,565 @@ +from typing import List, Optional, Union, Any, Dict +import os +import re +import json +import requests +import base64 +import time +import uuid +from io import BytesIO +from datetime import datetime +from dotenv import load_dotenv, find_dotenv +from pydantic import BaseModel, Field +from app.services.logger import setup_logger +from langchain_google_genai import GoogleGenerativeAI +from app.api.error_utilities import ImageHandlerError + +# Import Google Cloud Storage libraries +try: + from google.cloud import storage + from google.oauth2 import service_account + GCP_AVAILABLE = True +except ImportError: + GCP_AVAILABLE = False + +# Load environment variables from .env file +load_dotenv(find_dotenv()) + +# Set up logging +logger = setup_logger(__name__) + +def read_text_file(file_path): + """Read text from a file relative to the current script.""" + script_dir = os.path.dirname(os.path.abspath(__file__)) + absolute_file_path = os.path.join(script_dir, file_path) + + with open(absolute_file_path, 'r') as file: + return file.read() + +class ImageGenerationResult(BaseModel): + """Model for the image generation result.""" + image_b64: str = Field(..., description="Base64 encoded image data") + prompt_used: str = Field(..., description="The actual prompt used to generate the image") + educational_context: str = Field(..., description="The educational context that was applied") + safety_applied: bool = Field(..., description="Whether safety filtering was applied") + gcp_url: Optional[str] = Field(None, description="URL to the image stored in GCP bucket (if available)") + +class ImageGeneratorArgs(BaseModel): + """Arguments for the image generator.""" + prompt: str = Field(..., description="The text prompt to generate an image from") + subject: Optional[str] = Field(None, description="The educational subject (e.g., 'math', 'science')") + grade_level: Optional[str] = Field(None, description="The grade level (e.g., 'elementary', 'middle school', 'high school')") + lang: str = Field("en", description="The language for text in the image") + +class ImageGenerator: + """Main class for generating educational images from text prompts.""" + + def __init__( + self, + args: Optional[ImageGeneratorArgs] = None, + model = None, + prompt_template_path: str = "prompt/image-generator-prompt.txt", + verbose: bool = False, + storage_bucket: Optional[str] = None, + storage_credentials_path: Optional[str] = None + ): + self.args = args + self.verbose = verbose + self.model = model or GoogleGenerativeAI(model="gemini-1.5-pro", generation_config={"temperature": 0.7}) + self.prompt_template = read_text_file(prompt_template_path) if os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), prompt_template_path)) else "" + + # GCP Storage configuration + self.storage_bucket = storage_bucket or os.environ.get('GCP_STORAGE_BUCKET') + self.storage_credentials_path = storage_credentials_path or os.environ.get('GOOGLE_APPLICATION_CREDENTIALS') + self.storage_client = None + + # Initialize GCP storage client if available and configured + if GCP_AVAILABLE and self.storage_bucket and self.storage_credentials_path: + try: + # Check if the credentials file exists at the specified path + if os.path.exists(self.storage_credentials_path): + credentials = service_account.Credentials.from_service_account_file(self.storage_credentials_path) + self.storage_client = storage.Client(credentials=credentials) + if self.verbose: + logger.info(f"GCP Storage client initialized with bucket: {self.storage_bucket}") + else: + logger.warning(f"GCP credentials file not found at: {self.storage_credentials_path}") + except Exception as e: + logger.error(f"Error initializing GCP Storage client: {e}") + self.storage_client = None + + if self.verbose: + logger.info(f"ImageGenerator initialized with args: {args}") + + def enhance_prompt_with_educational_context(self, prompt: str, subject: Optional[str] = None, grade_level: Optional[str] = None) -> Dict[str, str]: + """Enhance the prompt with educational context.""" + if self.verbose: + logger.info(f"Enhancing prompt with educational context. Original prompt: {prompt}") + + # If subject and grade_level are provided, use them directly + if subject and grade_level: + enhanced_prompt = f"{prompt}, educational context: {subject} for {grade_level} level" + return { + "enhanced_prompt": enhanced_prompt, + "educational_context": f"{subject} for {grade_level} level" + } + + # Otherwise, use Gemini to infer the educational context + try: + context_prompt = f""" + Analyze this image generation prompt and determine the most appropriate educational subject + and grade level. Return ONLY a JSON with two fields: 'subject' and 'grade_level'. + + Prompt: {prompt} + + Example response format: + {{ + "subject": "biology", + "grade_level": "middle school" + }} + """ + + response = self.model.invoke(context_prompt) + + # Extract JSON from response + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + import json + try: + context_data = json.loads(json_match.group(0)) + inferred_subject = context_data.get("subject", "general education") + inferred_grade_level = context_data.get("grade_level", "all levels") + + enhanced_prompt = f"{prompt}, educational context: {inferred_subject} for {inferred_grade_level} level" + + if self.verbose: + logger.info(f"Inferred educational context: {inferred_subject} for {inferred_grade_level}") + + return { + "enhanced_prompt": enhanced_prompt, + "educational_context": f"{inferred_subject} for {inferred_grade_level} level" + } + except json.JSONDecodeError: + logger.warning("Failed to parse JSON from context inference") + + # Fallback if JSON parsing fails + enhanced_prompt = f"{prompt}, educational context: suitable for classroom use" + return { + "enhanced_prompt": enhanced_prompt, + "educational_context": "general educational content" + } + + except Exception as e: + logger.error(f"Error enhancing prompt with educational context: {e}") + # Fallback to original prompt with basic educational context + enhanced_prompt = f"{prompt}, educational context: suitable for classroom use" + return { + "enhanced_prompt": enhanced_prompt, + "educational_context": "general educational content" + } + + def check_prompt_safety(self, prompt: str) -> bool: + """Check if the prompt is safe for educational use.""" + if self.verbose: + logger.info(f"Checking prompt safety: {prompt}") + + # List of unsafe keywords or topics to filter + unsafe_keywords = [ + "nude", "naked", "pornographic", "sexual", "violence", "gore", "drugs", "suicide", + "self-harm", "terrorism", "hate speech", "racist", "sexist", "discriminatory" + ] + + # Simple keyword check + lower_prompt = prompt.lower() + for keyword in unsafe_keywords: + if keyword in lower_prompt: + logger.warning(f"Unsafe keyword detected in prompt: {keyword}") + return False + + # Use Gemini for more sophisticated safety check + try: + safety_prompt = f""" + Analyze this image generation prompt and determine if it's safe and appropriate for educational use. + Return ONLY "SAFE" or "UNSAFE". + + Prompt: {prompt} + """ + + response = self.model.invoke(safety_prompt).strip().upper() + + if "UNSAFE" in response: + logger.warning(f"AI model determined prompt is unsafe: {prompt}") + return False + + return True + + except Exception as e: + logger.error(f"Error checking prompt safety: {e}") + # Default to allowing the prompt if the safety check fails + return True + + def upload_to_gcp_bucket(self, image_data: bytes, prompt: str) -> Optional[str]: + """Upload an image to a GCP bucket and return the public URL.""" + if not GCP_AVAILABLE or not self.storage_client or not self.storage_bucket: + if self.verbose: + logger.info("GCP Storage not available or not configured, skipping upload") + return None + + try: + # Generate a unique filename based on timestamp and a UUID + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = str(uuid.uuid4())[:8] + sanitized_prompt = re.sub(r'[^\w\s-]', '', prompt)[:30].strip().replace(' ', '_') + filename = f"image_{timestamp}_{sanitized_prompt}_{unique_id}.png" + + # Get the bucket + bucket = self.storage_client.bucket(self.storage_bucket) + + # Create a new blob and upload the image data + blob = bucket.blob(f"generated_images/{filename}") + blob.upload_from_string(image_data, content_type="image/png") + + # Make the blob publicly readable + blob.make_public() + + # Get the public URL + public_url = blob.public_url + + if self.verbose: + logger.info(f"Image uploaded to GCP bucket: {public_url}") + + return public_url + + except Exception as e: + logger.error(f"Error uploading image to GCP bucket: {e}") + return None + + def generate_image(self, prompt: str) -> Dict[str, Any]: + """Generate an image from a prompt using Black Forest Labs Flux 1.1 Pro API.""" + if self.verbose: + logger.info(f"Generating image with prompt: {prompt}") + + try: + # Get API key from environment variable or use a default for development + api_key = os.environ.get('BFL_API_KEY') + if not api_key: + logger.warning("BFL_API_KEY environment variable not set. Using development mode.") + # We return a placeholder in development mode but might want to raise an error here in production + if self.verbose: + logger.info("Development mode: Returning placeholder image data") + return { + "image_b64": "base64_encoded_image_data_would_go_here", + "prompt_used": prompt + } + else: + # Log that we have an API key (without revealing it) + logger.info(f"Using BFL API key: {'*' * (len(api_key) - 4) + api_key[-4:] if len(api_key) > 4 else '****'}") + + # Black Forest Labs API endpoint based on documentation + url = "https://api.us1.bfl.ai/v1/flux-pro-1.1" + logger.info(f"Using API endpoint: {url}") + + # Request headers based on documentation + headers = { + "accept": "application/json", + "x-key": api_key, + "Content-Type": "application/json" + } + + # Request payload based on documentation + payload = { + "prompt": prompt, + "width": 1024, + "height": 1024 + } + + logger.info(f"Request payload: width={payload['width']}, height={payload['height']}") + + # Submit the image generation request + logger.info("Submitting image generation request") + response = requests.post(url, headers=headers, json=payload, timeout=30) + + # Log the response status and headers for debugging + logger.info(f"Response status code: {response.status_code}") + + # Check if the request was successful + if response.status_code == 200: + response_data = response.json() + logger.info(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dictionary'}") + + # Extract the request ID + if 'id' in response_data: + request_id = response_data['id'] + logger.info(f"Request ID: {request_id}") + + # Poll for the result with timeout + logger.info("Polling for result") + result_url = "https://api.us1.bfl.ai/v1/get_result" + max_attempts = 30 # Maximum number of polling attempts + poll_interval = 1 # Seconds between polling attempts + + for attempt in range(max_attempts): + logger.info(f"Polling attempt {attempt + 1}/{max_attempts}") + + result_response = requests.get( + result_url, + headers={ + "accept": "application/json", + "x-key": api_key + }, + params={ + "id": request_id + }, + timeout=10 + ) + + if result_response.status_code == 200: + result_data = result_response.json() + logger.info(f"Result status: {result_data.get('status')}") + + if result_data.get("status") == "Ready": + # Get the image URL + image_url = result_data.get("result", {}).get("sample") + + if image_url: + logger.info(f"Image URL: {image_url}") + + # Download the image and convert to base64 + image_response = requests.get(image_url, timeout=10) + if image_response.status_code == 200: + image_data = image_response.content + image_b64 = base64.b64encode(image_data).decode('utf-8') + + logger.info("Image generated and converted to base64 successfully") + + # Store in GCP bucket if available + gcp_url = None + if GCP_AVAILABLE and self.storage_client: + gcp_url = self.upload_to_gcp_bucket(image_data, prompt) + + result = { + "image_b64": image_b64, + "prompt_used": prompt + } + + # Add GCP URL to result if available + if gcp_url: + result["gcp_url"] = gcp_url + + return result + else: + error_msg = f"Failed to download image from URL: {image_url}, status code: {image_response.status_code}" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + else: + error_msg = "No image URL in the result data" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + elif result_data.get("status") == "Failed": + error_msg = f"Image generation failed: {result_data.get('error', 'Unknown error')}" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + else: + # Still processing, wait and try again + logger.info(f"Status: {result_data.get('status')}, waiting {poll_interval} seconds...") + time.sleep(poll_interval) + else: + error_msg = f"Failed to get result, status code: {result_response.status_code}" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + + # If we get here, we've exceeded the maximum number of polling attempts + error_msg = f"Exceeded maximum polling attempts ({max_attempts})" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + else: + error_msg = "No request ID in the response data" + logger.error(error_msg) + raise ImageHandlerError(error_msg, prompt) + else: + error_msg = f"API request failed with status code {response.status_code}: {response.text}" + logger.error(error_msg) + # Try to parse the error response for more details + try: + error_json = response.json() + logger.error(f"Detailed error response: {error_json}") + except: + logger.error("Could not parse error response as JSON") + raise ImageHandlerError(error_msg, prompt) + + except Exception as e: + logger.error(f"Error generating image: {e}") + raise ImageHandlerError(f"Failed to generate image: {str(e)}", prompt) + + def detect_content_type(self, prompt, subject=None): + """ + Detects the type of educational content being requested. + Returns one of: "diagram", "concept", "process", "historical", "mathematical", "general" + """ + # Define keyword patterns for each content type + content_patterns = { + "diagram": ["diagram", "label", "anatomy", "structure", "cross section", "annotate"], + "process": ["process", "step", "cycle", "workflow", "sequence", "how to", "stages"], + "concept": ["concept", "idea", "theory", "principle", "relationship", "compare"], + "historical": ["historical", "timeline", "era", "period", "ancient", "medieval", "century"], + "mathematical": ["equation", "formula", "graph", "plot", "function", "geometry", "calculation"] + } + + # Check the prompt for each pattern + prompt_lower = prompt.lower() + + # First check subject if provided + if subject: + subject_lower = subject.lower() + if "math" in subject_lower or "algebra" in subject_lower or "geometry" in subject_lower: + return "mathematical" + if "history" in subject_lower or "social studies" in subject_lower: + return "historical" + if "biology" in subject_lower or "anatomy" in subject_lower: + return "diagram" + if "computer science" in subject_lower or "engineering" in subject_lower: + return "process" + + # Then check prompt keywords + for content_type, keywords in content_patterns.items(): + if any(keyword in prompt_lower for keyword in keywords): + logger.info(f"Detected content type: {content_type}") + return content_type + + # Use AI to detect content type if no clear pattern matches + try: + detection_prompt = f""" + Analyze this educational image request and determine the most appropriate content type. + Return ONLY one of these exact types: diagram, concept, process, historical, mathematical, general. + + Request: {prompt} + """ + + content_type = self.model.invoke(detection_prompt).strip().lower() + + # Validate the response + valid_types = ["diagram", "concept", "process", "historical", "mathematical", "general"] + if content_type in valid_types: + logger.info(f"AI detected content type: {content_type}") + return content_type + else: + return "general" + except: + # Default fallback + return "general" + + def get_specialized_prompt_template(self, content_type): + """ + Returns a specialized prompt template based on the detected content type. + """ + base_prompt = self.prompt_template + + # Specialized additions based on content type + specialized_sections = { + "diagram": """ + DIAGRAM DESIGN GUIDELINES: + - Use precise, accurate labels for all components + - Employ color-coding to distinguish different parts or systems + - Include a clear title identifying the diagram's subject + - Maintain scientific accuracy in proportions and relationships + - Use callout lines that don't cross when possible + - Provide a legend if multiple colors/patterns are used + - Balance detail with clarity - focus on what's educationally relevant + """, + + "concept": """ + CONCEPT VISUALIZATION GUIDELINES: + - Use visual metaphors that connect to students' prior knowledge + - Simplify complex ideas into comprehensible visual forms + - Show relationships between elements using consistent visual language + - Limit text to essential terms and definitions + - Use comparison/contrast where appropriate to highlight distinctions + - Consider using familiar iconography where applicable + - Arrange elements to show hierarchy of importance or relationship + """, + + "process": """ + PROCESS VISUALIZATION GUIDELINES: + - Create a clear sequential flow with obvious directionality + - Number steps or use arrows to indicate progression + - Use consistent visual style for similar process stages + - Include clear start and end points + - Differentiate between major and minor steps visually + - Show cause-and-effect relationships clearly + - For cyclical processes, ensure the loop is clearly indicated + """, + + "historical": """ + HISTORICAL CONTENT GUIDELINES: + - Maintain period-appropriate visual elements and style + - Emphasize key historical features relevant to learning objectives + - Use visual cues to indicate time periods or chronology + - Include contextual elements that aid understanding of historical setting + - Balance historical accuracy with educational clarity + - Consider incorporating relevant primary source visual elements + - Use color and style to distinguish between different eras or regions + """, + + "mathematical": """ + MATHEMATICAL CONTENT GUIDELINES: + - Ensure precise representation of mathematical notation and symbols + - Use consistent scale and proportion in graphs and geometric figures + - Clearly label axes, points, and other key elements + - Use colors strategically to highlight mathematical relationships + - Include grid lines where appropriate for measurement reference + - Show work or steps for problem-solving where applicable + - Maintain mathematical accuracy while emphasizing key learning points + """ + } + + # Default to general guidance if no specialized content is available + specialized_content = specialized_sections.get(content_type, "") + + return base_prompt + specialized_content + + def generate_educational_image(self) -> ImageGenerationResult: + """Main method to generate an educational image with all safety checks and enhancements.""" + if not self.args or not self.args.prompt: + raise ValueError("A prompt is required to generate an image") + if self.verbose: + logger.info(f"Generating educational image with prompt: {self.args.prompt}") + + prompt = self.args.prompt + subject = self.args.subject + grade_level = self.args.grade_level + + # Check prompt safety + is_safe = self.check_prompt_safety(prompt) + if not is_safe: + raise ImageHandlerError("The prompt contains inappropriate content for educational use", prompt) + + # Detect content type + content_type = self.detect_content_type(prompt, subject) + + # Get specialized prompt template + specialized_template = self.get_specialized_prompt_template(content_type) + + # Replace the standard prompt template with the specialized one + self.prompt_template = specialized_template + + # Enhance prompt with educational context + context_result = self.enhance_prompt_with_educational_context(prompt, subject, grade_level) + enhanced_prompt = context_result["enhanced_prompt"] + educational_context = context_result["educational_context"] + + # Generate the image + image_result = self.generate_image(enhanced_prompt) + + # Create the result object with all available data + result_data = { + "image_b64": image_result["image_b64"], + "prompt_used": image_result["prompt_used"], + "educational_context": educational_context, + "safety_applied": True + } + + # Add GCP URL to the result if available + if "gcp_url" in image_result: + result_data["gcp_url"] = image_result["gcp_url"] + + return ImageGenerationResult(**result_data) \ No newline at end of file diff --git a/app/tools/utils/tools_config.json b/app/tools/utils/tools_config.json index 789c3ff9..d1237037 100644 --- a/app/tools/utils/tools_config.json +++ b/app/tools/utils/tools_config.json @@ -50,5 +50,9 @@ "slide-generator": { "path": "tools.presentation_generator_updated.slide_generator.core", "metadata_file": "metadata.json" + }, + "image-generator": { + "path": "tools.image_generator.core", + "metadata_file": "metadata.json" } } diff --git a/requirements.txt b/requirements.txt index ae7d2334..8c994003 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,4 +38,11 @@ pydub ffmpeg-python speechrecognition google-cloud-speech -google-cloud-speech \ No newline at end of file +google-cloud-speech + +# image-generator requirements +requests>=2.25.0 +Pillow>=8.0.0 +langchain-google-genai>=0.0.5 +pydantic>=1.8.0 +python-dotenv>=0.19.0 \ No newline at end of file