Skip to content

Commit 4afa332

Browse files
committed
[ENH] Implement retrieve orchestrartor
1 parent 5a95d37 commit 4afa332

File tree

13 files changed

+1208
-79
lines changed

13 files changed

+1208
-79
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ humantime = { version = "2.2.0" }
6969
petgraph = { version = "0.8.1" }
7070
base64 = "0.22"
7171
tikv-jemallocator = { version = "0.6.0", features = ["profiling"] }
72+
sprs = "0.11.3"
7273

7374
chroma-benchmark = { path = "rust/benchmark" }
7475
chroma-blockstore = { path = "rust/blockstore" }

rust/worker/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ num_cpus = { workspace = true }
4444
flatbuffers = { workspace = true }
4545
tantivy = { workspace = true }
4646
clap = { workspace = true }
47+
sprs = { workspace = true }
4748

4849
chroma-blockstore = { workspace = true }
4950
chroma-cache = { workspace = true }

rust/worker/benches/query.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use chroma_benchmark::{
66
datasets::sift::Sift1MData,
77
};
88
use chroma_config::{registry::Registry, Configurable};
9-
use chroma_segment::test::TestDistributedSegment;
9+
use chroma_segment::{
10+
distributed_hnsw::{DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentReader},
11+
test::TestDistributedSegment,
12+
};
1013
use chroma_system::{ComponentHandle, Dispatcher, Orchestrator, System};
1114
use chroma_types::operator::{Knn, KnnProjection};
1215
use criterion::{criterion_group, criterion_main, Criterion};
@@ -29,12 +32,10 @@ fn trivial_knn_filter(
2932
dispatcher_handle: ComponentHandle<Dispatcher>,
3033
) -> KnnFilterOrchestrator {
3134
let blockfile_provider = test_segments.blockfile_provider.clone();
32-
let hnsw_provider = test_segments.hnsw_provider.clone();
3335
let collection_uuid = test_segments.collection.collection_id;
3436
KnnFilterOrchestrator::new(
3537
blockfile_provider,
3638
dispatcher_handle,
37-
hnsw_provider,
3839
1000,
3940
test_segments.into(),
4041
empty_fetch_log(collection_uuid),
@@ -47,12 +48,10 @@ fn always_true_knn_filter(
4748
dispatcher_handle: ComponentHandle<Dispatcher>,
4849
) -> KnnFilterOrchestrator {
4950
let blockfile_provider = test_segments.blockfile_provider.clone();
50-
let hnsw_provider = test_segments.hnsw_provider.clone();
5151
let collection_uuid = test_segments.collection.collection_id;
5252
KnnFilterOrchestrator::new(
5353
blockfile_provider,
5454
dispatcher_handle,
55-
hnsw_provider,
5655
1000,
5756
test_segments.into(),
5857
empty_fetch_log(collection_uuid),
@@ -65,12 +64,10 @@ fn always_false_knn_filter(
6564
dispatcher_handle: ComponentHandle<Dispatcher>,
6665
) -> KnnFilterOrchestrator {
6766
let blockfile_provider = test_segments.blockfile_provider.clone();
68-
let hnsw_provider = test_segments.hnsw_provider.clone();
6967
let collection_uuid = test_segments.collection.collection_id;
7068
KnnFilterOrchestrator::new(
7169
blockfile_provider,
7270
dispatcher_handle,
73-
hnsw_provider,
7471
1000,
7572
test_segments.into(),
7673
empty_fetch_log(collection_uuid),
@@ -81,13 +78,15 @@ fn always_false_knn_filter(
8178
fn knn(
8279
test_segments: &TestDistributedSegment,
8380
dispatcher_handle: ComponentHandle<Dispatcher>,
81+
hnsw_reader: Option<DistributedHNSWSegmentReader>,
8482
knn_filter_output: KnnFilterOutput,
8583
query: Vec<f32>,
8684
) -> KnnOrchestrator {
8785
KnnOrchestrator::new(
8886
test_segments.blockfile_provider.clone(),
8987
dispatcher_handle.clone(),
9088
1000,
89+
hnsw_reader,
9190
knn_filter_output.clone(),
9291
Knn {
9392
embedding: query,
@@ -126,6 +125,21 @@ fn bench_query(criterion: &mut Criterion) {
126125
let runtime = tokio_multi_thread();
127126
let test_segments = runtime.block_on(sift1m_segments());
128127

128+
let hnsw_reader = match runtime.block_on(DistributedHNSWSegmentReader::from_segment(
129+
&test_segments.collection.clone(),
130+
&test_segments.vector_segment,
131+
test_segments
132+
.collection
133+
.dimension
134+
.expect("Collection dimension should be non-zero") as usize,
135+
test_segments.hnsw_provider.clone(),
136+
)) {
137+
Ok(hnsw_reader) => Some(*hnsw_reader),
138+
Err(err) if matches!(*err, DistributedHNSWSegmentFromSegmentError::Uninitialized) => None,
139+
140+
Err(err) => panic!("{err}"),
141+
};
142+
129143
let config = RootConfig::default();
130144
let system = System::default();
131145
let registry = Registry::new();
@@ -159,6 +173,7 @@ fn bench_query(criterion: &mut Criterion) {
159173
knn(
160174
&test_segments,
161175
dispatcher_handle.clone(),
176+
hnsw_reader.clone(),
162177
knn_filter_output.clone(),
163178
query.clone(),
164179
),
@@ -183,6 +198,7 @@ fn bench_query(criterion: &mut Criterion) {
183198
knn(
184199
&test_segments,
185200
dispatcher_handle.clone(),
201+
hnsw_reader.clone(),
186202
knn_filter_output.clone(),
187203
query.clone(),
188204
),
@@ -207,6 +223,7 @@ fn bench_query(criterion: &mut Criterion) {
207223
knn(
208224
&test_segments,
209225
dispatcher_handle.clone(),
226+
hnsw_reader.clone(),
210227
knn_filter_output.clone(),
211228
query.clone(),
212229
),

rust/worker/src/execution/operators/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@ pub mod prefetch_segment;
2121
pub mod projection;
2222
pub mod purge_dirty_log;
2323
pub mod repair_log_offsets;
24+
pub mod reverse_project;
25+
pub mod score;
2426
pub mod source_record_segment;
27+
pub mod sparse_knn;
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::collections::HashMap;
2+
3+
use async_trait::async_trait;
4+
use chroma_blockstore::provider::BlockfileProvider;
5+
use chroma_segment::{
6+
blockfile_record::{RecordSegmentReader, RecordSegmentReaderCreationError},
7+
types::{materialize_logs, LogMaterializerError},
8+
};
9+
use chroma_system::Operator;
10+
use chroma_types::{
11+
operator::{KnnProjectionOutput, Rank, RecordDistance},
12+
Segment,
13+
};
14+
use futures::future::try_join_all;
15+
use tracing::{Instrument, Span};
16+
17+
use crate::execution::operators::{fetch_log::FetchLogOutput, projection::ProjectionError};
18+
19+
#[derive(Clone, Debug)]
20+
pub struct ReverseProjectionInput {
21+
pub logs: FetchLogOutput,
22+
pub blockfile_provider: BlockfileProvider,
23+
pub record_segment: Segment,
24+
pub projection_outputs: HashMap<Rank, KnnProjectionOutput>,
25+
}
26+
27+
#[derive(Clone, Debug)]
28+
pub struct ReverseProjectionOutput {
29+
pub rank_records: HashMap<Rank, Vec<RecordDistance>>,
30+
}
31+
32+
// NOTE: This is a temporary operator that aims to reverse
33+
// the projection by converting user id to offset id
34+
#[derive(Clone, Debug)]
35+
pub struct ReverseProjection {}
36+
37+
#[async_trait]
38+
impl Operator<ReverseProjectionInput, ReverseProjectionOutput> for ReverseProjection {
39+
type Error = ProjectionError;
40+
41+
async fn run(
42+
&self,
43+
input: &ReverseProjectionInput,
44+
) -> Result<ReverseProjectionOutput, ProjectionError> {
45+
tracing::trace!(
46+
"Reversing projection on {} ranks",
47+
input.projection_outputs.len()
48+
);
49+
let record_segment_reader = match RecordSegmentReader::from_segment(
50+
&input.record_segment,
51+
&input.blockfile_provider,
52+
)
53+
.await
54+
{
55+
Ok(reader) => Ok(Some(reader)),
56+
Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => {
57+
Ok(None)
58+
}
59+
Err(e) => Err(*e),
60+
}?;
61+
let materialized_logs = materialize_logs(&record_segment_reader, input.logs.clone(), None)
62+
.instrument(tracing::trace_span!(parent: Span::current(), "Materialize logs"))
63+
.await?;
64+
let borrowed_logs = materialized_logs.iter().collect::<Vec<_>>();
65+
66+
let hydrated_futures = borrowed_logs.iter().map(|log| async {
67+
let hydrated_log = log.hydrate(record_segment_reader.as_ref()).await?;
68+
<Result<_, LogMaterializerError>>::Ok((
69+
hydrated_log.get_user_id(),
70+
hydrated_log.get_offset_id(),
71+
))
72+
});
73+
let log_user_id_to_offset_id = try_join_all(hydrated_futures)
74+
.await?
75+
.into_iter()
76+
.collect::<HashMap<_, _>>();
77+
let mut rank_records = HashMap::with_capacity(input.projection_outputs.len());
78+
79+
for (rank, projection_output) in &input.projection_outputs {
80+
let resolve_futures = projection_output.records.iter().map(|record| async {
81+
match log_user_id_to_offset_id.get(record.record.id.as_str()) {
82+
Some(&offset_id) => Ok(RecordDistance {
83+
offset_id,
84+
measure: record.distance.unwrap_or_default(),
85+
}),
86+
None => {
87+
if let Some(reader) = &record_segment_reader {
88+
match reader.get_offset_id_for_user_id(&record.record.id).await? {
89+
Some(offset_id) => Ok(RecordDistance {
90+
offset_id,
91+
measure: record.distance.unwrap_or_default(),
92+
}),
93+
None => Err(ProjectionError::RecordSegmentPhantomRecord(u32::MAX)),
94+
}
95+
} else {
96+
Err(ProjectionError::RecordSegmentUninitialized)
97+
}
98+
}
99+
}
100+
});
101+
rank_records.insert(rank.clone(), try_join_all(resolve_futures).await?);
102+
}
103+
104+
Ok(ReverseProjectionOutput { rank_records })
105+
}
106+
}

0 commit comments

Comments
 (0)