Skip to content

Commit 700e575

Browse files
committed
[ENH] Implement search for python client
1 parent 4e544d0 commit 700e575

File tree

11 files changed

+461
-11
lines changed

11 files changed

+461
-11
lines changed

chromadb/api/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Sequence, Optional
2+
from typing import Sequence, Optional, List
33
from uuid import UUID
44

55
from overrides import override
@@ -26,8 +26,10 @@
2626
QueryResult,
2727
GetResult,
2828
WhereDocument,
29+
SearchResult,
2930
)
3031
from chromadb.auth import UserIdentity
32+
from chromadb.execution.expression.plan import SearchPayload
3133
from chromadb.config import Component, Settings
3234
from chromadb.types import Database, Tenant, Collection as CollectionModel
3335
import chromadb.utils.embedding_functions as ef
@@ -648,6 +650,16 @@ def _fork(
648650
) -> CollectionModel:
649651
pass
650652

653+
@abstractmethod
654+
def _search(
655+
self,
656+
collection_id: UUID,
657+
searches: List[SearchPayload],
658+
tenant: str = DEFAULT_TENANT,
659+
database: str = DEFAULT_DATABASE,
660+
) -> SearchResult:
661+
pass
662+
651663
@abstractmethod
652664
@override
653665
def _count(

chromadb/api/async_api.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Sequence, Optional
2+
from typing import Sequence, Optional, List
33
from uuid import UUID
44

55
from overrides import override
@@ -28,7 +28,9 @@
2828
WhereDocument,
2929
IncludeMetadataDocuments,
3030
IncludeMetadataDocumentsDistances,
31+
SearchResult,
3132
)
33+
from chromadb.execution.expression.plan import SearchPayload
3234
from chromadb.config import Component, Settings
3335
from chromadb.types import Database, Tenant, Collection as CollectionModel
3436
import chromadb.utils.embedding_functions as ef
@@ -641,6 +643,16 @@ async def _fork(
641643
) -> CollectionModel:
642644
pass
643645

646+
@abstractmethod
647+
async def _search(
648+
self,
649+
collection_id: UUID,
650+
searches: List[SearchPayload],
651+
tenant: str = DEFAULT_TENANT,
652+
database: str = DEFAULT_DATABASE,
653+
) -> SearchResult:
654+
pass
655+
644656
@abstractmethod
645657
@override
646658
async def _count(

chromadb/api/async_fastapi.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from uuid import UUID
33
import urllib.parse
44
import orjson
5-
from typing import Any, Optional, cast, Tuple, Sequence, Dict
5+
from typing import Any, Optional, cast, Tuple, Sequence, Dict, List
66
import logging
77
import httpx
88
from overrides import override
@@ -26,6 +26,7 @@
2626
from chromadb.utils.async_to_sync import async_to_sync
2727
from chromadb.types import Database, Tenant, Collection as CollectionModel
2828
from chromadb.api.types import optional_embeddings_to_base64_strings
29+
from chromadb.execution.expression.plan import SearchPayload
2930

3031
from chromadb.api.types import (
3132
Documents,
@@ -39,11 +40,16 @@
3940
WhereDocument,
4041
GetResult,
4142
QueryResult,
43+
SearchResult,
44+
SearchRecord,
4245
CollectionMetadata,
4346
validate_batch,
4447
convert_np_embeddings_to_list,
4548
IncludeMetadataDocuments,
4649
IncludeMetadataDocumentsDistances,
50+
)
51+
52+
from chromadb.api.types import (
4753
IncludeMetadataDocumentsEmbeddings,
4854
)
4955

@@ -395,6 +401,40 @@ async def _fork(
395401
model = CollectionModel.from_json(resp_json)
396402
return model
397403

404+
@trace_method("AsyncFastAPI._search", OpenTelemetryGranularity.OPERATION)
405+
@override
406+
async def _search(
407+
self,
408+
collection_id: UUID,
409+
searches: List[SearchPayload],
410+
tenant: str = DEFAULT_TENANT,
411+
database: str = DEFAULT_DATABASE,
412+
) -> SearchResult:
413+
"""Performs hybrid search on a collection"""
414+
# Convert SearchPayload objects to dictionaries
415+
payload = {"searches": [s.to_dict() for s in searches]}
416+
417+
resp_json = await self._make_request(
418+
"post",
419+
f"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
420+
json=payload,
421+
)
422+
423+
# Parse response into SearchResult
424+
results = []
425+
for batch_results in resp_json.get("results", []):
426+
batch = []
427+
for record in batch_results:
428+
batch.append(SearchRecord(
429+
id=record["id"],
430+
document=record.get("document"),
431+
embedding=record.get("embedding"),
432+
metadata=record.get("metadata"),
433+
score=record.get("score"),
434+
))
435+
results.append(batch)
436+
return results
437+
398438
@trace_method("AsyncFastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
399439
@override
400440
async def delete_collection(

chromadb/api/fastapi.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import orjson
22
import logging
3-
from typing import Any, Dict, Optional, cast, Tuple
3+
from typing import Any, Dict, Optional, cast, Tuple, List
44
from typing import Sequence
55
from uuid import UUID
66
import httpx
@@ -17,6 +17,7 @@
1717
from chromadb.api.base_http_client import BaseHTTPClient
1818
from chromadb.types import Database, Tenant, Collection as CollectionModel
1919
from chromadb.api import ServerAPI
20+
from chromadb.execution.expression.plan import SearchPayload
2021

2122
from chromadb.api.types import (
2223
Documents,
@@ -30,11 +31,16 @@
3031
WhereDocument,
3132
GetResult,
3233
QueryResult,
34+
SearchResult,
35+
SearchRecord,
3336
CollectionMetadata,
3437
validate_batch,
3538
convert_np_embeddings_to_list,
3639
IncludeMetadataDocuments,
3740
IncludeMetadataDocumentsDistances,
41+
)
42+
43+
from chromadb.api.types import (
3844
IncludeMetadataDocumentsEmbeddings,
3945
optional_embeddings_to_base64_strings,
4046
)
@@ -352,6 +358,40 @@ def _fork(
352358
model = CollectionModel.from_json(resp_json)
353359
return model
354360

361+
@trace_method("FastAPI._search", OpenTelemetryGranularity.OPERATION)
362+
@override
363+
def _search(
364+
self,
365+
collection_id: UUID,
366+
searches: List[SearchPayload],
367+
tenant: str = DEFAULT_TENANT,
368+
database: str = DEFAULT_DATABASE,
369+
) -> SearchResult:
370+
"""Performs hybrid search on a collection"""
371+
# Convert SearchPayload objects to dictionaries
372+
payload = {"searches": [s.to_dict() for s in searches]}
373+
374+
resp_json = self._make_request(
375+
"post",
376+
f"/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search",
377+
json=payload,
378+
)
379+
380+
# Parse response into SearchResult
381+
results = []
382+
for batch_results in resp_json.get("results", []):
383+
batch = []
384+
for record in batch_results:
385+
batch.append(SearchRecord(
386+
id=record["id"],
387+
document=record.get("document"),
388+
embedding=record.get("embedding"),
389+
metadata=record.get("metadata"),
390+
score=record.get("score"),
391+
))
392+
results.append(batch)
393+
return results
394+
355395
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
356396
@override
357397
def delete_collection(

chromadb/api/models/AsyncCollection.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Optional, Union
1+
from typing import TYPE_CHECKING, Optional, Union, List
22

33
from chromadb.api.types import (
44
URI,
@@ -16,10 +16,12 @@
1616
ID,
1717
OneOrMany,
1818
WhereDocument,
19+
SearchResult,
1920
)
2021

2122
from chromadb.api.models.CollectionCommon import CollectionCommon
2223
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
24+
from chromadb.execution.expression.plan import SearchPayload
2325

2426
if TYPE_CHECKING:
2527
from chromadb.api import AsyncServerAPI # noqa: F401
@@ -288,6 +290,49 @@ async def fork(
288290
data_loader=self._data_loader,
289291
)
290292

293+
async def search(
294+
self,
295+
searches: List[SearchPayload],
296+
) -> SearchResult:
297+
"""Perform hybrid search on the collection.
298+
This is an experimental API that only works for Hosted Chroma for now.
299+
300+
Args:
301+
searches: List of SearchPayload objects, each containing:
302+
- filter: Optional filter criteria (user_ids, where)
303+
- score: Scoring expression for hybrid search
304+
- limit: Optional limit configuration (skip, fetch)
305+
- project: Optional projection configuration (fields to return)
306+
307+
Returns:
308+
SearchResult: List of search results for each search payload.
309+
Each result is a list of SearchRecord objects.
310+
311+
Raises:
312+
NotImplementedError: For local/segment API implementations
313+
314+
Example:
315+
from chromadb.execution.expression.operator import (
316+
DenseKnn, RankScore, Val, Sum, Filter, Limit, Project
317+
)
318+
from chromadb.execution.expression.plan import SearchPayload
319+
320+
payload = SearchPayload(
321+
filter=Filter(where={"category": "science"}),
322+
score=RankScore(source=DenseKnn(embedding=[0.1, 0.2, 0.3], limit=100)),
323+
limit=Limit(skip=0, fetch=10),
324+
project=Project(fields={"$document", "$score", "$metadata"})
325+
)
326+
327+
results = await collection.search([payload])
328+
"""
329+
return await self._client._search(
330+
collection_id=self.id,
331+
searches=searches,
332+
tenant=self.tenant,
333+
database=self.database,
334+
)
335+
291336
async def update(
292337
self,
293338
ids: OneOrMany[ID],

chromadb/api/models/Collection.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Optional, Union
1+
from typing import TYPE_CHECKING, Optional, Union, List
22

33
from chromadb.api.models.CollectionCommon import CollectionCommon
44
from chromadb.api.types import (
@@ -17,8 +17,10 @@
1717
ID,
1818
OneOrMany,
1919
WhereDocument,
20+
SearchResult,
2021
)
2122
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
23+
from chromadb.execution.expression.plan import SearchPayload
2224

2325
import logging
2426

@@ -292,6 +294,49 @@ def fork(
292294
data_loader=self._data_loader,
293295
)
294296

297+
def search(
298+
self,
299+
searches: List[SearchPayload],
300+
) -> SearchResult:
301+
"""Perform hybrid search on the collection.
302+
This is an experimental API that only works for Hosted Chroma for now.
303+
304+
Args:
305+
searches: List of SearchPayload objects, each containing:
306+
- filter: Optional filter criteria (user_ids, where)
307+
- score: Scoring expression for hybrid search
308+
- limit: Optional limit configuration (skip, fetch)
309+
- project: Optional projection configuration (fields to return)
310+
311+
Returns:
312+
SearchResult: List of search results for each search payload.
313+
Each result is a list of SearchRecord objects.
314+
315+
Raises:
316+
NotImplementedError: For local/segment API implementations
317+
318+
Example:
319+
from chromadb.execution.expression.operator import (
320+
DenseKnn, RankScore, Val, Sum, Filter, Limit, Project
321+
)
322+
from chromadb.execution.expression.plan import SearchPayload
323+
324+
payload = SearchPayload(
325+
filter=Filter(where={"category": "science"}),
326+
score=RankScore(source=DenseKnn(embedding=[0.1, 0.2, 0.3], limit=100)),
327+
limit=Limit(skip=0, fetch=10),
328+
project=Project(fields={"$document", "$score", "$metadata"})
329+
)
330+
331+
results = collection.search([payload])
332+
"""
333+
return self._client._search(
334+
collection_id=self.id,
335+
searches=searches,
336+
tenant=self.tenant,
337+
database=self.database,
338+
)
339+
295340
def update(
296341
self,
297342
ids: OneOrMany[ID],

chromadb/api/rust.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@
3434
IncludeMetadataDocuments,
3535
IncludeMetadataDocumentsDistances,
3636
IncludeMetadataDocumentsEmbeddings,
37+
SearchResult,
3738
)
3839

3940
# TODO(hammadb): Unify imports across types vs root __init__.py
4041
from chromadb.types import Database, Tenant, Collection as CollectionModel
42+
from chromadb.execution.expression.plan import SearchPayload
4143
import chromadb_rust_bindings
4244

4345

44-
from typing import Optional, Sequence
46+
from typing import Optional, Sequence, List
4547
from overrides import override
4648
from uuid import UUID
4749
import json
@@ -310,6 +312,18 @@ def _fork(
310312
"Collection forking is not implemented for Local Chroma"
311313
)
312314

315+
@override
316+
def _search(
317+
self,
318+
collection_id: UUID,
319+
searches: List[SearchPayload],
320+
tenant: str = DEFAULT_TENANT,
321+
database: str = DEFAULT_DATABASE,
322+
) -> SearchResult:
323+
raise NotImplementedError(
324+
"Search is not implemented for Local Chroma"
325+
)
326+
313327
@override
314328
def _count(
315329
self,

0 commit comments

Comments
 (0)