diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b05e001..4660909 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,3 +50,5 @@ jobs: - name: Test with pytest run: | pytest --color=yes + env: + PYTHONWARNINGS: "ignore::DeprecationWarning:pkg_resources.*" diff --git a/docs/milvus.md b/docs/milvus.md index d219e0f..88c62a2 100644 --- a/docs/milvus.md +++ b/docs/milvus.md @@ -1,43 +1,49 @@ # Tutorial of Rule-based Retrieval through Milvus -The `whyhow_rbr` package helps create customized RAG pipelines. It is built on top +The `rule-based-retrieval` package helps create customized RAG pipelines. It is built on top of the following technologies (and their respective Python SDKs) -- **OpenAI** - text generation - **Milvus** - vector database +- **OpenAI** - text generation ## Initialization +Install package +```shell +pip install rule-based-retrieval +``` + Please import some essential package ```python -from pymilvus import DataType - -from src.whyhow_rbr.rag_milvus import ClientMilvus +from whyhow_rbr import ClientMilvus, MilvusRule ``` -## Client +## ClientMilvus -The central object is a `ClientMilvus`. It manages all necessary resources +The central object is `ClientMilvus`. It manages all necessary resources and provides a simple interface for all the RAG related tasks. First of all, to instantiate it one needs to provide the following credentials: -- `OPENAI_API_KEY` -- `Milvus_URI` -- `Milvus_API_TOKEN` +- `milvus_uri` +- `milvus_token` (optional) +- `openai_api_key` + +You need to create a file with the format "xxx.db" in your current directory +and use the file path as milvus_uri. Initialize the ClientMilvus like this: ```python -# Set up your Milvus Cloud information -YOUR_MILVUS_CLOUD_END_POINT="YOUR_MILVUS_CLOUD_END_POINT" -YOUR_MILVUS_CLOUD_TOKEN="YOUR_MILVUS_CLOUD_TOKEN" +# Set up your Milvus Client information +YOUR_MILVUS_LITE_FILE_PATH = "./milvus_demo.db" # random name for milvus lite local db +OPENAI_API_KEY="" # Initialize the ClientMilvus milvus_client = ClientMilvus( - milvus_uri=YOUR_MILVUS_CLOUD_END_POINT, - milvus_token=YOUR_MILVUS_CLOUD_TOKEN + milvus_uri=YOUR_MILVUS_LITE_FILE_PATH, + openai_api_key=OPENAI_API_KEY ) ``` @@ -45,59 +51,22 @@ milvus_client = ClientMilvus( This tutorial `whyhow_rbr` uses Milvus for everything related to vector databses. -### Defining necessary variables +### Create the collection ```python # Define collection name COLLECTION_NAME="YOUR_COLLECTION_NAME" # take your own collection name - # Define vector dimension size DIMENSION=1536 # decide by the model you use -``` - -### Add schema - -Before inserting any data into Milvus database, we need to first define the data field, which is called schema in here. Through create object `CollectionSchema` and add data field through `addd_field()`, we can control our data type and their characteristics. This step is required. -```python -schema = milvus_client.create_schema(auto_id=True) # Enable id matching - -schema = milvus_client.add_field(schema=schema, field_name="id", datatype=DataType.INT64, is_primary=True) -schema = milvus_client.add_field(schema=schema, field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=DIMENSION) -``` -We only defined `id` and `embedding` here because we need to define a primary field for each collection. For embedding, we need to define the dimension. We allow `enable_dynamic_field` which support auto adding schema, but we still encourage you to add schema by yourself. This method is a thin wrapper around the official Milvus implementation ([official docs](https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Collections/create_schema.md)) - -### Creating an index - -For each schema, it is better to have an index so that the querying will be much more efficient. To create an index, we first need an index_params and later add more index data on this `IndexParams` object. -```python -# Start to indexing data field -index_params = milvus_client.prepare_index_params() -index_params = milvus_client.add_index( - index_params=index_params, # pass in index_params object - field_name="embedding", - index_type="AUTOINDEX", # use autoindex instead of other complex indexing method - metric_type="COSINE", # L2, COSINE, or IP -) -``` -This method is a thin wrapper around the official Milvus implementation ([official docs](https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Management/add_index.md)). - -### Create Collection - -After defining all the data field and indexing them, we now need to create our database collection so that we can access our data quick and precise. What's need to be mentioned is that we initialized the `enable_dynamic_field` to be true so that you can upload any data freely. The cost is the data querying might be inefficient. -```python # Create Collection -milvus_client.create_collection( - collection_name=COLLECTION_NAME, - schema=schema, - index_params=index_params -) +milvus_client.create_collection(collection_name=COLLECTION_NAME, dimension=DIMENSION) ``` ## Uploading documents After creating a collection, we are ready to populate it with documents. In -`whyhow_rbr` this is done using the `upload_documents` method of the `MilvusClient`. +`whyhow_rbr` this is done using the `upload_documents` method of the `ClientMilvus`. It performs the following steps under the hood: - **Preprocessing**: Reading and splitting the provided PDF files into chunks @@ -112,7 +81,6 @@ pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] # replace to your pdfs path # Uploading the PDF document milvus_client.upload_documents( - collection_name=COLLECTION_NAME, documents=pdfs ) ``` @@ -120,20 +88,19 @@ milvus_client.upload_documents( Now we can finally move to retrieval augmented generation. -In `whyhow_rbr` with Milvus, it can be done via the `search` method. +In `whyhow_rbr` with Milvus, it can be done via the `query` method. -1. Simple example: +1. Simple example without rules: ```python # Search data and implement RAG! -res = milvus_client.search( - question='What food does Harry Potter like to eat?', - collection_name=COLLECTION_NAME, - anns_field='embedding', - output_fields='text' +result = milvus_client.query( + question="What is Harry Potter's favorite food?", + process_rules_separately=True, + keyword_trigger=False, ) -print(res['answer']) -print(res['matches']) +print(result["answer"]) +print(result["matches"]) ``` The `result` is a dictionary that has the following keys @@ -142,7 +109,7 @@ The `result` is a dictionary that has the following keys - `matches` - the `limit` most relevant documents from the index Note that the number of matches will be in general equal to `limit` which -can be specified as a parameter. +can be specified as a parameter. The default value is 5. ### Clean up @@ -150,14 +117,12 @@ At last, after implemented all the instructuons, you can clean up the database by calling `drop_collection()`. ```python # Clean up -milvus_client.drop_collection( - collection_name=COLLECTION_NAME -) +milvus_client.drop_collection() ``` ### Rules -In the previous example, every single document in our index was considered. +In the previous example, every single document in our collection was considered. However, sometimes it might be beneficial to only retrieve documents satisfying some predefined conditions (e.g. `filename=harry-potter.pdf`). In `whyhow_rbr` through Milvus, this can be done via adjusting searching parameters. @@ -166,37 +131,41 @@ A rule can control the following metadata attributes - `filename` - name of the file - `page_numbers` - list of integers corresponding to page numbers (0 indexing) -- `id` - unique identifier of a chunk (this is the most "extreme" filter) +- `uuid` - unique identifier of a chunk (this is the most "extreme" filter) +- `keywords` - list of keywords to trigger the rule - Other rules base on [Boolean Expressions](https://milvus.io/docs/boolean.md) Rules Example: ```python -# RULES(search on book harry-potter on page 8): -PARTITION_NAME='harry-potter' # search on books -page_number='page_number == 8' - -# first create a partitions to store the book and later search on this specific partition: -milvus_client.crate_partition( - collection_name=COLLECTION_NAME, - partition_name=PARTITION_NAME # separate base on your pdfs type -) +# RULES: +rules = [ + MilvusRule( + # Replace with your rule + filename="harry-potter.pdf", + page_numbers=[120, 121, 150], + ), + MilvusRule( + # Replace with your rule + filename="harry-potter.pdf", + page_numbers=[120, 121, 150], + keywords=["food", "favorite", "likes to eat"] + ), +] # search with rules -res = milvus_client.search( - question='Tell me about the greedy method', - collection_name=COLLECTION_NAME, - partition_names=PARTITION_NAME, - filter=page_number, # append any rules follow the Boolean Expression Rule - anns_field='embedding', - output_fields='text' +res = milvus_client.query( + question="What is Harry Potter's favorite food?", + rules=rules, + process_rules_separately=True, + keyword_trigger=False, ) -print(res['answer']) -print(res['matches']) +print(res["answer"]) +print(res["matches"]) ``` -In this example, we first create a partition that store harry-potter related pdfs, and through searching within this partition, we can get the most direct information. -Also, we apply page number as a filter to specify the exact page we wish to search on. -Remember, the filer parameter need to follow the [boolean rule](https://milvus.io/docs/boolean.md). +In this example, the process_rules_separately parameter is set to True. This means that each rule will be processed independently, ensuring that both rules contribute to the final result set. + +By default, all rules are run as one joined query, which means that one rule can dominate the others, and given the return limit, a lower priority rule might not return any results. However, by setting process_rules_separately to True, each rule will be processed independently, ensuring that every rule returns results, and the results will be combined at the end. That's all for the Milvus implementation of Rule-based Retrieval. \ No newline at end of file diff --git a/examples/milvus_tutorial.py b/examples/milvus_tutorial.py index 01a6311..e60b088 100644 --- a/examples/milvus_tutorial.py +++ b/examples/milvus_tutorial.py @@ -1,94 +1,59 @@ """Script that demonstrates how to use the RAG model with Milvus to implement rule-based retrieval.""" -import os +from whyhow_rbr.rag_milvus import ClientMilvus, MilvusRule -from pymilvus import DataType - -from src.whyhow_rbr.rag_milvus import ClientMilvus - -# Set up your Milvus Cloud information -YOUR_MILVUS_CLOUD_END_POINT = os.getenv("YOUR_MILVUS_CLOUD_END_POINT") -YOUR_MILVUS_CLOUD_TOKEN = os.getenv("YOUR_MILVUS_CLOUD_TOKEN") - -# Initialize the ClientMilvus -milvus_client = ClientMilvus( - milvus_uri=YOUR_MILVUS_CLOUD_END_POINT, - milvus_token=YOUR_MILVUS_CLOUD_TOKEN, -) +# Set up your Milvus Client information +YOUR_MILVUS_LITE_FILE_PATH = "./milvus_demo.db" # local file name used by Milvus Lite to persist data # Define collection name COLLECTION_NAME = "YOUR_COLLECTION_NAME" # take your own collection name -# Create necessary schema to store data -DIMENSION = 1536 # decide by the model you use - -schema = milvus_client.create_schema(auto_id=True) # Enable id matching - -schema = milvus_client.add_field( - schema=schema, field_name="id", datatype=DataType.INT64, is_primary=True -) -schema = milvus_client.add_field( - schema=schema, - field_name="embedding", - datatype=DataType.FLOAT_VECTOR, - dim=DIMENSION, -) - - -# Start to indexing data field -index_params = milvus_client.prepare_index_params() -index_params = milvus_client.add_index( - index_params=index_params, # pass in index_params object - field_name="embedding", - index_type="AUTOINDEX", # use autoindex instead of other complex indexing method - metric_type="COSINE", # L2, COSINE, or IP +# Initialize the ClientMilvus +milvus_client = ClientMilvus( + milvus_uri=YOUR_MILVUS_LITE_FILE_PATH, + openai_api_key="", ) # Create Collection -milvus_client.create_collection( - collection_name=COLLECTION_NAME, schema=schema, index_params=index_params -) - - -# Create a Partition, list it out -milvus_client.crate_partition( - collection_name=COLLECTION_NAME, - partition_name="xxx", # Put in your own partition name, better fit the document you upload -) - -partitions = milvus_client.list_partition(collection_name=COLLECTION_NAME) -print(partitions) +milvus_client.create_collection(collection_name=COLLECTION_NAME) # Uploading the PDF document -# get pdfs -pdfs = ["harry-potter.pdf", "game-of-thrones.pdf"] # replace to your pdfs path +# get pdfs from data directory in current directory +pdfs = ["data/1.pdf", "data/2.pdf"] # replace to your pdfs path -milvus_client.upload_documents( - collection_name=COLLECTION_NAME, partition_name="xxx", documents=pdfs -) + +milvus_client.upload_documents(documents=pdfs) # add your rules: -filter = "" -partition_names = None +rules = [ + MilvusRule( + # Replace with your filename + filename="data/1.pdf", + page_numbers=[], + ), + MilvusRule( + # Replace with your filename + filename="data/2.pdf", + page_numbers=[], + ), +] # Search data and implement RAG! -res = milvus_client.search( - question="Tell me about the greedy method", - collection_name=COLLECTION_NAME, - filter=filter, - partition_names=None, - anns_field="embedding", - output_fields="text", +res = milvus_client.query( + question="YOUR_QUESTIONS", + rules=rules, + process_rules_separately=True, + keyword_trigger=False, ) print(res["answer"]) print(res["matches"]) # Clean up -milvus_client.drop_collection(collection_name=COLLECTION_NAME) +milvus_client.drop_collection() diff --git a/pyproject.toml b/pyproject.toml index 7ba3f9e..5c1bc4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,13 +20,14 @@ dependencies = [ "pydantic>1", "pypdf", "tiktoken", + "pymilvus", ] dynamic = ["version"] [project.urls] Homepage = "https://whyhow.ai" -Documentation = "https://whyhow-ai.github.io/rule-based-retrieval/" -"Issue Tracker" = "https://github.com/whyhow-ai/rule-based-retrieval/issues" +Documentation = "https://whyhow-ai.github.io/rule-based-retrieval/" +"Issue Tracker" = "https://github.com/whyhow-ai/rule-based-retrieval/issues" [project.optional-dependencies] @@ -103,7 +104,10 @@ warn_required_dynamic_aliases = true [tool.pytest.ini_options] filterwarnings = [ - "error" + "error", + "ignore::DeprecationWarning:pkg_resources.*", + "ignore::UserWarning", + "ignore::DeprecationWarning" ] testpaths = [ "tests", diff --git a/src/whyhow_rbr/__init__.py b/src/whyhow_rbr/__init__.py index 678f15f..e3e582a 100644 --- a/src/whyhow_rbr/__init__.py +++ b/src/whyhow_rbr/__init__.py @@ -1,15 +1,29 @@ """SDK.""" +# TODO from whyhow_rbr.exceptions import ( + CollectionAlreadyExistsException, + CollectionCreateFailureException, + CollectionNotFoundException, IndexAlreadyExistsException, IndexNotFoundException, + OpenAIException, ) from whyhow_rbr.rag import Client, Rule +from whyhow_rbr.rag_milvus import ClientMilvus, MilvusRule __version__ = "v0.1.4" __all__ = [ + # Client "Client", + "ClientMilvus", + "Rule", + "MilvusRule", + # Error "IndexAlreadyExistsException", "IndexNotFoundException", - "Rule", + "OpenAIException", + "CollectionNotFoundException", + "CollectionAlreadyExistsException", + "CollectionCreateFailureException", ] diff --git a/src/whyhow_rbr/exceptions.py b/src/whyhow_rbr/exceptions.py index 2b2105e..ca93091 100644 --- a/src/whyhow_rbr/exceptions.py +++ b/src/whyhow_rbr/exceptions.py @@ -31,37 +31,7 @@ class CollectionAlreadyExistsException(Exception): pass -class SchemaCreateFailureException(Exception): - """Raised when fail to create a new schema.""" - - pass - - class CollectionCreateFailureException(Exception): """Raised when fail to create a new collection.""" pass - - -class AddSchemaFieldFailureException(Exception): - """Raised when fail to add a field to schema.""" - - pass - - -class PartitionCreateFailureException(Exception): - """Raised when fail to create a partition.""" - - pass - - -class PartitionDropFailureException(Exception): - """Raised when fail to drop a partition.""" - - pass - - -class PartitionListFailureException(Exception): - """Raised when fail to list all partitions.""" - - pass diff --git a/src/whyhow_rbr/rag_milvus.py b/src/whyhow_rbr/rag_milvus.py index ddc7778..8ddb7fa 100644 --- a/src/whyhow_rbr/rag_milvus.py +++ b/src/whyhow_rbr/rag_milvus.py @@ -4,32 +4,27 @@ import os import pathlib import re +import uuid from typing import Any, Dict, List, Optional, TypedDict, cast from langchain_core.documents import Document from openai import OpenAI -from pydantic import BaseModel, ValidationError -from pymilvus import CollectionSchema, DataType, MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pydantic import BaseModel, Field, ValidationError, field_validator +from pymilvus import DataType, MilvusClient, MilvusException from whyhow_rbr.embedding import generate_embeddings from whyhow_rbr.exceptions import ( - AddSchemaFieldFailureException, CollectionAlreadyExistsException, CollectionCreateFailureException, CollectionNotFoundException, OpenAIException, - PartitionCreateFailureException, - PartitionDropFailureException, - PartitionListFailureException, - SchemaCreateFailureException, ) from whyhow_rbr.processing import clean_chunks, parse_and_split logger = logging.getLogger(__name__) -class MilvusMetadata(BaseModel): +class MilvusMetadata(BaseModel, extra="forbid"): """The metadata to be stored in Milvus. Attributes @@ -51,10 +46,89 @@ class MilvusMetadata(BaseModel): page_number: int chunk_number: int filename: str - vector: List[float] + uuid: str = Field(default_factory=lambda: str(uuid.uuid4())) -"""Custom classes for constructing prompt, output and query result with examples""" +class MilvusMatch(BaseModel, extra="ignore"): + """The match returned from Milvus. + + Attributes + ---------- + id : str + The ID of the document. + + score : float + The score of the match. Its meaning depends on the metric used for + the index. + + metadata : MilvusMetadata + The metadata of the document. + + """ + + id: str + score: float + metadata: MilvusMetadata + + +class MilvusRule(BaseModel): + """Retrieval rule. + + The rule is used to filter the documents in the index. + + Attributes + ---------- + filename : str | None + The filename of the document. + + uuid : str | None + The UUID of the document. + + page_numbers : list[int] | None + The page numbers of the document. + + keywords : list[str] | None + The keywords to trigger a rule. + """ + + filename: str | None = None + uuid: str | None = None + page_numbers: list[int] | None = None + keywords: list[str] | None = None + + @field_validator("page_numbers", mode="before") + @classmethod + def convert_empty_to_none( + cls, v: Optional[List[int]] + ) -> Optional[List[int]]: + """Convert empty list to None.""" + if v is not None and not v: + return None + return v + + def convert_empty_str_to_none( + cls, s: Optional[List[str]] + ) -> Optional[List[str]]: + """Convert empty string list to None.""" + if s is not None and not s: + return None + return s + + def to_filter(self) -> str: + """Convert rule to Milvus filter format.""" + if not any([self.filename, self.uuid, self.page_numbers]): + return "" + + conditions: List[str] = [] + if self.filename is not None: + conditions.append(f'filename == "{self.filename}"') + if self.uuid is not None: + conditions.append(f'id == "{self.uuid}"') + if self.page_numbers is not None: + conditions.append(f"page_number in {self.page_numbers}") + + filter_ = " and ".join(conditions) + return filter_ class Input(BaseModel): @@ -198,10 +272,8 @@ class ClientMilvus: def __init__( self, milvus_uri: str, - milvus_token: str, - milvus_db_name: Optional[str] = None, - timeout: float | None = None, - openai_api_key: str | None = None, + milvus_token: Optional[str] = None, + openai_api_key: Optional[str] = None, ): if openai_api_key is None: openai_api_key = os.environ.get("OPENAI_API_KEY") @@ -214,457 +286,103 @@ def __init__( self.milvus_client = MilvusClient( uri=milvus_uri, token=milvus_token, - db_name=milvus_db_name, - timeout=timeout, ) - def get_collection_stats( - self, collection_name: str, timeout: Optional[float] = None - ) -> Dict[str, Any]: - """Get an existing collection. + def create_collection( + self, + collection_name: str, + dimension: int = 1536, + ) -> None: + """ + Initialize a collection. Parameters ---------- - collection_name : str - The name of the collection. - - timeout : Optional[float] - The timeout duration for this operation. - Setting this to None indicates that this operation timeouts when any response returns or error occurs. + dimension: int + The dimension of the collection. Returns ------- - Dict - A dictionary that contains detailed information about the specified collection. - - Raises - ------ - CollectionNotFoundException - If the collection does not exist. + None """ - try: - collection_stats = self.milvus_client.describe_collection( - collection_name, timeout + self.collection_name = collection_name + if self.milvus_client.has_collection( + collection_name=self.collection_name + ): + raise CollectionAlreadyExistsException( + f"Collection {self.collection_name} already exists" ) - except MilvusException as e: - raise CollectionNotFoundException( - f"Collection {collection_name} does not exist" - ) from e - - return collection_stats - - def create_schema( - self, - auto_id: bool = False, - enable_dynamic_field: bool = True, - **kwargs: Any, - ) -> CollectionSchema: - """Create a schema to add in collection. - - Parameters - ---------- - auto_id : bool - Whether allows the primary field to automatically increment. - enable_dynamic_field : bool - Whether allows Milvus saves the values of undefined fields in a dynamic field - if the data being inserted into the target collection includes fields that are not defined in the collection's schema. - - Returns - ------- - CollectionSchema - A Schema instance represents the schema of a collection. - - Raises - ------ - SchemaCreateFailureException - If schema create failure. - """ try: - schema = MilvusClient.create_schema( - auto_id=auto_id, - enable_dynamic_field=enable_dynamic_field, - **kwargs, + # create schema, with dynamic field available + schema = self.milvus_client.create_schema( + auto_id=False, + enable_dynamic_field=True, ) - except MilvusException as e: - raise SchemaCreateFailureException("Schema create failure.") from e - - return schema - def add_field( - self, - schema: CollectionSchema, - field_name: str, - datatype: DataType, - is_primary: bool = False, - **kwargs: Any, - ) -> CollectionSchema: - """Add Field to current schema. - - Parameters - ---------- - schema : CollectionSchema - The exist schema object. - - field_name : str - The name of the new field. - - datatype : DataType - The data type of the field. - You can choose from the following options when selecting a data type for different fields: - - Primary key field: Use DataType.INT64 or DataType.VARCHAR. - - Scalar fields: Choose from a variety of options, including: - - DataType.BOOL, DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, - DataType.FLOAT, DataType.DOUBLE, DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, - DataType.FLOAT16_VECTOR, __DataType.BFLOAT16_VECTOR, DataType.VARCHAR, - DataType.JSON, and DataType.ARRAY. - - Vector fields: Select DataType.BINARY_VECTOR or DataType.FLOAT_VECTOR. - - is_primary : bool - Whether the current field is the primary field in a collection. - **Each collection has only one primary field. - - **kwargs : Any - - max_length (int) - - The maximum length of the field value. - This is mandatory for a DataType.VARCHAR field. - - element_type (str) - - The data type of the elements in the field value. - This is mandatory for a DataType.Array field. - - max_capacity (int) - - The number of elements in an Array field value. - This is mandatory for a DataType.Array field. - - dim (int) - - The dimension of the vector embeddings. - This is mandatory for a DataType.FLOAT_VECTOR field or a DataType.BINARY_VECTOR field. - - Returns - ------- - CollectionSchema - A Schema instance represents the schema of a collection. - - Raises - ------ - AddSchemaFieldFailureException - If schema create failure. - """ - try: + # add fields to schema schema.add_field( - field_name=field_name, - datatype=datatype, - is_primary=is_primary, - **kwargs, + field_name="id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=36, + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=dimension, ) - except MilvusException as e: - raise AddSchemaFieldFailureException( - f"Fail to add {field_name} to current schema." - ) from e - - return schema - - def prepare_index_params(self) -> IndexParams: - """Prepare an index object.""" - index_params = self.milvus_client.prepare_index_params() - - return index_params - - def add_index( - self, - index_params: IndexParams, - field_name: str, - index_type: str = "AUTOINDEX", - index_name: Optional[str] = None, - metric_type: str = "COSINE", - params: Optional[Dict[str, Any]] = None, - ) -> IndexParams: - """Add an index to IndexParams Object. - - Parameters - ---------- - index_params : IndexParams - index object - - field_name : str - The name of the target file to apply this object applies. - - index_name : str - The name of the index file generated after this object has been applied. - - index_type : str - The name of the algorithm used to arrange data in the specific field. - - metric_type : str - The algorithm that is used to measure similarity between vectors. Possible values are IP, L2, and COSINE. - - params : dict - The fine-tuning parameters for the specified index type. For details on possible keys and value ranges, refer to In-memory Index. - """ - index_params.add_index( - field_name=field_name, - index_type=index_type, - index_name=index_name, - metric_type=metric_type, - params=params, - ) - - return index_params - - def create_index( - self, - collection_name: str, - index_params: IndexParams, - timeout: Optional[float] = None, - **kwargs: Dict[str, Any], - ) -> None: - """Create an index. - - Parameters - ---------- - index_params : IndexParams - index object - - collection_name : str - The name of the collection. - - timeout : Optional[float] - The maximum duration to wait for the operation to complete before timing out. - """ - self.milvus_client.create_index( - collection_name=collection_name, - index_params=index_params, - timeout=timeout, - **kwargs, - ) - - def create_collection( - self, - collection_name: str, - dimension: Optional[int] = None, - metric_type: str = "COSINE", - timeout: Optional[float] = None, - schema: Optional[CollectionSchema] = None, - index_params: Optional[IndexParams] = None, - enable_dynamic_field: bool = True, - **kwargs: Any, - ) -> None: - """Create a new collection. - - If the collection does not exist, it creates a new collection with the specified. - - Parameters - ---------- - collection_name : str - [REQUIRED] - The name of the collection to create. - - dimension : int - The dimension of the vector field in the collection. - The reason choosing 1024 as default is that the model - "text-embedding-3-small" we use generates a size of 1024 embeddings - - metric_type : str - The metric used to measure similarities between vector embeddings in the collection. - - timeout : Optional[float] - The maximum duration to wait for the operation to complete before timing out. - - schema : Optional[CollectionSchema] - Defines the structure of the collection. - - enable_dynamic_field: bool: - True can insert data without creating a schema first. - Raises - ------ - CollectionAlreadyExistsException - If the collection already exists. - """ - try: - # Detect whether the collection exist or not - self.get_collection_stats(collection_name, timeout) - except CollectionNotFoundException: - pass - else: - raise CollectionAlreadyExistsException( - f"Collection {collection_name} already exists" + # prepare index parameters + index_params = self.milvus_client.prepare_index_params() + index_params.add_index( + index_type="AUTOINDEX", + field_name="vector", + metric_type="COSINE", ) - try: + # create a collection self.milvus_client.create_collection( - collection_name=collection_name, - dimension=dimension, - metric_type=metric_type, + collection_name=self.collection_name, schema=schema, index_params=index_params, - timeout=timeout, - enable_dynamic_field=enable_dynamic_field, - **kwargs, - ) - except MilvusException as e: - raise CollectionCreateFailureException( - f"Collection {collection_name} fail to create" - ) from e - - def crate_partition( - self, - collection_name: str, - partition_name: str, - timeout: Optional[float] = None, - ) -> None: - """Create a partition in collection. - - Parameters - ---------- - collection_name : str - [REQUIRED] - The name of the collection to add partition. - - partition_name : str - [REQUIRED] - The name of the partition to create. - - timeout : Optional[float] - The timeout duration for this operation. - Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. - - Raises - ------ - PartitionCreateFailureException - If partition create failure. - """ - try: - self.milvus_client.create_partition( - collection_name=collection_name, - partition_name=partition_name, - timeout=timeout, - ) - except MilvusException as e: - raise PartitionCreateFailureException( - f"Partition {partition_name} fail to create" - ) from e - - def drop_partition( - self, - collection_name: str, - partition_name: str, - timeout: Optional[float] = None, - ) -> None: - """Drop a partition in collection. - - Parameters - ---------- - collection_name : str - [REQUIRED] - The name of the collection to drop partition. - - partition_name : str - [REQUIRED] - The name of the partition to drop. - - timeout : Optional[float] - The timeout duration for this operation. - Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. - - Raises - ------ - PartitionDropFailureException - If partition drop failure. - """ - try: - self.milvus_client.drop_partition( - collection_name=collection_name, - partition_name=partition_name, - timeout=timeout, + consistency_level=0, ) - except MilvusException as e: - raise PartitionDropFailureException( - f"Partition {partition_name} fail to drop" - ) from e - - def list_partition( - self, collection_name: str, timeout: Optional[float] = None - ) -> List[str]: - """List all partitions in the specific collection. - - Parameters - ---------- - collection_name : str - [REQUIRED] - The name of the collection to add partition. - - timeout : Optional[float] - The timeout duration for this operation. - Setting this to None indicates that this operation timeouts when any response arrives or any error occurs. - Returns - ------- - partitions : list[str] - All the partitions in that specific collection. - - Raises - ------ - PartitionListFailureException - If partition listing failure. - """ - try: - partitions = self.milvus_client.list_partitions( - collection_name=collection_name, timeout=timeout + except Exception as e: + raise CollectionCreateFailureException( + f"Error {e} occurred while attempting to creat collection {self.collection_name}." ) - except MilvusException as e: - raise PartitionListFailureException( - f"Partitions from {collection_name} fail to list" - ) from e - return partitions - - def drop_collection(self, collection_name: str) -> None: + def drop_collection(self) -> None: """Delete an existing collection. - Parameters - ---------- - collection_name : str - The name of the collection. - Raises ------ CollectionNotFoundException If the collection does not exist. """ try: - self.milvus_client.drop_collection(collection_name=collection_name) + self.milvus_client.drop_collection( + collection_name=self.collection_name + ) except MilvusException as e: raise CollectionNotFoundException( - f"Collection {collection_name} not found" + f"Collection {self.collection_name} not found" ) from e def upload_documents( self, - collection_name: str, documents: List[str | pathlib.Path], - partition_name: Optional[str] = None, embedding_model: str = "text-embedding-3-small", ) -> None: """Upload documents to the index. Parameters ---------- - collection_name : str - The name of the collection - documents : list[str | pathlib.Path] The documents to upload. - partition_name : str | None - The name of the partition in that collection to insert the data - embedding_model : str The OpenAI embedding model to use. """ @@ -693,29 +411,28 @@ def upload_documents( "Number of embeddings does not match number of chunks" ) - data = [] + datas = [] for i, (chunk, embedding) in enumerate(zip(all_chunks, embeddings)): - rawdata = MilvusMetadata( + metadata = MilvusMetadata( text=chunk.page_content, page_number=chunk.metadata["page"], chunk_number=chunk.metadata["chunk"], filename=chunk.metadata["source"], - vector=embedding, ) - metadata = { - "text": rawdata.text, - "page_number": str(rawdata.page_number), - "chunk_number": str(rawdata.chunk_number), - "filename": rawdata.filename, - "embedding": list(rawdata.vector), + data = { + "id": metadata.uuid, + "vector": embedding, + "text": metadata.text, + "page_number": metadata.page_number, + "chunk_number": metadata.chunk_number, + "filename": metadata.filename, } - data.append(metadata) + datas.append(data) response = self.milvus_client.insert( - collection_name=collection_name, - partition_name=partition_name, - data=data, + collection_name=self.collection_name, + data=datas, ) insert_count = response["insert_count"] @@ -750,42 +467,30 @@ def create_search_params( return search_params - def search( + def query( self, question: str, - collection_name: str, - anns_field: Optional[str] = None, - partition_names: Optional[List[str]] = None, - filter: str = "", + rules: list[MilvusRule] | None = None, limit: int = 5, - output_fields: Optional[List[str]] = None, - search_params: Optional[Dict[str, Any]] = None, chat_model: str = "gpt-4-1106-preview", chat_temperature: float = 0.0, chat_max_tokens: int = 1000, chat_seed: int = 2, embedding_model: str = "text-embedding-3-small", + process_rules_separately: bool = True, + keyword_trigger: bool = False, **kwargs: Dict[str, Any], ) -> QueryReturnType: """Query the index. Parameters ---------- - collection_name : str - Name of the collection. - - anns_field : str - Specific Field to search on. - question : str The question to ask. limit : int The maximum number of answers to return. - output_fields : str - The field that should return. - chat_model : str The OpenAI chat model to use. @@ -818,43 +523,115 @@ def search( include the chat model not finishing or the response not being valid JSON. """ - if output_fields is None: - output_fields = ["text", "filename", "page_number"] + logger.info(f"Raw rules: {rules}") + + if rules is None: + rules = [] + + if keyword_trigger: + triggered_rules = [] + clean_question = self.clean_text(question).split(" ") + + for rule in rules: + if rule.keywords: + clean_keywords = [ + self.clean_text(keyword) for keyword in rule.keywords + ] - if search_params is None: - search_params = {} + if bool(set(clean_keywords) & set(clean_question)): + triggered_rules.append(rule) - logger.info(f"Filter: {filter} and Search params: {search_params}") + rules = triggered_rules + + rule_filters = [rule.to_filter() for rule in rules if rule is not None] # size of 1024 question_embedding = generate_embeddings( openai_api_key=self.openai_client.api_key, chunks=[question], model=embedding_model, - )[0] + ) + matches = [] match_texts: List[str] = [] - results: Optional[List[Any]] = [] - i = 0 - while results is not None and i < 5: - results = self.milvus_client.search( - collection_name=collection_name, - anns_field=anns_field, - partition_names=partition_names, - filter=filter, - data=[question_embedding], - output_fields=[output_fields], + # Check if there are any rule filters, and if not, proceed with a default query + if not rule_filters: + # Perform a default query + query_response = self.milvus_client.search( + collection_name=self.collection_name, limit=limit, - search_params=search_params, - **kwargs, + data=question_embedding, + output_fields=["*"], ) - i += 1 + matches = [ + MilvusMatch( + id=m["id"], + score=m["distance"], + metadata=MilvusMetadata( + text=m["entity"]["text"], + page_number=m["entity"]["page_number"], + chunk_number=m["entity"]["chunk_number"], + filename=m["entity"]["filename"], + ), + ) + for m in query_response[0] + ] + match_texts = [m.metadata.text for m in matches] - if results is not None: - for result in results: - text = result[0]["entity"]["text"] - match_texts.append(text) + else: + + if process_rules_separately: + for rule_filter in rule_filters: + query_response = self.milvus_client.search( + collection_name=self.collection_name, + data=question_embedding, + filter=rule_filter, + limit=limit, + output_fields=["*"], + ) + matches = [ + MilvusMatch( + id=m["id"], + score=m["distance"], + metadata=MilvusMetadata( + text=m["entity"]["text"], + page_number=m["entity"]["page_number"], + chunk_number=m["entity"]["chunk_number"], + filename=m["entity"]["filename"], + ), + ) + for m in query_response[0] + ] + match_texts = [m.metadata.text for m in matches] + match_texts = list( + set(match_texts) + ) # Ensure unique match texts + else: + if rule_filters: + rule_filter = " or ".join(rule_filters) + + query_response = self.milvus_client.search( + collection_name=self.collection_name, + data=question_embedding, + filter=rule_filter, + limit=limit, + output_fields=["*"], + ) + matches = [ + MilvusMatch( + id=m["id"], + score=m["distance"], + metadata=MilvusMetadata( + text=m["entity"]["text"], + page_number=m["entity"]["page_number"], + chunk_number=m["entity"]["chunk_number"], + filename=m["entity"]["filename"], + ), + ) + for m in query_response[0] + ] + match_texts = [m.metadata.text for m in matches] # Proceed to create prompt, send it to OpenAI, and handle the response prompt = self.create_prompt(question, match_texts) @@ -870,13 +647,10 @@ def search( return_dict: QueryReturnType = { "answer": output.answer, - "matches": [], + "matches": [m.model_dump() for m in matches], "used_contexts": output.contexts, } - if results is not None and len(results) > 0: - return_dict["matches"] = results[0] - return return_dict def create_prompt(self, question: str, match_texts: list[str]) -> str: