|
2 | 2 | from uuid import UUID
|
3 | 3 | import urllib.parse
|
4 | 4 | import orjson
|
5 |
| -from typing import Any, Optional, cast, Tuple, Sequence, Dict |
| 5 | +from typing import Any, Optional, cast, Tuple, Sequence, Dict, List |
6 | 6 | import logging
|
7 | 7 | import httpx
|
8 | 8 | from overrides import override
|
|
26 | 26 | from chromadb.utils.async_to_sync import async_to_sync
|
27 | 27 | from chromadb.types import Database, Tenant, Collection as CollectionModel
|
28 | 28 | from chromadb.api.types import optional_embeddings_to_base64_strings
|
| 29 | +from chromadb.execution.expression.plan import SearchPayload |
29 | 30 |
|
30 | 31 | from chromadb.api.types import (
|
31 | 32 | Documents,
|
|
39 | 40 | WhereDocument,
|
40 | 41 | GetResult,
|
41 | 42 | QueryResult,
|
| 43 | + SearchResult, |
| 44 | + SearchRecord, |
42 | 45 | CollectionMetadata,
|
43 | 46 | validate_batch,
|
44 | 47 | convert_np_embeddings_to_list,
|
45 | 48 | IncludeMetadataDocuments,
|
46 | 49 | IncludeMetadataDocumentsDistances,
|
| 50 | +) |
| 51 | + |
| 52 | +from chromadb.api.types import ( |
47 | 53 | IncludeMetadataDocumentsEmbeddings,
|
48 | 54 | )
|
49 | 55 |
|
@@ -395,6 +401,40 @@ async def _fork(
|
395 | 401 | model = CollectionModel.from_json(resp_json)
|
396 | 402 | return model
|
397 | 403 |
|
| 404 | + @trace_method("AsyncFastAPI._search", OpenTelemetryGranularity.OPERATION) |
| 405 | + @override |
| 406 | + async def _search( |
| 407 | + self, |
| 408 | + collection_id: UUID, |
| 409 | + searches: List[SearchPayload], |
| 410 | + tenant: str = DEFAULT_TENANT, |
| 411 | + database: str = DEFAULT_DATABASE, |
| 412 | + ) -> SearchResult: |
| 413 | + """Performs hybrid search on a collection""" |
| 414 | + # Convert SearchPayload objects to dictionaries |
| 415 | + payload = {"searches": [s.to_dict() for s in searches]} |
| 416 | + |
| 417 | + resp_json = await self._make_request( |
| 418 | + "post", |
| 419 | + f"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search", |
| 420 | + json=payload, |
| 421 | + ) |
| 422 | + |
| 423 | + # Parse response into SearchResult |
| 424 | + results = [] |
| 425 | + for batch_results in resp_json.get("results", []): |
| 426 | + batch = [] |
| 427 | + for record in batch_results: |
| 428 | + batch.append(SearchRecord( |
| 429 | + id=record["id"], |
| 430 | + document=record.get("document"), |
| 431 | + embedding=record.get("embedding"), |
| 432 | + metadata=record.get("metadata"), |
| 433 | + score=record.get("score"), |
| 434 | + )) |
| 435 | + results.append(batch) |
| 436 | + return results |
| 437 | + |
398 | 438 | @trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
|
399 | 439 | @override
|
400 | 440 | async def delete_collection(
|
|
0 commit comments