Skip to content

Commit 47b3a9e

Browse files
committed
test: add unit tests for search operators
Add comprehensive unit tests for new operators introduced in search orchestration: - Rank operator: test KNN ranking, arithmetic ops, min/max functions - Sparse Log KNN: test sparse vector search with masks and overlaps - Sparse KNN Merge: test merging and overfetch scenarios - Select operator: test field selection and empty records
1 parent 73504bd commit 47b3a9e

File tree

4 files changed

+651
-0
lines changed

4 files changed

+651
-0
lines changed

rust/worker/src/execution/operators/rank.rs

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,215 @@ impl Operator<RankInput, RankOutput> for Rank {
206206
Ok(RankOutput { ranks })
207207
}
208208
}
209+
210+
#[cfg(test)]
211+
mod tests {
212+
use super::*;
213+
214+
#[tokio::test]
215+
async fn test_rank_with_knn_results() {
216+
// Setup KNN results
217+
let mut knn_results = HashMap::new();
218+
let query = KnnQuery {
219+
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1, 0.2, 0.3]),
220+
key: String::new(),
221+
limit: 3,
222+
};
223+
knn_results.insert(
224+
query.clone(),
225+
vec![
226+
RecordMeasure {
227+
offset_id: 1,
228+
measure: 0.9,
229+
},
230+
RecordMeasure {
231+
offset_id: 2,
232+
measure: 0.7,
233+
},
234+
RecordMeasure {
235+
offset_id: 3,
236+
measure: 0.5,
237+
},
238+
],
239+
);
240+
241+
// Test simple KNN rank
242+
let rank = Rank::Knn {
243+
embedding: query.embedding.clone(),
244+
key: String::new(),
245+
limit: query.limit,
246+
default: None,
247+
ordinal: false,
248+
};
249+
let input = RankInput {
250+
knn_results,
251+
blockfile_provider: BlockfileProvider::new_memory(),
252+
};
253+
254+
let output = rank.run(&input).await.expect("Rank should succeed");
255+
assert_eq!(output.ranks.len(), 3);
256+
assert_eq!(output.ranks[0].offset_id, 1);
257+
assert_eq!(output.ranks[0].measure, 0.9);
258+
}
259+
260+
#[tokio::test]
261+
async fn test_rank_arithmetic_operations() {
262+
// Setup two KNN queries
263+
let mut knn_results = HashMap::new();
264+
let query1 = KnnQuery {
265+
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1]),
266+
key: String::new(),
267+
limit: 2,
268+
};
269+
let query2 = KnnQuery {
270+
embedding: chroma_types::operator::QueryVector::Sparse(chroma_types::SparseVector {
271+
indices: vec![0],
272+
values: vec![1.0],
273+
}),
274+
key: "sparse".to_string(),
275+
limit: 2,
276+
};
277+
278+
knn_results.insert(
279+
query1.clone(),
280+
vec![
281+
RecordMeasure {
282+
offset_id: 1,
283+
measure: 0.8,
284+
},
285+
RecordMeasure {
286+
offset_id: 2,
287+
measure: 0.6,
288+
},
289+
],
290+
);
291+
knn_results.insert(
292+
query2.clone(),
293+
vec![
294+
RecordMeasure {
295+
offset_id: 1,
296+
measure: 0.4,
297+
},
298+
RecordMeasure {
299+
offset_id: 3,
300+
measure: 0.2,
301+
},
302+
],
303+
);
304+
305+
// Test summation
306+
let rank = Rank::Summation(vec![
307+
Rank::Knn {
308+
embedding: query1.embedding.clone(),
309+
key: String::new(),
310+
limit: query1.limit,
311+
default: None,
312+
ordinal: false,
313+
},
314+
Rank::Knn {
315+
embedding: query2.embedding.clone(),
316+
key: "sparse".to_string(),
317+
limit: query2.limit,
318+
default: None,
319+
ordinal: false,
320+
},
321+
]);
322+
let input = RankInput {
323+
knn_results: knn_results.clone(),
324+
blockfile_provider: BlockfileProvider::new_memory(),
325+
};
326+
327+
let output = rank.run(&input).await.expect("Rank should succeed");
328+
// Record 1 appears in both: 0.8 + 0.4 = 1.2
329+
assert_eq!(output.ranks[0].offset_id, 1);
330+
assert_eq!(output.ranks[0].measure, 1.2);
331+
332+
// Test multiplication with constant
333+
let rank = Rank::Multiplication(vec![
334+
Rank::Knn {
335+
embedding: query1.embedding.clone(),
336+
key: String::new(),
337+
limit: query1.limit,
338+
default: None,
339+
ordinal: false,
340+
},
341+
Rank::Value(0.5),
342+
]);
343+
let input = RankInput {
344+
knn_results,
345+
blockfile_provider: BlockfileProvider::new_memory(),
346+
};
347+
348+
let output = rank.run(&input).await.expect("Rank should succeed");
349+
assert_eq!(output.ranks[0].offset_id, 1);
350+
assert_eq!(output.ranks[0].measure, 0.4); // 0.8 * 0.5
351+
}
352+
353+
#[tokio::test]
354+
async fn test_rank_min_max_functions() {
355+
let mut knn_results = HashMap::new();
356+
let query = KnnQuery {
357+
embedding: chroma_types::operator::QueryVector::Dense(vec![0.1]),
358+
key: String::new(),
359+
limit: 2,
360+
};
361+
362+
knn_results.insert(
363+
query.clone(),
364+
vec![
365+
RecordMeasure {
366+
offset_id: 1,
367+
measure: 0.8,
368+
},
369+
RecordMeasure {
370+
offset_id: 2,
371+
measure: 0.3,
372+
},
373+
],
374+
);
375+
376+
// Test max
377+
let rank = Rank::Maximum(vec![
378+
Rank::Knn {
379+
embedding: query.embedding.clone(),
380+
key: String::new(),
381+
limit: query.limit,
382+
default: None,
383+
ordinal: false,
384+
},
385+
Rank::Value(0.5),
386+
]);
387+
let input = RankInput {
388+
knn_results: knn_results.clone(),
389+
blockfile_provider: BlockfileProvider::new_memory(),
390+
};
391+
392+
let output = rank.run(&input).await.expect("Rank should succeed");
393+
assert_eq!(output.ranks[0].offset_id, 1);
394+
assert_eq!(output.ranks[0].measure, 0.8); // max(0.8, 0.5) = 0.8
395+
assert_eq!(output.ranks[1].offset_id, 2);
396+
assert_eq!(output.ranks[1].measure, 0.5); // max(0.3, 0.5) = 0.5
397+
398+
// Test min
399+
let rank = Rank::Minimum(vec![
400+
Rank::Knn {
401+
embedding: query.embedding.clone(),
402+
key: String::new(),
403+
limit: query.limit,
404+
default: None,
405+
ordinal: false,
406+
},
407+
Rank::Value(0.5),
408+
]);
409+
let input = RankInput {
410+
knn_results,
411+
blockfile_provider: BlockfileProvider::new_memory(),
412+
};
413+
414+
let output = rank.run(&input).await.expect("Rank should succeed");
415+
assert_eq!(output.ranks[0].offset_id, 1);
416+
assert_eq!(output.ranks[0].measure, 0.5); // min(0.8, 0.5) = 0.5
417+
assert_eq!(output.ranks[1].offset_id, 2);
418+
assert_eq!(output.ranks[1].measure, 0.3); // min(0.3, 0.5) = 0.3
419+
}
420+
}

