Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
53 changes: 53 additions & 0 deletions app/features/text_rewriter/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from app.api.error_utilities import LoaderError, ToolExecutorError
from app.features.text_rewriter.tools import TextRewriterPipeline
from app.services.logger import setup_logger
from app.services.schemas import TextRewriterArgs
from app.utils.document_loaders import get_docs

logger = setup_logger()


def executor(
text_input: str,
file_type: str,
file_url: str,
rewrite_instruction: str,
lang: str,
verbose=False
):
try:
logger.info(f"Generating docs. from {file_type}")

if not any([file_url, file_type, text_input]): raise ValueError("No input provided for text rewriter.")
if not rewrite_instruction: raise ValueError("Rewrite instruction not provided.")

if file_url and file_type:
logger.info(f"Generating docs. from {file_type}")
docs = get_docs(file_url, file_type, lang)
else:
docs = None

# Initialize the TextRewriterArgs schema
text_rewriter_args = TextRewriterArgs(
text_input=text_input,
file_type=file_type,
file_url=file_url,
rewrite_instruction=rewrite_instruction,
lang=lang
)

output = TextRewriterPipeline(text_rewriter_args, verbose).rewrite_text(docs)

logger.info(f"Text rewritten successfully.")

except LoaderError as e:
error_message = f"Error in Text Rewriter Pipeline ->: {e}"
logger.error(error_message)
raise ToolExecutorError(error_message)

except Exception as e:
error_message = f"Error in executor: {e}"
logger.error(error_message)
raise ValueError(error_message)

return output
29 changes: 29 additions & 0 deletions app/features/text_rewriter/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"inputs": [
{
"label": "Text Input",
"name": "text_input",
"type": "text"
},
{
"label": "File Type",
"name": "file_type",
"type": "text"
},
{
"label": "File URL",
"name": "file_url",
"type": "text"
},
{
"label": "Rewrite Instruction",
"name": "rewrite_instruction",
"type": "text"
},
{
"label": "Language",
"name": "lang",
"type": "text"
}
]
}
12 changes: 12 additions & 0 deletions app/features/text_rewriter/prompt/text-rewriter-prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
You are an advanced text rewriting assistant.
Your task is to rewrite the given text according to the provided instructions while preserving its meaning and key information.
If a file is provided, extract relevant context and use it to enhance the rewrite.

Here is the original text and rewriting instructions:
{attribute_collection}

If additional context is available from the uploaded file, use it:
{context}

Your response should be formatted as follows:
{format_instructions}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
You are an advanced text rewriting assistant.
Your task is to rewrite the given text according to the provided instructions while preserving its meaning and key information.

Here is the original text and rewriting instructions:
{attribute_collection}

Your response should be formatted as follows:
{format_instructions}
Empty file.
131 changes: 131 additions & 0 deletions app/features/text_rewriter/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest

from app.features.text_rewriter.core import executor

base_attributes = {
"text_input": "Photosynthesis is a process used by plants to convert sunlight into energy.",
"rewrite_instruction": "Summarize into bullet points",
"lang": "en"
}


# PDF Tests
def test_executor_pdf_url_valid():
ai_resistant_assignment = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf",
file_type="pdf"
)
assert isinstance(ai_resistant_assignment, dict)


def test_executor_pdf_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/pdf/sample1.pdf",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


# CSV Tests
def test_executor_csv_url_valid():
ai_resistant_assignment = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/csv/sample1.csv",
file_type="csv"
)
assert isinstance(ai_resistant_assignment, dict)


def test_executor_csv_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/csv/sample1.csv",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


# TXT Tests
def test_executor_txt_url_valid():
ai_resistant_assignment = executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/txt/sample1.txt",
file_type="txt"
)
assert isinstance(ai_resistant_assignment, dict)


def test_executor_txt_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/txt/sample1.txt",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


# PPTX Tests
def test_executor_pptx_url_valid():
ai_resistant_assignment = executor(
**base_attributes,
file_url="https://scholar.harvard.edu/files/torman_personal/files/samplepptx.pptx",
file_type="pptx"
)
assert isinstance(ai_resistant_assignment, dict)


def test_executor_pptx_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://scholar.harvard.edu/files/torman_personal/files/samplepptx.pptx",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


def test_executor_xls_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/xls/sample1.xls",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


def test_executor_xlsx_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesamples.com/samples/document/xlsx/sample1.xlsx",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


# XML Tests
def test_executor_xml_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://filesampleshub.com/download/code/xml/dummy.xml",
file_type=1
)
assert isinstance(exc_info.value, ValueError)


