Skip to content

Commit 0e3d03e

Browse files
authored
[ENH] Implement new search endpoint (#5323)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Introduces a new `search` endpoint for hybrid search ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent d1d95f5 commit 0e3d03e

File tree

15 files changed

+1668
-37
lines changed

15 files changed

+1668
-37
lines changed

clients/new-js/packages/chromadb/scripts/gen-api.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ const main = async () => {
3636

3737
// Fix the HashMap type to include null and remove duplicate number
3838
typesContent = typesContent.replace(
39-
/export type HashMap = \{\s*\[key: string\]: boolean \| number \| number \| string;\s*\};/,
40-
"export type HashMap = {\n [key: string]: boolean | number | string | null;\n};",
39+
/export type HashMap = \{\s*\[key: string\]: boolean \| number \| number \| string \| SparseVector;\s*};/,
40+
"export type HashMap = {\n [key: string]: boolean | number | string | SparseVector | null;\n};",
4141
);
4242

4343
await writeFile(typesPath, typesContent);

clients/new-js/packages/chromadb/src/api/sdk.gen.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// This file is auto-generated by @hey-api/openapi-ts
22

33
import type { Options as ClientOptions, TDataShape, Client } from '@hey-api/client-fetch';
4-
import type { GetUserIdentityData, GetUserIdentityResponse2, GetUserIdentityError, GetCollectionByCrnData, GetCollectionByCrnResponse, GetCollectionByCrnError, HealthcheckData, HealthcheckResponse, HealthcheckError, HeartbeatData, HeartbeatResponse2, HeartbeatError, PreFlightChecksData, PreFlightChecksResponse, PreFlightChecksError, ResetData, ResetResponse, ResetError, CreateTenantData, CreateTenantResponse2, CreateTenantError, GetTenantData, GetTenantResponse2, GetTenantError, UpdateTenantData, UpdateTenantResponse2, UpdateTenantError, ListDatabasesData, ListDatabasesResponse, ListDatabasesError, CreateDatabaseData, CreateDatabaseResponse2, CreateDatabaseError, DeleteDatabaseData, DeleteDatabaseResponse2, DeleteDatabaseError, GetDatabaseData, GetDatabaseResponse, GetDatabaseError, ListCollectionsData, ListCollectionsResponse, ListCollectionsError, CreateCollectionData, CreateCollectionResponse, CreateCollectionError, DeleteCollectionData, DeleteCollectionResponse, DeleteCollectionError, GetCollectionData, GetCollectionResponse, GetCollectionError, UpdateCollectionData, UpdateCollectionResponse2, UpdateCollectionError, CollectionAddData, CollectionAddResponse, CollectionCountData, CollectionCountResponse, CollectionCountError, CollectionDeleteData, CollectionDeleteResponse, CollectionDeleteError, ForkCollectionData, ForkCollectionResponse, ForkCollectionError, CollectionGetData, CollectionGetResponse, CollectionGetError, CollectionQueryData, CollectionQueryResponse, CollectionQueryError, CollectionUpdateData, CollectionUpdateResponse, CollectionUpsertData, CollectionUpsertResponse, CollectionUpsertError, CountCollectionsData, CountCollectionsResponse, CountCollectionsError, VersionData, VersionResponse } from './types.gen';
4+
import type { GetUserIdentityData, GetUserIdentityResponse2, GetUserIdentityError, GetCollectionByCrnData, GetCollectionByCrnResponse, GetCollectionByCrnError, HealthcheckData, HealthcheckResponse, HealthcheckError, HeartbeatData, HeartbeatResponse2, HeartbeatError, PreFlightChecksData, PreFlightChecksResponse, PreFlightChecksError, ResetData, ResetResponse, ResetError, CreateTenantData, CreateTenantResponse2, CreateTenantError, GetTenantData, GetTenantResponse2, GetTenantError, UpdateTenantData, UpdateTenantResponse2, UpdateTenantError, ListDatabasesData, ListDatabasesResponse, ListDatabasesError, CreateDatabaseData, CreateDatabaseResponse2, CreateDatabaseError, DeleteDatabaseData, DeleteDatabaseResponse2, DeleteDatabaseError, GetDatabaseData, GetDatabaseResponse, GetDatabaseError, ListCollectionsData, ListCollectionsResponse, ListCollectionsError, CreateCollectionData, CreateCollectionResponse, CreateCollectionError, DeleteCollectionData, DeleteCollectionResponse, DeleteCollectionError, GetCollectionData, GetCollectionResponse, GetCollectionError, UpdateCollectionData, UpdateCollectionResponse2, UpdateCollectionError, CollectionAddData, CollectionAddResponse, CollectionCountData, CollectionCountResponse, CollectionCountError, CollectionDeleteData, CollectionDeleteResponse, CollectionDeleteError, ForkCollectionData, ForkCollectionResponse, ForkCollectionError, CollectionGetData, CollectionGetResponse, CollectionGetError, CollectionQueryData, CollectionQueryResponse, CollectionQueryError, CollectionSearchData, CollectionSearchResponse, CollectionSearchError, CollectionUpdateData, CollectionUpdateResponse, CollectionUpsertData, CollectionUpsertResponse, CollectionUpsertError, CountCollectionsData, CountCollectionsResponse, CountCollectionsError, VersionData, VersionResponse } from './types.gen';
55
import { client as _heyApiClient } from './client.gen';
66

77
export type Options<TData extends TDataShape = TDataShape, ThrowOnError extends boolean = boolean> = ClientOptions<TData, ThrowOnError> & {
@@ -299,6 +299,20 @@ export class DefaultService {
299299
});
300300
}
301301

302+
/**
303+
* Search records from a collection with hybrid criterias.
304+
*/
305+
public static collectionSearch<ThrowOnError extends boolean = true>(options: Options<CollectionSearchData, ThrowOnError>) {
306+
return (options.client ?? _heyApiClient).post<CollectionSearchResponse, CollectionSearchError, ThrowOnError>({
307+
url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search',
308+
...options,
309+
headers: {
310+
'Content-Type': 'application/json',
311+
...options?.headers
312+
}
313+
});
314+
}
315+
302316
/**
303317
* Updates records in a collection by ID.
304318
*/

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ export type GetUserIdentityResponse = {
131131
};
132132

133133
export type HashMap = {
134-
[key: string]: boolean | number | number | string | SparseVector;
134+
[key: string]: boolean | number | string | SparseVector | null;
135135
};
136136

137137
export type HeartbeatResponse = {
@@ -175,6 +175,52 @@ export type RawWhereFields = {
175175
where_document?: unknown;
176176
};
177177

178+
/**
179+
* Payload for hybrid search
180+
*/
181+
export type SearchPayload = {
182+
/**
183+
* Filter criteria for search
184+
*/
185+
filter: {
186+
query_ids?: Array<string>;
187+
where_clause?: {
188+
[key: string]: unknown;
189+
};
190+
};
191+
limit: {
192+
fetch?: number;
193+
skip: number;
194+
};
195+
/**
196+
* Ranking expression for hybrid search
197+
*/
198+
rank: {
199+
[key: string]: {
200+
[key: string]: unknown;
201+
};
202+
};
203+
select: {
204+
fields: Array<string>;
205+
};
206+
};
207+
208+
export type SearchRecord = {
209+
document?: string | null;
210+
embedding?: Array<number> | null;
211+
id: string;
212+
metadata?: null | HashMap;
213+
score?: number | null;
214+
};
215+
216+
export type SearchRequestPayload = {
217+
searches: Array<SearchPayload>;
218+
};
219+
220+
export type SearchResponse = {
221+
results: Array<Array<SearchRecord>>;
222+
};
223+
178224
export type SpannConfiguration = {
179225
ef_construction?: number | null;
180226
ef_search?: number | null;
@@ -1205,6 +1251,52 @@ export type CollectionQueryResponses = {
12051251

12061252
export type CollectionQueryResponse = CollectionQueryResponses[keyof CollectionQueryResponses];
12071253

1254+
export type CollectionSearchData = {
1255+
body: SearchRequestPayload;
1256+
path: {
1257+
/**
1258+
* Tenant ID
1259+
*/
1260+
tenant: string;
1261+
/**
1262+
* Database name for the collection
1263+
*/
1264+
database: string;
1265+
/**
1266+
* Collection ID to search records from
1267+
*/
1268+
collection_id: string;
1269+
};
1270+
query?: never;
1271+
url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search';
1272+
};
1273+
1274+
export type CollectionSearchErrors = {
1275+
/**
1276+
* Unauthorized
1277+
*/
1278+
401: ErrorResponse;
1279+
/**
1280+
* Collection not found
1281+
*/
1282+
404: ErrorResponse;
1283+
/**
1284+
* Server error
1285+
*/
1286+
500: ErrorResponse;
1287+
};
1288+
1289+
export type CollectionSearchError = CollectionSearchErrors[keyof CollectionSearchErrors];
1290+
1291+
export type CollectionSearchResponses = {
1292+
/**
1293+
* Records searched from the collection
1294+
*/
1295+
200: SearchResponse;
1296+
};
1297+
1298+
export type CollectionSearchResponse = CollectionSearchResponses[keyof CollectionSearchResponses];
1299+
12081300
export type CollectionUpdateData = {
12091301
body: UpdateCollectionRecordsPayload;
12101302
path: {

clients/new-js/packages/chromadb/src/types.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { GetUserIdentityResponse, Include } from "./api";
1+
import { GetUserIdentityResponse, Include, SparseVector } from "./api";
22

33
/**
44
* User identity information including tenant and database access.
@@ -11,14 +11,17 @@ export type UserIdentity = GetUserIdentityResponse;
1111
*/
1212
export type CollectionMetadata = Record<
1313
string,
14-
boolean | number | string | null
14+
boolean | number | string | SparseVector | null
1515
>;
1616

1717
/**
1818
* Metadata that can be associated with individual records.
1919
* Values must be boolean, number, or string types.
2020
*/
21-
export type Metadata = Record<string, boolean | number | string | null>;
21+
export type Metadata = Record<
22+
string,
23+
boolean | number | string | SparseVector | null
24+
>;
2225

2326
/**
2427
* Base interface for record sets containing optional fields.

idl/chromadb/proto/query_executor.proto

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,88 @@ message KNNBatchResult {
8989
uint64 pulled_log_bytes = 2;
9090
}
9191

92+
message QueryVector {
93+
oneof vector {
94+
Vector dense = 1;
95+
SparseVector sparse = 2;
96+
}
97+
}
98+
99+
message Rank {
100+
message Division {
101+
Rank left = 1;
102+
Rank right = 2;
103+
}
104+
105+
message Knn {
106+
QueryVector embedding = 1;
107+
string key = 2;
108+
uint32 limit = 3;
109+
optional float default = 4;
110+
bool ordinal = 5;
111+
}
112+
113+
message Subtraction {
114+
Rank left = 1;
115+
Rank right = 2;
116+
}
117+
118+
message RankList {
119+
repeated Rank ranks = 1;
120+
}
121+
122+
oneof rank {
123+
Rank absolute = 1;
124+
Division division = 2;
125+
Rank exponentiation = 3;
126+
Knn knn = 4;
127+
Rank logarithm = 5;
128+
RankList maximum = 6;
129+
RankList minimum = 7;
130+
RankList multiplication = 8;
131+
Subtraction subtraction = 9;
132+
RankList summation = 10;
133+
float value = 11;
134+
}
135+
}
136+
137+
message SelectOperator {
138+
repeated string fields = 1;
139+
}
140+
141+
message SearchPayload {
142+
FilterOperator filter = 1;
143+
Rank rank = 2;
144+
LimitOperator limit = 3;
145+
SelectOperator select = 4;
146+
}
147+
148+
message SearchPlan {
149+
ScanOperator scan = 1;
150+
repeated SearchPayload payloads = 2;
151+
}
152+
153+
message SearchRecord {
154+
string id = 1;
155+
optional string document = 2;
156+
optional Vector embedding = 3;
157+
optional UpdateMetadata metadata = 4;
158+
optional float score = 5;
159+
}
160+
161+
message SearchPayloadResult {
162+
repeated SearchRecord records = 1;
163+
}
164+
165+
message SearchResult {
166+
repeated SearchPayloadResult results = 1;
167+
uint64 pulled_log_bytes = 2;
168+
}
169+
92170
service QueryExecutor {
93171
rpc Count(CountPlan) returns (CountResult) {}
94172
rpc Get(GetPlan) returns (GetResult) {}
95173
rpc KNN(KNNPlan) returns (KNNBatchResult) {}
174+
rpc Search(SearchPlan) returns (SearchResult) {}
96175
}
97176

rust/frontend/src/executor/distributed.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ use chroma_system::System;
1515
use chroma_types::chroma_proto::query_executor_client::QueryExecutorClient;
1616
use chroma_types::SegmentType;
1717
use chroma_types::{
18-
operator::{CountResult, GetResult, KnnBatchResult},
19-
plan::{Count, Get, Knn},
18+
operator::{CountResult, GetResult, KnnBatchResult, SearchResult},
19+
plan::{Count, Get, Knn, Search},
2020
ExecutorError,
2121
};
2222

@@ -210,6 +210,42 @@ impl DistributedExecutor {
210210
Ok(res.into_inner().try_into()?)
211211
}
212212

213+
pub async fn search(&mut self, plan: Search) -> Result<SearchResult, ExecutorError> {
214+
// Get the collection ID from the plan
215+
let collection_id = &plan
216+
.scan
217+
.collection_and_segments
218+
.collection
219+
.collection_id
220+
.to_string();
221+
222+
let clients = self
223+
.client_assigner
224+
.clients(collection_id)
225+
.map_err(|e| ExecutorError::Internal(e.boxed()))?;
226+
227+
// Convert plan to proto
228+
let request: chroma_types::chroma_proto::SearchPlan = plan.try_into()?;
229+
230+
let attempt_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
231+
let config = self.client_selection_config.clone();
232+
let res = {
233+
let attempt_count = attempt_count.clone();
234+
(|| async {
235+
let current_attempt =
236+
attempt_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
237+
let is_retry = current_attempt > 0;
238+
choose_query_client_weighted(&clients, &config, is_retry)?
239+
.search(Request::new(request.clone()))
240+
.await
241+
})
242+
.retry(self.backoff)
243+
.when(is_retryable_error)
244+
.await?
245+
};
246+
Ok(res.into_inner().try_into()?)
247+
}
248+
213249
pub async fn is_ready(&self) -> bool {
214250
!self.client_assigner.is_empty()
215251
}

rust/frontend/src/executor/local.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use chroma_system::ComponentHandle;
1212
use chroma_types::{
1313
operator::{
1414
CountResult, Filter, GetResult, KnnBatchResult, KnnProjectionOutput, KnnProjectionRecord,
15-
Projection, ProjectionRecord, RecordDistance,
15+
Limit, Projection, ProjectionRecord, RecordDistance, SearchResult,
1616
},
17-
plan::{Count, Get, Knn},
17+
plan::{Count, Get, Knn, Search},
1818
CollectionAndSegments, CollectionUuid, ExecutorError, HnswSpace, SegmentType,
1919
};
2020
use std::{
@@ -161,7 +161,10 @@ impl LocalExecutor {
161161
let filter_plan = Get {
162162
scan: plan.scan.clone(),
163163
filter: filter.clone(),
164-
limit: Default::default(),
164+
limit: Limit {
165+
skip: 0,
166+
fetch: None,
167+
},
165168
proj: Default::default(),
166169
};
167170

@@ -264,7 +267,10 @@ impl LocalExecutor {
264267
query_ids: Some(returned_user_ids),
265268
where_clause: None,
266269
},
267-
limit: Default::default(),
270+
limit: Limit {
271+
skip: 0,
272+
fetch: None,
273+
},
268274
proj: Projection {
269275
document: plan.proj.projection.document,
270276
embedding: false,
@@ -306,6 +312,12 @@ impl LocalExecutor {
306312
})
307313
}
308314

315+
pub async fn search(&mut self, _plan: Search) -> Result<SearchResult, ExecutorError> {
316+
Err(ExecutorError::NotImplemented(
317+
"Search operation is not implemented for local executor".to_string(),
318+
))
319+
}
320+
309321
pub async fn reset(&mut self) -> Result<(), Box<dyn ChromaError>> {
310322
self.hnsw_manager.reset().await.map_err(|err| err.boxed())?;
311323
Ok(())

rust/frontend/src/executor/mod.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use chroma_types::{
2-
operator::{CountResult, GetResult, KnnBatchResult},
3-
plan::{Count, Get, Knn},
2+
operator::{CountResult, GetResult, KnnBatchResult, SearchResult},
3+
plan::{Count, Get, Knn, Search},
44
ExecutorError, SegmentType,
55
};
66
use distributed::DistributedExecutor;
@@ -38,6 +38,12 @@ impl Executor {
3838
Executor::Local(local_executor) => local_executor.knn(plan).await,
3939
}
4040
}
41+
pub async fn search(&mut self, plan: Search) -> Result<SearchResult, ExecutorError> {
42+
match self {
43+
Executor::Distributed(distributed_executor) => distributed_executor.search(plan).await,
44+
Executor::Local(local_executor) => local_executor.search(plan).await,
45+
}
46+
}
4147
pub async fn is_ready(&self) -> bool {
4248
match self {
4349
Executor::Distributed(distributed_executor) => distributed_executor.is_ready().await,

0 commit comments

Comments
 (0)