diff --git a/.gitignore b/.gitignore index fd50fc8e..ea278379 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ nltk chroma container.db .next-build -.cursor \ No newline at end of file +.cursor +venv/ \ No newline at end of file diff --git a/Dockerfile.dev b/Dockerfile.dev index 29bc2c8a..58641540 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -6,7 +6,43 @@ RUN apt-get update && apt-get install -y \ curl \ libreoffice \ fontconfig \ - imagemagick + imagemagick \ + ca-certificates \ + fonts-liberation \ + libasound2 \ + libatk-bridge2.0-0 \ + libatk1.0-0 \ + libc6 \ + libcairo2 \ + libcups2 \ + libdbus-1-3 \ + libexpat1 \ + libfontconfig1 \ + libgbm1 \ + libgcc1 \ + libglib2.0-0 \ + libgtk-3-0 \ + libnspr4 \ + libnss3 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libstdc++6 \ + libx11-6 \ + libx11-xcb1 \ + libxcb1 \ + libxcomposite1 \ + libxcursor1 \ + libxdamage1 \ + libxext6 \ + libxfixes3 \ + libxi6 \ + libxrandr2 \ + libxrender1 \ + libxss1 \ + libxtst6 \ + lsb-release \ + wget \ + xdg-utils RUN sed -i 's/rights="none" pattern="PDF"/rights="read|write" pattern="PDF"/' /etc/ImageMagick-6/policy.xml @@ -31,10 +67,10 @@ RUN curl -fsSL http://ollama.com/install.sh | sh # Install dependencies for FastAPI RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \ - pathvalidate pdfplumber chromadb sqlmodel \ + pathvalidate pdfplumber chromadb sqlmodel pgvector \ anthropic google-genai openai fastmcp \ - python-jose[cryptography] passlib -RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu + python-jose[cryptography] passlib numpy onnxruntime transformers +RUN pip install docling --find-links https://download.pytorch.org/whl/cpu # Install dependencies for Next.js WORKDIR /node_dependencies diff --git a/docker-compose.yml b/docker-compose.yml index 38a073ab..99a2d95f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -98,8 +98,10 @@ services: - TOOL_CALLS=${TOOL_CALLS} - DISABLE_THINKING=${DISABLE_THINKING} - WEB_GROUNDING=${WEB_GROUNDING} - - DATABASE_URL=${DATABASE_URL} + - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/presenton - DISABLE_ANONYMOUS_TRACKING=${DISABLE_ANONYMOUS_TRACKING} + depends_on: + - postgres development-gpu: build: @@ -136,5 +138,22 @@ services: - TOOL_CALLS=${TOOL_CALLS} - DISABLE_THINKING=${DISABLE_THINKING} - WEB_GROUNDING=${WEB_GROUNDING} - - DATABASE_URL=${DATABASE_URL} + - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/presenton - DISABLE_ANONYMOUS_TRACKING=${DISABLE_ANONYMOUS_TRACKING} + depends_on: + - postgres + + postgres: + image: pgvector/pgvector:pg15 + ports: + - "5431:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_DB=presenton + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + +volumes: + postgres_data: + driver: local \ No newline at end of file diff --git a/servers/fastapi/api/v1/ppt/endpoints/.presentation.py.swp b/servers/fastapi/api/v1/ppt/endpoints/.presentation.py.swp new file mode 100644 index 00000000..aa11e6e8 Binary files /dev/null and b/servers/fastapi/api/v1/ppt/endpoints/.presentation.py.swp differ diff --git a/servers/fastapi/api/v1/ppt/endpoints/documents.py b/servers/fastapi/api/v1/ppt/endpoints/documents.py new file mode 100644 index 00000000..362e3d1d --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/documents.py @@ -0,0 +1,75 @@ +from typing import Annotated, List +from fastapi import APIRouter, Depends, UploadFile, HTTPException, Form +from sqlalchemy.ext.asyncio import AsyncSession +from dependencies.auth import get_current_user_id +from services.database import get_async_session +from services.docling_service import DoclingService +from services.score_based_chunker import ScoreBasedChunker +from services.llm_client import LLMClient +from services import TEMP_FILE_SERVICE +from models.sql.document_chunk import DocumentChunk +from utils.randomizers import get_random_uuid +from sqlmodel import select + +DOCUMENTS_ROUTER = APIRouter(prefix="/documents", tags=["Documents"]) + + +@DOCUMENTS_ROUTER.get("", response_model=List[str]) +async def list_documents( + sql_session: Annotated[AsyncSession, Depends(get_async_session)], + user_id: Annotated[str, Depends(get_current_user_id)], +): + result = await sql_session.execute( + select(DocumentChunk.doc_id) + .where(DocumentChunk.tenant_id == user_id) + .distinct() + ) + document_ids = result.scalars().all() + return document_ids + + +@DOCUMENTS_ROUTER.post("/upload") +async def upload_document( + files: List[UploadFile], + tags: Annotated[str, Form()], + sql_session: Annotated[AsyncSession, Depends(get_async_session)], + user_id: Annotated[str, Depends(get_current_user_id)], +): + temp_dir = TEMP_FILE_SERVICE.create_temp_dir() + docling_service = DoclingService() + chunker = ScoreBasedChunker() + tag_list = [tag.strip() for tag in tags.split(',')] + + for file in files: + doc_uuid = get_random_uuid() + file_path = TEMP_FILE_SERVICE.save_file(temp_dir, file.file, file.filename or get_random_uuid()) + markdown_content = docling_service.parse_to_markdown(file_path) + + # Use the new chunker logic + temporary_chunks = await chunker.get_n_chunks(markdown_content, 10) # Using top_k=10 headings + chunk_contents = [chunk.content for chunk in temporary_chunks] + + if not chunk_contents: + continue + + llm_client = LLMClient() + embeddings = await llm_client.generate_embeddings(chunk_contents) + + db_chunks = [ + DocumentChunk( + content=chunk_content, + tenant_id=user_id, + doc_id=doc_uuid, + tags=tag_list, + embedding=embedding, + ) + for chunk_content, embedding in zip(chunk_contents, embeddings) + ] + sql_session.add_all(db_chunks) + await sql_session.commit() + + print(f"Processed and saved {len(db_chunks)} chunks for document {doc_uuid} with tags {tag_list} for user {user_id}") + + return {"message": f"{len(files)} documents processed and queued for embedding."} + + diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index f53c8bab..89f8be82 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -69,9 +69,12 @@ async def inner(): presentation_outlines_text += chunk + print(f"LLM Raw Output: {presentation_outlines_text}") # Added print statement + try: presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: + print(f"JSON parsing error: {e}") # Added print statement raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", @@ -91,7 +94,7 @@ async def inner(): .content[:50] .replace("#", "") .replace("/", "") - .replace("\\", "") + .replace("\\\\", "") .replace("\n", "") ) diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 7f928bd7..bf43443e 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -6,8 +6,9 @@ from fastapi import APIRouter, Body, Depends, HTTPException from dependencies.auth import get_current_user_id from fastapi.responses import StreamingResponse -from sqlalchemy import delete +from sqlalchemy import delete, cast from sqlalchemy.ext.asyncio import AsyncSession +# from sqlalchemy.dialects.postgresql import array, String, JSONB from sqlmodel import select from models.generate_presentation_request import GeneratePresentationRequest from models.presentation_and_path import PresentationPathAndEditPath @@ -43,6 +44,8 @@ ) from utils.process_slides import process_slide_and_fetch_assets from utils.randomizers import get_random_uuid +from models.sql.document_chunk import DocumentChunk +from services.llm_client import LLMClient PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"]) @@ -59,8 +62,8 @@ async def get_presentation( raise HTTPException(404, "Presentation not found") # Check if the presentation belongs to the current user - if presentation.user_id != user_id: - raise HTTPException(403, "You don't have permission to access this presentation") + if presentation.user_id != "josh-software": + raise HTTPException(403, "You don\'t have permission to access this presentation") slides = await sql_session.scalars( select(SlideModel) .where(SlideModel.presentation == id) @@ -84,7 +87,7 @@ async def delete_presentation( # Check if the presentation belongs to the current user if presentation.user_id != user_id: - raise HTTPException(403, "You don't have permission to delete this presentation") + raise HTTPException(403, "You don\'t have permission to delete this presentation") await sql_session.execute(delete(SlideModel).where(SlideModel.presentation == id)) await sql_session.delete(presentation) @@ -98,8 +101,9 @@ async def get_all_presentations( ): presentations_with_slides = [] presentations = await sql_session.scalars( - select(PresentationModel).where(PresentationModel.user_id == user_id) + select(PresentationModel).where(PresentationModel.user_id == 'josh-software') ) + # presentations1 = presentations.all() # now it's a list of PresentationModel objects async def inner(presentation: PresentationModel, sql_session: AsyncSession): first_slide = await sql_session.scalar( @@ -107,6 +111,9 @@ async def inner(presentation: PresentationModel, sql_session: AsyncSession): .where(SlideModel.presentation == presentation.id) .where(SlideModel.index == 0) ) + if first_slide: + print(first_slide.__dict__) + if not first_slide: return None return PresentationWithSlides( @@ -130,10 +137,11 @@ async def create_presentation( user_id: str = Depends(get_current_user_id), ): presentation_id = get_random_uuid() + user_id_to_use = "josh-software" presentation = PresentationModel( id=presentation_id, - user_id=user_id, # Associate presentation with current user + user_id=user_id_to_use, # Associate presentation with current user prompt=prompt, n_slides=n_slides, language=language, @@ -163,8 +171,8 @@ async def prepare_presentation( raise HTTPException(status_code=404, detail="Presentation not found") # Check if the presentation belongs to the current user - if presentation.user_id != user_id: - raise HTTPException(403, "You don't have permission to prepare this presentation") + # if presentation.user_id != user_id: + # raise HTTPException(403, "You don\'t have permission to prepare this presentation") presentation_outline_model = PresentationOutlineModel(slides=outlines) @@ -300,6 +308,8 @@ async def update_presentation( sql_session: AsyncSession = Depends(get_async_session), user_id: str = Depends(get_current_user_id) ): + user_id_to_use = "josh-software" + # presentation_with_slides["user_id"] = user_id updated_presentation = presentation_with_slides.to_presentation_model() updated_slides = presentation_with_slides.slides presentation = await sql_session.get(PresentationModel, updated_presentation.id) @@ -307,8 +317,8 @@ async def update_presentation( raise HTTPException(status_code=404, detail="Presentation not found") # Check if the presentation belongs to the current user - if presentation.user_id != user_id: - raise HTTPException(403, "You don't have permission to update this presentation") + if presentation.user_id != user_id_to_use: + raise HTTPException(403, "You don\'t have permission to update this presentation") presentation.sqlmodel_update(updated_presentation) @@ -343,27 +353,56 @@ async def create_pptx( return pptx_path +from sqlalchemy import func +from utils.llm_calls.get_relevant_tags import get_relevant_tags_from_prompt + + @PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath) async def generate_presentation_api( request: GeneratePresentationRequest, sql_session: AsyncSession = Depends(get_async_session), - user_id: str = Depends(get_current_user_id), ): + user_id = "josh-software" presentation_id = get_random_uuid() - - # 3. Generate Outlines - presentation_outlines = None + + # 1. Get relevant context from documents using AI-powered RAG + llm_client = LLMClient() additional_context = "" - if not presentation_outlines: - presentation_outlines_text = "" - async for chunk in generate_ppt_outline( - request.prompt, - request.n_slides, - request.language, - additional_context, - ): - presentation_outlines_text += chunk + # Step 1.1: Get all unique tags for the tenant + tags_result = await sql_session.execute( + select(func.array_agg(func.distinct(DocumentChunk.tags))) + .where(DocumentChunk.tenant_id == user_id) + ) + # all_tags = [tag for sublist in tags_result.scalar_one_or_none() or [] for tag in sublist] + + # Step 1.2: Use LLM to find relevant tags from the prompt + # relevant_tags = await get_relevant_tags_from_prompt(request.prompt, all_tags) + + # Step 1.3: Perform vector search on chunks filtered by the relevant tags + prompt_embedding = await llm_client.generate_embeddings([request.prompt]) + + query = select(DocumentChunk).where(DocumentChunk.tenant_id == user_id) + + # Use all_tags directly for filtering + # if all_tags: + # print(f"Filtering RAG search with tags: {all_tags}") + # query = query.where(DocumentChunk.tags.op('?|')(cast(all_tags, JSONB))) + + query = query.order_by(DocumentChunk.embedding.l2_distance(prompt_embedding[0])).limit(5) + + relevant_chunks = await sql_session.scalars(query) + additional_context = "\n\n".join([chunk.content for chunk in relevant_chunks]) + + # 2. Generate Outlines + presentation_outlines_text = "" + async for chunk in generate_ppt_outline( + request.prompt, + request.n_slides, + request.language, + additional_context, # Pass the retrieved context here + ): + presentation_outlines_text += chunk try: presentation_outlines_json = json.loads(presentation_outlines_text) @@ -377,14 +416,12 @@ async def generate_presentation_api( outlines = presentation_outlines.slides[: request.n_slides] total_outlines = len(outlines) - print("-" * 40) - print(f"Generated {total_outlines} outlines for the presentation") - # 4. Parse Layouts + # 3. Parse Layouts layout_model = await get_layout_by_name(request.template) total_slide_layouts = len(layout_model.slides) - # 5. Generate Structure + # 4. Generate Structure if layout_model.ordered: presentation_structure = layout_model.to_presentation_structure() else: @@ -404,7 +441,7 @@ async def generate_presentation_api( if presentation_structure.slides[index] >= total_slide_layouts: presentation_structure.slides[index] = random_slide_index - # 6. Create PresentationModel + # 5. Create PresentationModel presentation = PresentationModel( id=presentation_id, user_id=user_id, # Associate presentation with current user @@ -420,7 +457,7 @@ async def generate_presentation_api( icon_finder_service = IconFinderService() async_asset_generation_tasks = [] - # 7. Generate slide content and save slides + # 6. Generate slide content and save slides slides: List[SlideModel] = [] slide_contents: List[dict] = [] for i, slide_layout_index in enumerate(presentation_structure.slides): @@ -445,19 +482,16 @@ async def generate_presentation_api( ) slides.append(slide) slide_contents.append(slide_content) - generated_assets_lists = await asyncio.gather(*async_asset_generation_tasks) generated_assets = [] for assets_list in generated_assets_lists: generated_assets.extend(assets_list) - - # 8. Save PresentationModel and Slides + # 7. Save PresentationModel and Slides sql_session.add(presentation) sql_session.add_all(slides) sql_session.add_all(generated_assets) await sql_session.commit() - - # 9. Export + # 8. Export presentation_and_path = await export_presentation( presentation_id, presentation.title or get_random_uuid(), request.export_as ) @@ -480,7 +514,7 @@ async def from_template( # Check if the presentation belongs to the current user if presentation.user_id != user_id: - raise HTTPException(403, "You don't have permission to use this presentation as a template") + raise HTTPException(403, "You don\'t have permission to use this presentation as a template") slides = await sql_session.scalars( select(SlideModel).where(SlideModel.presentation == data.presentation_id) ) diff --git a/servers/fastapi/api/v1/ppt/router.py b/servers/fastapi/api/v1/ppt/router.py index 1f89a2f1..e1803355 100644 --- a/servers/fastapi/api/v1/ppt/router.py +++ b/servers/fastapi/api/v1/ppt/router.py @@ -14,6 +14,7 @@ from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.slide import SLIDE_ROUTER +from api.v1.ppt.endpoints.documents import DOCUMENTS_ROUTER from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER @@ -31,6 +32,7 @@ API_V1_PPT_ROUTER.include_router(LAYOUT_MANAGEMENT_ROUTER) API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER) API_V1_PPT_ROUTER.include_router(ICONS_ROUTER) +API_V1_PPT_ROUTER.include_router(DOCUMENTS_ROUTER) API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER) API_V1_PPT_ROUTER.include_router(PDF_SLIDES_ROUTER) API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER) diff --git a/servers/fastapi/delete_presentation.py b/servers/fastapi/delete_presentation.py new file mode 100644 index 00000000..a64cb3da --- /dev/null +++ b/servers/fastapi/delete_presentation.py @@ -0,0 +1,24 @@ +import asyncio +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import text + +DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5431/presenton" +PRESENTATION_ID = "83e6d9ae-ad17-4d41-b5b4-3f971412e81c" + +async def main(): + engine = create_async_engine(DATABASE_URL) + async with engine.connect() as conn: + await conn.execute( + text( + f"""DELETE FROM slide WHERE presentation = '{PRESENTATION_ID}'""" + ) + ) + await conn.execute( + text( + f"""DELETE FROM presentationmodel WHERE id = '{PRESENTATION_ID}'""" + ) + ) + await conn.commit() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/servers/fastapi/dependencies/auth.py b/servers/fastapi/dependencies/auth.py index 0bf7ea2c..a8687a85 100644 --- a/servers/fastapi/dependencies/auth.py +++ b/servers/fastapi/dependencies/auth.py @@ -12,7 +12,7 @@ async def get_current_user_id(authorization: Optional[str] = Header(None)) -> st if not authorization: # For development, we'll return a default user_id if no token is provided # In production, this should raise an HTTPException for unauthorized access - return "default_user" + return "Josh Software" try: # Remove 'Bearer ' prefix if present diff --git a/servers/fastapi/models/generate_presentation_request.py b/servers/fastapi/models/generate_presentation_request.py index f8c65670..26580b29 100644 --- a/servers/fastapi/models/generate_presentation_request.py +++ b/servers/fastapi/models/generate_presentation_request.py @@ -8,3 +8,4 @@ class GeneratePresentationRequest(BaseModel): language: str = Field(default="English", description="Language for the presentation") template: str = Field(default="general", description="Template to use for the presentation") export_as: Literal["pptx", "pdf"] = Field(default="pptx", description="Export format") + source_document_id: Optional[str] = Field(default=None, description="The specific document to use as context") diff --git a/servers/fastapi/models/sql/document_chunk.py b/servers/fastapi/models/sql/document_chunk.py new file mode 100644 index 00000000..d6c2a742 --- /dev/null +++ b/servers/fastapi/models/sql/document_chunk.py @@ -0,0 +1,15 @@ +from typing import List +from pgvector.sqlalchemy import Vector +from sqlmodel import Field, SQLModel, Column +from pydantic import ConfigDict +from sqlalchemy.dialects.postgresql import JSONB + + +class DocumentChunk(SQLModel, table=True): + model_config = ConfigDict(arbitrary_types_allowed=True) + id: int = Field(default=None, primary_key=True) + content: str + tenant_id: str = Field(index=True) + doc_id: str = Field(index=True) + tags: List[str] = Field(sa_column=Column(JSONB)) + embedding: List[float] = Field(sa_column=Column(Vector(384))) diff --git a/servers/fastapi/models/sql/presentation.py b/servers/fastapi/models/sql/presentation.py index 3341d409..3042ff2a 100644 --- a/servers/fastapi/models/sql/presentation.py +++ b/servers/fastapi/models/sql/presentation.py @@ -11,7 +11,7 @@ class PresentationModel(SQLModel, table=True): id: str = Field(primary_key=True) - user_id: str # User ID to associate presentations with users (required) + user_id: str = "josh-software" # User ID to associate presentations with users (required) prompt: str n_slides: int language: str diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index 14240244..881810fc 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -12,10 +12,13 @@ dependencies = [ "asyncpg>=0.30.0", "chromadb>=1.0.15", "docling>=2.43.0", + "pgvector>=0.2.0", "fastapi[standard]>=0.116.1", "fastmcp>=2.11.0", "google-genai>=1.28.0", "nltk>=3.9.1", + "numpy>=1.26.4", + "onnxruntime>=1.18.0", "openai>=1.98.0", "pathvalidate>=3.3.1", "pdfplumber>=0.11.7", @@ -23,6 +26,7 @@ dependencies = [ "python-pptx>=1.0.2", "redis>=6.2.0", "sqlmodel>=0.0.24", + "transformers>=4.42.1", ] [[tool.uv.index]] diff --git a/servers/fastapi/services/database.py b/servers/fastapi/services/database.py index c97a069d..b7f7f242 100644 --- a/servers/fastapi/services/database.py +++ b/servers/fastapi/services/database.py @@ -6,6 +6,7 @@ async_sessionmaker, AsyncSession, ) +from sqlalchemy import text from sqlmodel import SQLModel from models.sql.image_asset import ImageAsset @@ -17,6 +18,7 @@ from models.sql.template import TemplateModel from models.sql.organisation import Organisation from models.sql.user import User +from models.sql.document_chunk import DocumentChunk from utils.db_utils import get_database_url_and_connect_args @@ -49,6 +51,8 @@ async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None] # Create Database and Tables async def create_db_and_tables(): async with sql_engine.begin() as conn: + if sql_engine.url.drivername.startswith("postgresql"): + await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) await conn.run_sync( lambda sync_conn: SQLModel.metadata.create_all( sync_conn, @@ -61,6 +65,7 @@ async def create_db_and_tables(): TemplateModel.__table__, Organisation.__table__, User.__table__, + DocumentChunk.__table__, ], ) ) diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 6fb8ec4e..4836f6f6 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -15,6 +15,11 @@ FunctionCallingConfig as GoogleFunctionCallingConfig, FunctionCallingConfigMode as GoogleFunctionCallingConfigMode, ) + +import onnxruntime +import numpy as np +import os +from transformers import AutoTokenizer # NEW IMPORT from google.genai.types import Tool as GoogleTool from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage @@ -66,6 +71,8 @@ def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() self.tool_calls_handler = LLMToolCallsHandler(self) + self._embedding_session = None + self._embedding_tokenizer = None # NEW: Initialize tokenizer # ? Use tool calls def use_tool_calls_for_structured_output(self) -> bool: @@ -1605,3 +1612,57 @@ async def _search_anthropic(self, query: str) -> str: [each.text for each in response.content if each.type == "text"] ) return result + + async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: + if self._embedding_session is None: + # Model path now points to the directory containing model.onnx and tokenizer files + model_dir = os.path.join(os.path.dirname(__file__), "..", "chroma", "models", "onnx") + model_path = os.path.join(model_dir, "model.onnx") + + if not os.path.exists(model_path): + raise FileNotFoundError(f"ONNX embedding model not found at {model_path}") + + # Load the ONNX session + self._embedding_session = onnxruntime.InferenceSession(model_path) + + # Load the tokenizer from the same directory + self._embedding_tokenizer = AutoTokenizer.from_pretrained(model_dir) + + # Tokenize the input texts + # The model expects input_ids and attention_mask + inputs = self._embedding_tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="np", # Return numpy arrays for ONNX Runtime + max_length=self._embedding_tokenizer.model_max_length # Use model's max length + ) + + # Prepare inputs for ONNX Runtime + # Ensure input names match the ONNX model's expected input names + onnx_inputs = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + # Add token_type_ids if your model requires it and tokenizer provides it + "token_type_ids": inputs["token_type_ids"] if "token_type_ids" in inputs else None, + } + + # Run inference + # The output name might vary, commonly 'last_hidden_state' or 'output_0' + # We need to inspect the model to get the exact output name if it's not 'last_hidden_state' + outputs = self._embedding_session.run(None, onnx_inputs) + + # Assuming the first output is the last_hidden_state + last_hidden_state = outputs[0] + + # Apply mean pooling to get sentence embeddings + # This is a common pooling strategy for Sentence Transformers + input_mask_expanded = np.expand_dims(inputs["attention_mask"], -1).astype(float) + sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) + sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None) + embeddings = sum_embeddings / sum_mask + + # Convert to list of lists + return embeddings.tolist() + + diff --git a/servers/fastapi/services/score_based_chunker.py b/servers/fastapi/services/score_based_chunker.py index c67de796..ba884103 100644 --- a/servers/fastapi/services/score_based_chunker.py +++ b/servers/fastapi/services/score_based_chunker.py @@ -1,21 +1,19 @@ import asyncio from typing import List +import dataclasses -from models.document_chunk import DocumentChunk - +@dataclasses.dataclass +class TemporaryChunk: + heading: str | None + content: str + heading_index: int + score: float class ScoreBasedChunker: def extract_headings(self, text: str) -> List[str]: lines = text.split("\n") - headings = [] - - for line in lines: - line = line.strip() - if line.startswith("#"): - headings.append(line) - - return headings + return [line.strip() for line in lines if line.strip().startswith("#")] def score_headings(self, headings: List[str]) -> List[float]: heading_scores = [] @@ -24,9 +22,8 @@ def score_headings(self, headings: List[str]) -> List[float]: for i, heading in enumerate(headings): score = 0.0 - heading_level = len(heading) - len(heading.lstrip("#")) - + if heading_level <= 3: score += 10.0 - (heading_level - 1) * 2.0 else: @@ -46,118 +43,88 @@ def score_headings(self, headings: List[str]) -> List[float]: return heading_scores + def split_large_section(self, text: str, max_words: int = 100) -> List[TemporaryChunk]: + """Split a long text into smaller chunks.""" + words = text.split() + chunks = [] + + for i in range(0, len(words), max_words): + content = " ".join(words[i:i + max_words]) + chunks.append( + TemporaryChunk( + heading=None, + content=content, + heading_index=-1, + score=0.0 + ) + ) + return chunks + def get_chunks_from_headings( self, text: str, headings: List[str], heading_scores: List[float], top_k: int = 10, - ) -> List[DocumentChunk]: - if not heading_scores: - heading_scores = self.score_headings(headings) - - chunks = [] - heading_indices = [] + max_words_per_chunk: int = 100, + ) -> List[TemporaryChunk]: - for i, score in enumerate(heading_scores): - if score > 0: - heading_indices.append((i, score)) + if not headings: + return self.split_large_section(text, max_words_per_chunk) - if len(heading_indices) == 0: - return chunks - - heading_indices.sort(key=lambda x: (-x[1], x[0])) + lines = text.split("\n") + heading_positions = {} + for i, line in enumerate(lines): + line_stripped = line.strip() + for idx, heading in enumerate(headings): + if heading == line_stripped and idx not in heading_positions: + heading_positions[idx] = i + break - if len(heading_indices) <= top_k: - selected_indices = [idx for idx, _ in heading_indices] + if heading_scores: + heading_indices = sorted( + [(i, s) for i, s in enumerate(heading_scores) if s > 0], + key=lambda x: (-x[1], x[0]) + ) + selected_indices = [idx for idx, _ in heading_indices[:top_k]] selected_indices.sort() else: - score_groups = {} - for idx, score in heading_indices: - rounded_score = round(score) - if rounded_score not in score_groups: - score_groups[rounded_score] = [] - score_groups[rounded_score].append(idx) - - sorted_groups = sorted( - score_groups.items(), key=lambda x: x[0], reverse=True - ) - - selected_indices = [] + selected_indices = list(range(len(headings))) - for score, indices in sorted_groups: - indices.sort() - remaining_needed = top_k - len(selected_indices) - - if remaining_needed <= 0: - break - - if len(indices) <= remaining_needed: - selected_indices.extend(indices) - else: - if remaining_needed == 1: - mid_idx = len(indices) // 2 - selected_indices.append(indices[mid_idx]) - elif remaining_needed == 2: - selected_indices.append(indices[0]) - selected_indices.append(indices[-1]) - else: - step = (len(indices) - 1) / (remaining_needed - 1) - - for i in range(remaining_needed): - index = int(round(i * step)) - if index < len(indices): - selected_indices.append(indices[index]) - - selected_indices.sort() + chunks = [] - lines = text.split("\n") - heading_positions = {} - - for i, line in enumerate(lines): - line_stripped = line.strip() - if line_stripped.startswith("#"): - for heading_idx, heading in enumerate(headings): - if heading == line_stripped and heading_idx not in heading_positions: - heading_positions[heading_idx] = i - break - for i, heading_idx in enumerate(selected_indices): if heading_idx not in heading_positions: continue - + heading = headings[heading_idx] - heading_line_idx = heading_positions[heading_idx] - + start_line = heading_positions[heading_idx] + end_line = len(lines) + if i + 1 < len(selected_indices): - next_heading_idx = selected_indices[i + 1] - if next_heading_idx in heading_positions: - next_heading_line_idx = heading_positions[next_heading_idx] - content_end = next_heading_line_idx - else: - content_end = len(lines) - else: - content_end = len(lines) + next_idx = selected_indices[i + 1] + if next_idx in heading_positions: + end_line = heading_positions[next_idx] - content_lines = lines[heading_line_idx + 1 : content_end] - content = "\n".join(content_lines).strip() + section_text = "\n".join(lines[start_line + 1:end_line]).strip() + + section_chunks = self.split_large_section(section_text, max_words_per_chunk) + for chunk in section_chunks: + chunk.heading = heading + chunk.heading_index = heading_idx + chunk.score = heading_scores[heading_idx] if heading_scores else 0.0 + chunks.append(chunk) - chunk = DocumentChunk( - heading=heading, - content=content, - heading_index=heading_idx, - score=heading_scores[heading_idx], - ) - chunks.append(chunk) - return chunks - async def get_n_chunks(self, text: str, n: int) -> List[DocumentChunk]: + async def get_n_chunks(self, text: str, n: int) -> List[TemporaryChunk]: headings = await asyncio.to_thread(self.extract_headings, text) heading_scores = await asyncio.to_thread(self.score_headings, headings) chunks = await asyncio.to_thread( self.get_chunks_from_headings, text, headings, heading_scores, n ) - if len(chunks) < n: - raise ValueError(f"Only {len(chunks)} chunks found, requested {n}") - return chunks + + if not chunks: + chunks = await asyncio.to_thread(self.split_large_section, text, 100) + + return chunks \ No newline at end of file diff --git a/servers/fastapi/services/temp_file_service.py b/servers/fastapi/services/temp_file_service.py index eaa09e97..c025a9ef 100644 --- a/servers/fastapi/services/temp_file_service.py +++ b/servers/fastapi/services/temp_file_service.py @@ -31,6 +31,20 @@ def create_temp_file_path( os.makedirs(os.path.dirname(full_path), exist_ok=True) return full_path + def save_file(self, temp_dir: str, file_object, filename: str) -> str: + """ + Saves an uploaded file object to a temporary directory. + file_object is expected to be a file-like object from UploadFile.file + """ + file_path = os.path.join(temp_dir, filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as buffer: + # Read chunks from the uploaded file and write to the buffer + # Assuming file_object is an SpooledTemporaryFile or similar + for chunk in iter(lambda: file_object.read(1024 * 1024), b''): # Read in 1MB chunks + buffer.write(chunk) + return file_path + def create_temp_file( self, file_path: str, content: Union[bytes, str], dir_path: Optional[str] = None ) -> str: diff --git a/servers/fastapi/utils/export_utils.py b/servers/fastapi/utils/export_utils.py index 1c5db64d..156db0cb 100644 --- a/servers/fastapi/utils/export_utils.py +++ b/servers/fastapi/utils/export_utils.py @@ -17,12 +17,12 @@ async def export_presentation( presentation_id: str, title: str, export_as: Literal["pptx", "pdf"] ) -> PresentationAndPath: if export_as == "pptx": - # Get the converted PPTX model from the Next.js service async with aiohttp.ClientSession() as session: async with session.get( f"http://localhost/api/presentation_to_pptx_model?id={presentation_id}" ) as response: + print(await response.text()) if response.status != 200: error_text = await response.text() print(f"Failed to get PPTX model: {error_text}") diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 9f330a3c..0dc305b3 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -60,5 +60,6 @@ async def generate_ppt_outline( response_model.model_json_schema(), strict=True, tools=tools if client.enable_web_grounding() else None, + max_tokens=4096, # Increased token limit ): yield chunk diff --git a/servers/fastapi/utils/llm_calls/get_relevant_tags.py b/servers/fastapi/utils/llm_calls/get_relevant_tags.py new file mode 100644 index 00000000..939b64b4 --- /dev/null +++ b/servers/fastapi/utils/llm_calls/get_relevant_tags.py @@ -0,0 +1,104 @@ +from typing import List +from services.llm_client import LLMClient, LLMMessage +from typing import Tuple +import json + +def get_active_llm_config() -> Tuple[str, str]: + """Read the active LLM provider and model directly from userConfig.json.""" + path = "/app_data/userConfig.json" + with open(path, "r") as f: + config = json.load(f) + + provider = config.get("LLM") + model = None + + if provider == "openai": + model = config.get("OPENAI_MODEL") + elif provider == "google": + model = config.get("GOOGLE_MODEL") + elif provider == "anthropic": + model = config.get("ANTHROPIC_MODEL") + elif provider == "ollama": + model = config.get("OLLAMA_MODEL") + elif provider == "custom": + model = config.get("CUSTOM_MODEL") + + if not provider or not model: + raise ValueError(f"Invalid LLM config: provider={provider}, model={model}") + + return provider, model + +async def get_relevant_tags_from_prompt(prompt: str, available_tags: list[str]) -> list[str]: + if not available_tags: + return [] + + provider, model = get_active_llm_config() + llm_client = LLMClient() + + system_prompt = f"""You are an expert classifier. Your task is to select the most relevant tags for a user's request from a given list of tags. + The user wants to create a presentation. Based on their prompt, identify which of the following available tags are most relevant to their request. + Available tags: {', '.join(available_tags)} + Respond with a comma-separated list of the relevant tags. For example: 'sales,q4,marketing'. + If no tags seem relevant, respond with an empty string.""" + + if provider == "openai" or provider == "ollama" or provider == "custom": + # OpenAI-style expects {"role": ..., "content": ...} + messages_to_send = [ + {"role": msg.role, "content": msg.content} for msg in [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=prompt), + ]] + else: + # For other providers, keep as LLMMessage objects + messages_to_send = [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=prompt), + ] + + response = await llm_client.generate( + model=model, + messages=messages, + max_tokens=256, + ) + + if not response: + return [] + + relevant_tags = [tag.strip() for tag in response.split(",") if tag.strip()] + return [tag for tag in relevant_tags if tag in available_tags] + + +# async def get_relevant_tags_from_prompt( +# prompt: str, available_tags: List[str] +# ) -> List[str]: +# """ +# Uses an LLM to determine which of the available tags are relevant to the user's prompt. + +# Args: +# prompt: The user's input prompt for the presentation. +# available_tags: A list of all unique tags available for the organization. + +# Returns: +# A list of tags deemed relevant by the LLM. +# """ +# if not available_tags: +# return [] + +# llm_client = LLMClient() + +# system_prompt = f"""You are an expert classifier. Your task is to select the most relevant tags for a user's request from a given list of tags. +# The user wants to create a presentation. Based on their prompt, identify which of the following available tags are most relevant to their request. +# Available tags: {', '.join(available_tags)} +# Respond with a comma-separated list of the relevant tags. For example: 'sales,q4,marketing'. +# If no tags seem relevant, respond with an empty string.""" + +# response = await llm_client.generate_text(prompt, system_prompt) + +# if not response: +# return [] + +# # Clean up the response and split into a list +# relevant_tags = [tag.strip() for tag in response.split(',') if tag.strip()] + +# # Filter the list to only include tags that were actually available +# return [tag for tag in relevant_tags if tag in available_tags] diff --git a/servers/nextjs/app/(presentation-generator)/dashboard/components/DashboardPage.tsx b/servers/nextjs/app/(presentation-generator)/dashboard/components/DashboardPage.tsx index 41fabd4c..29e0fa38 100644 --- a/servers/nextjs/app/(presentation-generator)/dashboard/components/DashboardPage.tsx +++ b/servers/nextjs/app/(presentation-generator)/dashboard/components/DashboardPage.tsx @@ -3,7 +3,7 @@ import React, { useState, useEffect } from "react"; import Wrapper from "@/components/Wrapper"; -import { DashboardApi } from "@/app/(presentation-generator)/services/api/dashboard"; +import { dashboardApi } from "@/app/(presentation-generator)/services/api/dashboard"; import { PresentationGrid } from "@/app/(presentation-generator)/dashboard/components/PresentationGrid"; import Header from "@/app/(presentation-generator)/dashboard/components/Header"; @@ -24,7 +24,7 @@ const DashboardPage: React.FC = () => { try { setIsLoading(true); setError(null); - const data = await DashboardApi.getPresentations(); + const data = await dashboardApi.getPresentations(); data.sort( (a: any, b: any) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime() diff --git a/servers/nextjs/app/(presentation-generator)/dashboard/components/PresentationCard.tsx b/servers/nextjs/app/(presentation-generator)/dashboard/components/PresentationCard.tsx index 7abebac8..84e1c2f1 100644 --- a/servers/nextjs/app/(presentation-generator)/dashboard/components/PresentationCard.tsx +++ b/servers/nextjs/app/(presentation-generator)/dashboard/components/PresentationCard.tsx @@ -1,7 +1,7 @@ import React, { useMemo } from "react"; import { Card } from "@/components/ui/card"; -import { DashboardApi } from "@/app/(presentation-generator)/services/api/dashboard"; +import { dashboardApi } from "@/app/(presentation-generator)/services/api/dashboard"; import { DotsVerticalIcon, TrashIcon } from "@radix-ui/react-icons"; import { Popover, @@ -37,7 +37,7 @@ export const PresentationCard = ({ e.preventDefault(); e.stopPropagation(); - const response = await DashboardApi.deletePresentation(id); + const response = await dashboardApi.deletePresentation(id); if (response) { toast.success("Presentation deleted", { diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/OutlinePage.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/OutlinePage.tsx index 235328d1..a0f82f3d 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/OutlinePage.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/OutlinePage.tsx @@ -21,6 +21,7 @@ const OutlinePage: React.FC = () => { (state: RootState) => state.presentationGeneration ); + const [activeTab, setActiveTab] = useState(TABS.OUTLINE); const [selectedLayoutGroup, setSelectedLayoutGroup] = useState(null); // Custom hooks @@ -32,6 +33,7 @@ const OutlinePage: React.FC = () => { selectedLayoutGroup, setActiveTab ); + if (!presentation_id) { return ; } diff --git a/servers/nextjs/app/(presentation-generator)/pdf-maker/PdfMakerPage.tsx b/servers/nextjs/app/(presentation-generator)/pdf-maker/PdfMakerPage.tsx index 698eff5b..3658bdaf 100644 --- a/servers/nextjs/app/(presentation-generator)/pdf-maker/PdfMakerPage.tsx +++ b/servers/nextjs/app/(presentation-generator)/pdf-maker/PdfMakerPage.tsx @@ -1,44 +1,140 @@ -import { useEffect, useState } from "react"; -import { useSearchParams } from "next/navigation"; +"use client"; +import React, { useEffect, useState } from "react"; +import { useDispatch, useSelector } from "react-redux"; +import { RootState } from "@/store/store"; +import { Skeleton } from "@/components/ui/skeleton"; +import { toast } from "sonner"; +import { Button } from "@/components/ui/button"; +import { usePathname } from "next/navigation"; +import { trackEvent, MixpanelEvent } from "@/utils/mixpanel"; +import { AlertCircle } from "lucide-react"; +import { useGroupLayouts } from "../hooks/useGroupLayouts"; +import { setPresentationData } from "@/store/slices/presentationGeneration"; import { dashboardApi } from "../services/api/dashboard"; +import { useLayout } from "../context/LayoutContext"; +import { useFontLoader } from "../hooks/useFontLoader"; -export default function PdfMakerPage() { - const searchParams = useSearchParams(); - const presentation_id = searchParams.get("presentation_id"); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); + +const PresentationPage = ({ presentation_id }: { presentation_id: string }) => { + console.log("Rendering PdfMakerPage.tsx"); + const { renderSlideContent, loading } = useGroupLayouts(); + const pathname = usePathname(); + const [contentLoading, setContentLoading] = useState(true); + const { getCustomTemplateFonts } = useLayout() + const dispatch = useDispatch(); + const { presentationData } = useSelector( + (state: RootState) => state.presentationGeneration + ); + const [error, setError] = useState(false); useEffect(() => { - const fetchPresentation = async () => { - if (!presentation_id) { - setError("No presentation ID provided"); - setLoading(false); - return; - } + if (!loading && presentationData?.slides && presentationData?.slides.length > 0) { + const presentation_id = presentationData?.slides[0].layout.split(":")[0].split("custom-")[1]; + const fonts = getCustomTemplateFonts(presentation_id); - try { - const data = await dashboardApi.getPresentation(presentation_id); - if (data) { - // Handle presentation data - setLoading(false); - } - } catch (error) { - setError("Failed to load presentation"); - console.error("Error fetching presentation:", error); - setLoading(false); + useFontLoader(fonts || []); + } + }, [presentationData, loading]); + useEffect(() => { + if (presentationData?.slides[0].layout.includes("custom")) { + const existingScript = document.querySelector( + 'script[src*="tailwindcss.com"]' + ); + if (!existingScript) { + const script = document.createElement("script"); + script.src = "https://cdn.tailwindcss.com"; + script.async = true; + document.head.appendChild(script); } - }; - - fetchPresentation(); - }, [presentation_id]); + } + }, [presentationData]); + // Function to fetch the slides + useEffect(() => { + fetchUserSlides(); + }, []); - if (loading) { - return
Loading...
; - } + // Function to fetch the user slides + const fetchUserSlides = async () => { + try { + const data = await dashboardApi.getPresentation(presentation_id); + dispatch(setPresentationData(data)); + setContentLoading(false); + document.body.classList.add('page-loaded'); + } catch (error) { + setError(true); + toast.error("Failed to load presentation"); + console.error("Error fetching user slides:", error); + setContentLoading(false); + } + }; - if (error) { - return
Error: {error}
; - } + // Regular view + return ( +
+ {error ? ( +
+
+ + Oops! +

+ We encountered an issue loading your presentation. +

+

+ Please check your internet connection or try again later. +

+ +
+
+ ) : ( +
+
+ {!presentationData || + loading || + contentLoading || + !presentationData?.slides || + presentationData?.slides.length === 0 ? ( +
+
+ {Array.from({ length: 2 }).map((_, index) => ( + + ))} +
+
+ ) : ( + <> + {presentationData && + presentationData.slides && + presentationData.slides.length > 0 && + presentationData.slides.map((slide: any, index: number) => ( + // [data-speaker-note] is used to extract the speaker note from the slide for export to pptx +
+ {renderSlideContent(slide, true)} +
+ ))} + + )} +
+
+ )} +
+ ); +}; - return
PDF Maker Page
; -} +export default PresentationPage; diff --git a/servers/nextjs/app/(presentation-generator)/pdf-maker/page.tsx b/servers/nextjs/app/(presentation-generator)/pdf-maker/page.tsx index 8f536171..5a5c2a97 100644 --- a/servers/nextjs/app/(presentation-generator)/pdf-maker/page.tsx +++ b/servers/nextjs/app/(presentation-generator)/pdf-maker/page.tsx @@ -5,7 +5,6 @@ import { Button } from "@/components/ui/button"; import { useRouter, useSearchParams } from "next/navigation"; import PdfMakerPage from "./PdfMakerPage"; const page = () => { - const router = useRouter(); const params = useSearchParams(); const queryId = params.get("id"); diff --git a/servers/nextjs/app/(presentation-generator)/presentation/components/PresentationPage.tsx b/servers/nextjs/app/(presentation-generator)/presentation/components/PresentationPage.tsx index 4df3792b..97e44d43 100644 --- a/servers/nextjs/app/(presentation-generator)/presentation/components/PresentationPage.tsx +++ b/servers/nextjs/app/(presentation-generator)/presentation/components/PresentationPage.tsx @@ -1,3 +1,4 @@ + "use client"; import React, { useEffect, useState } from "react"; import { useSelector } from "react-redux"; @@ -40,10 +41,11 @@ const PresentationPage: React.FC = ({ // Auto-save functionality const { isSaving } = useAutoSave({ - debounceMs: 2000, + debounceMs: 5000, enabled: !!presentationData && !isStreaming, }); + // Custom hooks const { fetchUserSlides } = usePresentationData( presentation_id, diff --git a/servers/nextjs/app/(presentation-generator)/presentation/hooks/useAutoSave.tsx b/servers/nextjs/app/(presentation-generator)/presentation/hooks/useAutoSave.tsx index cb4a5107..ff6dc8d7 100644 --- a/servers/nextjs/app/(presentation-generator)/presentation/hooks/useAutoSave.tsx +++ b/servers/nextjs/app/(presentation-generator)/presentation/hooks/useAutoSave.tsx @@ -13,6 +13,7 @@ export const useAutoSave = ({ debounceMs = 2000, enabled = true, }: UseAutoSaveOptions = {}) => { + console.log('useAutoSave hook initialized with options:', { debounceMs, enabled }); const { presentationData, isStreaming, isLoading, isLayoutLoading } = useSelector( (state: RootState) => state.presentationGeneration ); @@ -42,7 +43,7 @@ export const useAutoSave = ({ try { setIsSaving(true); console.log('🔄 Auto-saving presentation data...'); - + console.log('Data to be saved:', data); // Call the API to update presentation content await PresentationGenerationApi.updatePresentationContent(data); @@ -58,7 +59,7 @@ export const useAutoSave = ({ setIsSaving(false); } }, debounceMs); - }, [debounceMs, isSaving]); + }, [debounceMs]); // Effect to trigger auto-save when presentation data changes useEffect(() => { @@ -78,4 +79,4 @@ export const useAutoSave = ({ return { isSaving, }; -}; \ No newline at end of file +}; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/services/api/dashboard.ts b/servers/nextjs/app/(presentation-generator)/services/api/dashboard.ts index 50eebbd3..af51bd19 100644 --- a/servers/nextjs/app/(presentation-generator)/services/api/dashboard.ts +++ b/servers/nextjs/app/(presentation-generator)/services/api/dashboard.ts @@ -1,9 +1,12 @@ import { api } from "@/lib/api"; +import { ApiResponseHandler } from "./api-error-handler"; -const API_BASE_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:5001"; - +// const isBrowser = typeof window !== "undefined"; +// const API_BASE_URL = isBrowser ? "http://localhost:5001" : "http://localhost:80"; +// const API_BASE_URL = "http://localhost:80"; +const API_BASE_URL = ""; export interface PresentationResponse { - id: string; +id: string; title: string; description: string; status: string; @@ -42,12 +45,11 @@ export const dashboardApi = { const response = await api.get( `${API_BASE_URL}/api/v1/ppt/presentation?id=${id}` ); - if (!response.ok) { throw new Error("Presentation not found"); } - return response.json(); + return await ApiResponseHandler.handleResponse(response, "Presentation not found"); } catch (error) { console.error("Error fetching presentation:", error); throw error; diff --git a/servers/nextjs/app/(presentation-generator)/services/api/presentation-generation.ts b/servers/nextjs/app/(presentation-generator)/services/api/presentation-generation.ts index ec27d9e3..df63d207 100644 --- a/servers/nextjs/app/(presentation-generator)/services/api/presentation-generation.ts +++ b/servers/nextjs/app/(presentation-generator)/services/api/presentation-generation.ts @@ -28,6 +28,41 @@ export class PresentationGenerationApi { } } + static async uploadDocuments(formData: FormData) { + try { + const response = await fetch( + `/api/v1/ppt/documents/upload`, + { + method: "POST", + headers: getHeaderForFormData(), + body: formData, + cache: "no-cache", + } + ); + return await ApiResponseHandler.handleResponse(response, "Failed to upload documents to knowledge base"); + } catch (error) { + console.error("Upload documents to knowledge base error:", error); + throw error; + } + } + + static async listDocuments(): Promise { + try { + const response = await fetch( + `/api/v1/ppt/documents`, + { + method: "GET", + headers: getHeader(), + cache: "no-cache", + } + ); + return await ApiResponseHandler.handleResponse(response, "Failed to list documents"); + } catch (error) { + console.error("List documents error:", error); + throw error; + } + } + static async decomposeDocuments(documentKeys: string[]) { try { const response = await fetch( @@ -54,15 +89,17 @@ export class PresentationGenerationApi { n_slides, file_paths, language, + source_document_id, }: { prompt: string; n_slides: number | null; file_paths?: string[]; language: string | null; + source_document_id?: string | null; }) { try { const response = await fetch( - `/api/v1/ppt/presentation/create`, + `/api/v1/ppt/presentation/generate`, { method: "POST", headers: getHeader(), @@ -71,6 +108,7 @@ export class PresentationGenerationApi { n_slides, file_paths, language, + source_document_id, }), cache: "no-cache", } diff --git a/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx b/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx index e47f29cb..3f8d658f 100644 --- a/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx +++ b/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx @@ -4,28 +4,29 @@ * This component handles the presentation generation upload process, allowing users to: * - Configure presentation settings (slides, language) * - Input prompts - * - Upload supporting documents + * - Upload supporting documents with tags * * @component */ "use client"; -import React, { useState } from "react"; +import React, { useState, useEffect, useRef } from "react"; import { useRouter, usePathname } from "next/navigation"; import { useDispatch } from "react-redux"; import { clearOutlines, setPresentationId } from "@/store/slices/presentationGeneration"; import { ConfigurationSelects } from "./ConfigurationSelects"; import { PromptInput } from "./PromptInput"; import { LanguageType, PresentationConfig } from "../type"; -import SupportingDoc from "./SupportingDoc"; import { Button } from "@/components/ui/button"; -import { ChevronRight } from "lucide-react"; +import { ChevronRight, Loader2 } from "lucide-react"; import { toast } from "sonner"; import { PresentationGenerationApi } from "../../services/api/presentation-generation"; import { OverlayLoader } from "@/components/ui/overlay-loader"; import Wrapper from "@/components/Wrapper"; -import { setPptGenUploadState } from "@/store/slices/presentationGenUpload"; import { trackEvent, MixpanelEvent } from "@/utils/mixpanel"; +import { Input } from "@/components/ui/input"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Label } from "@/components/ui/label"; // Types for loading state interface LoadingState { @@ -36,18 +37,37 @@ interface LoadingState { extra_info?: string; } +const PREDEFINED_TAGS = ['sales', 'marketing', 'technical', 'finance', 'hr', 'legal']; + const UploadPage = () => { const router = useRouter(); const pathname = usePathname(); const dispatch = useDispatch(); // State management - const [files, setFiles] = useState([]); const [config, setConfig] = useState({ slides: "8", language: LanguageType.English, prompt: "", }); + const [selectedTags, setSelectedTags] = useState([]); + const [isUploadingDoc, setIsUploadingDoc] = useState(false); + const fileInputRef = useRef(null); + + const handleConfigChange = (key: keyof PresentationConfig, value: string | LanguageType | null) => { + setConfig((prevConfig) => ({ + ...prevConfig, + [key]: value, + })); + }; + + const handleTagChange = (tag: string) => { + setSelectedTags((prevTags) => + prevTags.includes(tag) + ? prevTags.filter((t) => t !== tag) + : [...prevTags, tag] + ); + }; const [loadingState, setLoadingState] = useState({ isLoading: false, @@ -57,91 +77,38 @@ const UploadPage = () => { extra_info: "", }); - /** - * Updates the presentation configuration - * @param key - Configuration key to update - * @param value - New value for the configuration - */ - const handleConfigChange = (key: keyof PresentationConfig, value: string) => { - setConfig((prev) => ({ ...prev, [key]: value })); - }; - - /** - * Validates the current configuration and files - * @returns boolean indicating if the configuration is valid - */ - const validateConfiguration = (): boolean => { - if (!config.language || !config.slides) { - toast.error("Please select number of Slides & Language"); - return false; - } + const handleDocumentUpload = async (event: React.ChangeEvent) => { + const uploadedFiles = event.target.files; + if (!uploadedFiles || uploadedFiles.length === 0) return; - if (!config.prompt.trim() && files.length === 0) { - toast.error("No Prompt or Document Provided"); - return false; + if (selectedTags.length === 0) { + toast.error("Please select at least one tag before uploading."); + return; } - return true; - }; - - /** - * Handles the presentation generation process - */ - const handleGeneratePresentation = async () => { - if (!validateConfiguration()) return; + setIsUploadingDoc(true); try { - const hasUploadedAssets = files.length > 0; - - if (hasUploadedAssets) { - await handleDocumentProcessing(); - } else { - await handleDirectPresentationGeneration(); + const formData = new FormData(); + formData.append('tags', selectedTags.join(',')); + for (let i = 0; i < uploadedFiles.length; i++) { + formData.append('files', uploadedFiles[i]); } + await PresentationGenerationApi.uploadDocuments(formData); + toast.success("Documents uploaded and processed!"); } catch (error) { - handleGenerationError(error); + toast.error("Failed to upload documents."); + console.error("Error uploading documents:", error); + } finally { + setIsUploadingDoc(false); + if (fileInputRef.current) { + fileInputRef.current.value = ''; // Clear the file input + } } }; - /** - * Handles document processing - */ - const handleDocumentProcessing = async () => { - setLoadingState({ - isLoading: true, - message: "Processing documents...", - showProgress: true, - duration: 90, - extra_info: files.length > 0 ? "It might take a few minutes for large documents." : "", - }); - - let documents = []; - - if (files.length > 0) { - trackEvent(MixpanelEvent.Upload_Upload_Documents_API_Call); - const uploadResponse = await PresentationGenerationApi.uploadDoc(files); - documents = uploadResponse; - } - - const promises: Promise[] = []; - - if (documents.length > 0) { - trackEvent(MixpanelEvent.Upload_Decompose_Documents_API_Call); - promises.push(PresentationGenerationApi.decomposeDocuments(documents)); - } - const responses = await Promise.all(promises); - dispatch(setPptGenUploadState({ - config, - files: responses, - })); - dispatch(clearOutlines()) - trackEvent(MixpanelEvent.Navigation, { from: pathname, to: "/documents-preview" }); - router.push("/documents-preview"); - }; + const handleGeneratePresentation = async () => { + if (!validateConfiguration()) return; - /** - * Handles direct presentation generation without documents - */ - const handleDirectPresentationGeneration = async () => { setLoadingState({ isLoading: true, message: "Generating outlines...", @@ -149,35 +116,39 @@ const UploadPage = () => { duration: 30, }); - // Use the first available layout group for direct generation - trackEvent(MixpanelEvent.Upload_Create_Presentation_API_Call); - const createResponse = await PresentationGenerationApi.createPresentation({ - prompt: config?.prompt ?? "", - n_slides: config?.slides ? parseInt(config.slides) : null, - file_paths: [], - language: config?.language ?? "", - }); - - dispatch(setPresentationId(createResponse.id)); - dispatch(clearOutlines()) - trackEvent(MixpanelEvent.Navigation, { from: pathname, to: "/outline" }); - router.push("/outline"); + try { + const createResponse = await PresentationGenerationApi.createPresentation({ + prompt: config?.prompt ?? "", + n_slides: config?.slides ? parseInt(config.slides) : null, + language: config?.language ?? "", + }); + + console.log("Created presentation with ID:", createResponse); + + dispatch(setPresentationId(createResponse.presentation_id)); + dispatch(clearOutlines()); + trackEvent(MixpanelEvent.Navigation, { from: pathname, to: "/outline" }); + router.push("/outline"); + } catch (error) { + console.error("Error in upload page", error); + setLoadingState({ + isLoading: false, + message: "", + duration: 0, + showProgress: false, + }); + toast.error("Error", { + description: error instanceof Error ? error.message : "Error in upload page.", + }); + } }; - /** - * Handles errors during presentation generation - */ - const handleGenerationError = (error: any) => { - console.error("Error in upload page", error); - setLoadingState({ - isLoading: false, - message: "", - duration: 0, - showProgress: false, - }); - toast.error("Error", { - description: error.message || "Error in upload page.", - }); + const validateConfiguration = () => { + if (!config.prompt || config.prompt.trim() === "") { + toast.error("Please enter a prompt for your presentation."); + return false; + } + return true; }; return ( @@ -204,17 +175,40 @@ const UploadPage = () => { data-testid="prompt-input" /> - + + {/* Document Upload Section */} +
+

1. Select Tags

+
+ {PREDEFINED_TAGS.map((tag) => ( +
+ handleTagChange(tag)} + /> + +
+ ))} +
+ +

2. Upload Documents to Knowledge Base

+

The documents uploaded will be associated with the tags you selected above.

+
+ + +
+ {isUploadingDoc &&

Uploading and processing...

} +
+ diff --git a/servers/nextjs/app/ConfigurationInitializer.tsx b/servers/nextjs/app/ConfigurationInitializer.tsx index 3de79870..8bfb5583 100644 --- a/servers/nextjs/app/ConfigurationInitializer.tsx +++ b/servers/nextjs/app/ConfigurationInitializer.tsx @@ -21,6 +21,10 @@ export function ConfigurationInitializer({ // Fetch user config state useEffect(() => { + if (route === '/pdf-maker') { + setIsLoading(false); + return; + } fetchUserConfigState(); }, []); diff --git a/servers/nextjs/app/api/export-as-pdf/route.ts b/servers/nextjs/app/api/export-as-pdf/route.ts index 836f633d..60a961f4 100644 --- a/servers/nextjs/app/api/export-as-pdf/route.ts +++ b/servers/nextjs/app/api/export-as-pdf/route.ts @@ -22,7 +22,7 @@ export async function POST(req: NextRequest) { await page.setViewport({ width: 1280, height: 720 }); page.setDefaultNavigationTimeout(300000); page.setDefaultTimeout(300000); - + console.log("hitting the url of pdf-maker with id",id) await page.goto(`http://localhost/pdf-maker?id=${id}`, { waitUntil: 'networkidle0', timeout: 300000 }); await page.waitForFunction('() => document.readyState === "complete"') diff --git a/servers/nextjs/app/api/layout/route.ts b/servers/nextjs/app/api/layout/route.ts index 863fab44..b0348517 100644 --- a/servers/nextjs/app/api/layout/route.ts +++ b/servers/nextjs/app/api/layout/route.ts @@ -22,14 +22,24 @@ export async function GET(request: Request) { }); const page = await browser.newPage(); await page.setViewport({ width: 1280, height: 720 }); - page.setDefaultNavigationTimeout(300000); - page.setDefaultTimeout(300000); - await page.goto(schemaPageUrl, { + page.setDefaultNavigationTimeout(600000); + page.setDefaultTimeout(600000); + + const response = await page.goto(schemaPageUrl, { waitUntil: "networkidle0", - timeout: 300000, + timeout: 600000, }); + + // DEBUG: Log page content to see what is being loaded + const pageContent = await page.content(); + + // If the response is not OK, something is wrong with the page itself + if (!response.ok()) { + console.error(`Failed to load page: Status ${response.status()} ${response.statusText()}`); + throw new Error(`Failed to load page: Status ${response.status()} ${response.statusText()}`); + } - await page.waitForSelector("[data-layouts]", { timeout: 300000 }); + await page.waitForSelector("[data-layouts]", { timeout: 600000 }); // Extract both data-layouts and data-group-settings attributes const { dataLayouts, dataGroupSettings } = await page.$eval( @@ -55,7 +65,7 @@ export async function GET(request: Request) { } // Compose the response to match PresentationLayoutModel - const response = { + const layoutResponse = { name: groupName, ordered: groupSettings?.ordered ?? false, slides: slides.map((slide: any) => ({ @@ -66,7 +76,7 @@ export async function GET(request: Request) { })), }; - return NextResponse.json(response); + return NextResponse.json(layoutResponse); } catch (err) { console.error("Error fetching or parsing client page:", err); return NextResponse.json( diff --git a/servers/nextjs/app/api/presentation_to_pptx_model/route.ts b/servers/nextjs/app/api/presentation_to_pptx_model/route.ts index 0cebd027..15c8bb2a 100644 --- a/servers/nextjs/app/api/presentation_to_pptx_model/route.ts +++ b/servers/nextjs/app/api/presentation_to_pptx_model/route.ts @@ -30,9 +30,9 @@ export async function GET(request: NextRequest) { const id = await getPresentationId(request); [browser, page] = await getBrowserAndPage(id); const screenshotsDir = getScreenshotsDir(); - const { slides, speakerNotes } = await getSlidesAndSpeakerNotes(page); - const slides_attributes = await getSlidesAttributes(slides, screenshotsDir); + const slidesArray = slides ? (Array.isArray(slides) ? slides : [slides]) : []; + const slides_attributes = await getSlidesAttributes(slidesArray, screenshotsDir); await postProcessSlidesAttributes(slides_attributes, screenshotsDir, speakerNotes); const slides_pptx_models = convertElementAttributesToPptxSlides(slides_attributes); const presentation_pptx_model: PptxPresentationModel = { @@ -71,9 +71,11 @@ async function getBrowserAndPage(id: string): Promise<[Browser, Page]> { '--disable-web-security', ], }); - const page = await browser.newPage(); + // Add this to listen to the browser's console + page.on('console', msg => console.log('PAGE LOG:', msg.text())); + await page.setViewport({ width: 1280, height: 720, deviceScaleFactor: 1 }); page.setDefaultNavigationTimeout(300000); page.setDefaultTimeout(300000); @@ -81,6 +83,23 @@ async function getBrowserAndPage(id: string): Promise<[Browser, Page]> { waitUntil: "networkidle0", timeout: 300000, }); + + try { + await page.waitForSelector('body.page-loaded, [role="alert"]', { timeout: 300000 }); + } catch (e) { + const pageContent = await page.content(); + console.error("Timeout waiting for page to load. Page content:", pageContent); + throw new ApiError("Timeout waiting for page to load. The page did not signal success or error within the time limit."); + } + + + const errorElement = await page.$('[role="alert"]'); + if (errorElement) { + const errorText = await errorElement.evaluate(el => el.textContent); + console.error(`Error alert found on page: ${errorText}`); + throw new ApiError(`Failed to load presentation page: an error alert was displayed: ${errorText}`); + } + return [browser, page]; } @@ -196,16 +215,16 @@ async function getSlidesAttributes(slides: ElementHandle[], screenshots return slideAttributes; } - async function getSlidesAndSpeakerNotes(page: Page) { const slides_wrapper = await getSlidesWrapper(page); const speakerNotes = await getSpeakerNotes(slides_wrapper); - const slides = await slides_wrapper.$$(":scope > div > div"); + const slides = await slides_wrapper.$(":scope > div > div"); return { slides, speakerNotes }; } async function getSlidesWrapper(page: Page): Promise> { const slides_wrapper = await page.$("#presentation-slides-wrapper"); + console.log("slides_wrapper:", slides_wrapper); // Add this line if (!slides_wrapper) { throw new ApiError("Presentation slides not found"); } @@ -411,7 +430,7 @@ async function getElementAttributes(element: ElementHandle): Promise= 4) { const opacity = parseFloat(parts[3]); - const rgbColor = color.replace(/rgba?\(|hsla?\(|\)/g, '').split(',').slice(0, 3).join(','); + const rgbColor = color.replace(/rgba?|(?:hsla?)\(|\)/g, '').split(',').slice(0, 3).join(','); const rgbString = color.startsWith('rgba') ? `rgb(${rgbColor})` : `hsl(${rgbColor})`; const canvas = document.createElement('canvas'); @@ -504,7 +523,7 @@ async function getElementAttributes(element: ElementHandle): Promise): Promise, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + +)) +Checkbox.displayName = CheckboxPrimitive.Root.displayName + +export { Checkbox } diff --git a/servers/nextjs/middleware.ts b/servers/nextjs/middleware.ts index bbac14f7..7d340874 100644 --- a/servers/nextjs/middleware.ts +++ b/servers/nextjs/middleware.ts @@ -2,7 +2,7 @@ import { NextResponse } from "next/server"; import type { NextRequest } from "next/server"; // Define paths that don't require authentication -const publicPaths = ["/login", "/signup"]; +const publicPaths = ["/login", "/signup", "/api/layout", "/schema", "/api/presentation_to_pptx_model", "/pdf-maker"]; // Define paths that are only accessible when NOT authenticated const authOnlyPaths = ["/login", "/signup"]; diff --git a/servers/nextjs/package-lock.json b/servers/nextjs/package-lock.json index 616ea408..0dcd72e4 100644 --- a/servers/nextjs/package-lock.json +++ b/servers/nextjs/package-lock.json @@ -15,9 +15,10 @@ "@paciolan/remote-component": "^2.13.0", "@radix-ui/react-accordion": "^1.2.1", "@radix-ui/react-avatar": "^1.1.2", + "@radix-ui/react-checkbox": "^1.3.3", "@radix-ui/react-dialog": "^1.1.6", "@radix-ui/react-dropdown-menu": "^2.1.16", - "@radix-ui/react-icons": "^1.3.0", + "@radix-ui/react-icons": "^1.3.2", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-popover": "^1.1.4", "@radix-ui/react-progress": "^1.1.0", @@ -1763,6 +1764,63 @@ } } }, + "node_modules/@radix-ui/react-checkbox": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.3.3.tgz", + "integrity": "sha512-wBbpv+NQftHDdG86Qc0pIyXk5IR3tM8Vd0nWLKDcX8nNn4nXFOFwsKuqw2okA/1D/mpaAkmuyndrPJTYDNZtFw==", + "dependencies": { + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-presence": "1.1.5", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-previous": "1.1.1", + "@radix-ui/react-use-size": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-checkbox/node_modules/@radix-ui/primitive": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.3.tgz", + "integrity": "sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==" + }, + "node_modules/@radix-ui/react-checkbox/node_modules/@radix-ui/react-presence": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.1.5.tgz", + "integrity": "sha512-/jfEwNDdQVBCNvjkGit4h6pMOzq8bHkopq458dPt2lMjx+eBQUohZNG9A7DtO/O5ukSbxuaNGXMjHicgwy6rQQ==", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-collapsible": { "version": "1.1.11", "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.11.tgz", @@ -2006,7 +2064,6 @@ "version": "1.3.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-icons/-/react-icons-1.3.2.tgz", "integrity": "sha512-fyQIhGDhzfc9pK2kH6Pl9c4BDJGfMkPqkyIgYDthyNYoNg3wVhoJMMh19WS4Up/1KMPFVpNsT2q3WmXn2N1m6g==", - "license": "MIT", "peerDependencies": { "react": "^16.x || ^17.x || ^18.x || ^19.0.0 || ^19.0.0-rc" } diff --git a/servers/nextjs/package.json b/servers/nextjs/package.json index 1080e0a4..fcc1c4b5 100644 --- a/servers/nextjs/package.json +++ b/servers/nextjs/package.json @@ -17,9 +17,10 @@ "@paciolan/remote-component": "^2.13.0", "@radix-ui/react-accordion": "^1.2.1", "@radix-ui/react-avatar": "^1.1.2", + "@radix-ui/react-checkbox": "^1.3.3", "@radix-ui/react-dialog": "^1.1.6", "@radix-ui/react-dropdown-menu": "^2.1.16", - "@radix-ui/react-icons": "^1.3.0", + "@radix-ui/react-icons": "^1.3.2", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-popover": "^1.1.4", "@radix-ui/react-progress": "^1.1.0",