rust/worker/src/execution/operators/select.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,124 @@ impl Operator<SelectInput, SelectOutput> for Select {
195195
})
196196
}
197197
}
198+
199+
#[cfg(test)]
200+
mod tests {
201+
use super::*;
202+
use chroma_log::test::{upsert_generator, LoadFromGenerator, LogGenerator};
203+
use chroma_segment::test::TestDistributedSegment;
204+
use chroma_system::Operator;
205+
use std::collections::HashSet;
206+
207+
async fn setup_select_input() -> (TestDistributedSegment, SelectInput) {
208+
let mut test_segment = TestDistributedSegment::new().await;
209+
test_segment
210+
.populate_with_generator(10, upsert_generator)
211+
.await;
212+
213+
let records = vec![
214+
RecordMeasure {
215+
offset_id: 1,
216+
measure: 0.9,
217+
},
218+
RecordMeasure {
219+
offset_id: 5,
220+
measure: 0.7,
221+
},
222+
RecordMeasure {
223+
offset_id: 8,
224+
measure: 0.5,
225+
},
226+
];
227+
228+
let input = SelectInput {
229+
records,
230+
logs: upsert_generator.generate_chunk(11..=15),
231+
blockfile_provider: test_segment.blockfile_provider.clone(),
232+
record_segment: test_segment.record_segment.clone(),
233+
};
234+
235+
(test_segment, input)
236+
}
237+
238+
#[tokio::test]
239+
async fn test_select_with_score_only() {
240+
let (_test_segment, input) = setup_select_input().await;
241+
242+
let mut fields = HashSet::new();
243+
fields.insert(SelectField::Score);
244+
245+
let select_operator = Select { fields };
246+
247+
let output = select_operator
248+
.run(&input)
249+
.await
250+
.expect("Select should succeed");
251+
252+
assert_eq!(output.records.len(), 3);
253+
254+
// Check first record - ID should always be present
255+
assert!(!output.records[0].id.is_empty());
256+
assert_eq!(output.records[0].score, Some(0.9));
257+
assert!(output.records[0].document.is_none());
258+
assert!(output.records[0].embedding.is_none());
259+
assert!(output.records[0].metadata.is_none());
260+
261+
// Check scores are preserved
262+
assert_eq!(output.records[1].score, Some(0.7));
263+
assert_eq!(output.records[2].score, Some(0.5));
264+
}
265+
266+
#[tokio::test]
267+
async fn test_select_with_all_fields() {
268+
let (_test_segment, input) = setup_select_input().await;
269+
270+
let mut fields = HashSet::new();
271+
fields.insert(SelectField::Document);
272+
fields.insert(SelectField::Embedding);
273+
fields.insert(SelectField::Metadata);
274+
fields.insert(SelectField::Score);
275+
276+
let select_operator = Select { fields };
277+
278+
let output = select_operator
279+
.run(&input)
280+
.await
281+
.expect("Select should succeed");
282+
283+
assert_eq!(output.records.len(), 3);
284+
285+
// Check all fields are present
286+
for record in &output.records {
287+
assert!(!record.id.is_empty());
288+
assert!(record.document.is_some());
289+
assert!(record.embedding.is_some());
290+
assert!(record.metadata.is_some());
291+
assert!(record.score.is_some());
292+
}
293+
}
294+
295+
#[tokio::test]
296+
async fn test_select_empty_records() {
297+
let test_segment = TestDistributedSegment::new().await;
298+
299+
let input = SelectInput {
300+
records: vec![],
301+
logs: upsert_generator.generate_chunk(1..=5),
302+
blockfile_provider: test_segment.blockfile_provider,
303+
record_segment: test_segment.record_segment,
304+
};
305+
306+
let mut fields = HashSet::new();
307+
fields.insert(SelectField::Score);
308+
309+
let select_operator = Select { fields };
310+
311+
let output = select_operator
312+
.run(&input)
313+
.await
314+
.expect("Select should succeed");
315+
316+
assert_eq!(output.records.len(), 0);
317+
}
318+
}

0 commit comments

Comments
 (0)