# GSheets Tests
def test_executor_gsheets_url_invalid():
with pytest.raises(ValueError) as exc_info:
executor(
**base_attributes,
file_url="https://docs.google.com/spreadsheets/d/16OPtLLSfU/edit",
file_type=1
)
assert isinstance(exc_info.value, ValueError)
109 changes: 109 additions & 0 deletions app/features/text_rewriter/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pydantic import BaseModel, Field
from typing import Optional, List
import os
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import JsonOutputParser
from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings

from app.services.logger import setup_logger

logger = setup_logger(__name__)


def read_text_file(file_path):
# Get the directory containing the script file
script_dir = os.path.dirname(os.path.abspath(__file__))

# Combine the script directory with the relative file path
absolute_file_path = os.path.join(script_dir, file_path)

with open(absolute_file_path, 'r') as file:
return file.read()


class TextRewriterPipeline:
def __init__(self, args=None, model=None, embedding_model=None, vectorstore_class=Chroma, parser=None, prompt=None,
verbose=False):
default_config = {
"model": GoogleGenerativeAI(model="gemini-1.5-flash"),
"embedding_model": GoogleGenerativeAIEmbeddings(model="models/embedding-001"),
"parser": JsonOutputParser(pydantic_object=TextRewriterOutput),
"prompt": read_text_file("./prompt/text-rewriter-prompt.txt"),
"prompt_without_context": read_text_file("./prompt/text-rewriter-without-context-prompt.txt"),
"vectorstore_class": Chroma,
}

self.model = default_config["model"] or model
self.embedding_model = embedding_model or default_config["embedding_model"]

self.parser = parser or default_config["parser"]
self.prompt = prompt or default_config["prompt"]
self.prompt_without_context = default_config["prompt_without_context"]

self.args, self.verbose = args, verbose

self.vectorstore_class = vectorstore_class or default_config["vectorstore_class"]
self.vectorstore, self.retriever, self.runner = None, None, None

if vectorstore_class is None: raise ValueError("Vectorstore must be provided")
if args.text_input is None: raise ValueError("Text input must be provided")
if args.rewrite_instruction is None: raise ValueError("Rewrite instruction must be provided")
if args.lang is None: raise ValueError("Language must be provided")

def compile_with_docs(self, documents: List[Document]):
# Return the chain
prompt = PromptTemplate(
template=self.prompt,
input_variables=["attribute_collection"],
partial_variables={"format_instructions": self.parser.get_format_instructions()}
)

if self.runner is None:
logger.info(f"Creating vectorstore from {len(documents)} documents") if self.verbose else None
self.vectorstore = self.vectorstore_class.from_documents(documents, self.embedding_model)
logger.info(f"Vectorstore created") if self.verbose else None

self.retriever = self.vectorstore.as_retriever()
logger.info(f"Retriever created successfully") if self.verbose else None

self.runner = RunnableParallel(
{"context": self.retriever,
"attribute_collection": RunnablePassthrough()
}
)
logger.info(f"Chain compilation complete")
return self.runner | prompt | self.model | self.parser

def compile_without_docs(self):
# Return the chain
prompt = PromptTemplate(
template=self.prompt_without_context,
input_variables=["attribute_collection"],
partial_variables={"format_instructions": self.parser.get_format_instructions()}
)
logger.info(f"Chain compilation complete")
return prompt | self.model | self.parser

def rewrite_text(self, documents: Optional[List[Document]]):
logger.info(f"Rewriting text")
if documents:
chain = self.compile_with_docs(documents)
else:
chain = self.compile_without_docs()

response = chain.invoke(f"""Original Text: {self.args.text_input},
Rewrite Instruction: {self.args.rewrite_instruction},
Respond in this language: {self.args.lang}""")

if documents:
if self.verbose: print(f"Deleting vectorstore")
self.vectorstore.delete_collection()

return response


class TextRewriterOutput(BaseModel):
rewritten_text: str = Field(..., description="The rewritten text")
7 changes: 7 additions & 0 deletions app/services/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,11 @@ class SlideGeneratorInput(BaseModel):
slides_titles: List[str]
instructional_level: str
topic: str
lang: Optional[str] = "en"

class TextRewriterArgs(BaseModel):
text_input: str
file_type: str
file_url: str
rewrite_instruction: str
lang: Optional[str] = "en"
4 changes: 4 additions & 0 deletions app/tools/utils/tools_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,9 @@
"slide-generator": {
"path": "tools.presentation_generator_updated.slide_generator.core",
"metadata_file": "metadata.json"
},
"text-rewriter": {
"path": "features.text_rewriter.core",
"metadata_file": "metadata.json"
}
}