Skip to content

Commit f07fb90

Browse files
committed
In progress
1 parent 24090a7 commit f07fb90

File tree

8 files changed

+149
-117
lines changed

8 files changed

+149
-117
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/types/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ validator = { workspace = true }
2222
regex = { workspace = true }
2323
regex-syntax = { workspace = true }
2424
utoipa = { workspace = true }
25+
sprs = { workspace = true }
2526

2627
# (Cross-crate testing dependencies)
2728
proptest = { workspace = true, optional = true }

rust/types/src/metadata.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use chroma_error::{ChromaError, ErrorCodes};
22
use serde::{Deserialize, Serialize};
33
use serde_json::{Number, Value};
4+
use sprs::CsVec;
45
use std::{
56
cmp::Ordering,
67
collections::{HashMap, HashSet},
@@ -84,6 +85,23 @@ impl From<SparseVector> for chroma_proto::SparseVector {
8485
}
8586
}
8687

88+
/// Convert SparseVector to sprs::CsVec for efficient sparse operations
89+
impl From<&SparseVector> for CsVec<f32> {
90+
fn from(sparse: &SparseVector) -> Self {
91+
let (indices, values) = sparse
92+
.iter()
93+
.map(|(index, value)| (index as usize, value))
94+
.unzip();
95+
CsVec::new(u32::MAX as usize, indices, values)
96+
}
97+
}
98+
99+
impl From<SparseVector> for CsVec<f32> {
100+
fn from(sparse: SparseVector) -> Self {
101+
(&sparse).into()
102+
}
103+
}
104+
87105
#[cfg(feature = "pyo3")]
88106
impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
89107
type Target = pyo3::PyAny;

rust/worker/src/execution/operators/score.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ impl Operator<ScoreInput, ScoreOutput> for Score {
184184
ranks: &input.ranks,
185185
};
186186
let score_domain = score_provider.eval(self.clone());
187-
Ok(ScoreOutput {
188-
scores: score_domain
189-
.support
190-
.into_iter()
191-
.map(|(offset_id, measure)| RecordDistance { offset_id, measure })
192-
.collect(),
193-
})
187+
let mut scores = score_domain
188+
.support
189+
.into_iter()
190+
.map(|(offset_id, measure)| RecordDistance { offset_id, measure })
191+
.collect::<Vec<_>>();
192+
scores.sort_unstable();
193+
Ok(ScoreOutput { scores })
194194
}
195195
}

rust/worker/src/execution/operators/sparse_knn.rs

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use std::{
2-
cmp::Reverse,
3-
collections::{BinaryHeap, HashMap},
4-
};
1+
use std::{cmp::Reverse, collections::BinaryHeap};
52

63
use async_trait::async_trait;
74
use chroma_blockstore::provider::BlockfileProvider;
@@ -16,7 +13,7 @@ use chroma_segment::{
1613
use chroma_system::Operator;
1714
use chroma_types::{
1815
operator::{Rank, RecordDistance},
19-
MaterializedLogOperation, MetadataValue, Segment,
16+
MaterializedLogOperation, MetadataValue, Segment, SparseVector,
2017
};
2118
use sprs::CsVec;
2219
use thiserror::Error;
@@ -64,7 +61,7 @@ impl ChromaError for SparseKnnError {
6461

6562
#[derive(Clone, Debug)]
6663
pub struct SparseKnn {
67-
pub embedding: HashMap<u32, f32>,
64+
pub embedding: SparseVector,
6865
pub key: String,
6966
pub limit: u32,
7067
}
@@ -74,17 +71,8 @@ impl Operator<SparseKnnInput, SparseKnnOutput> for SparseKnn {
7471
type Error = SparseKnnError;
7572

7673
async fn run(&self, input: &SparseKnnInput) -> Result<SparseKnnOutput, SparseKnnError> {
77-
let mut query_raw_vector = self
78-
.embedding
79-
.iter()
80-
.map(|(index, val)| (*index as usize, *val))
81-
.collect::<Vec<_>>();
82-
query_raw_vector.sort_unstable_by_key(|(index, _)| *index);
83-
let (query_raw_indexes, query_raw_values) = query_raw_vector
84-
.into_iter()
85-
.unzip::<usize, f32, Vec<_>, Vec<_>>();
86-
let query_sparse_verctor =
87-
CsVec::new(u32::MAX as usize, query_raw_indexes, query_raw_values);
74+
// Convert SparseVector to sprs::CsVec
75+
let query_sparse_verctor: CsVec<f32> = (&self.embedding).into();
8876
let record_segment_reader = match RecordSegmentReader::from_segment(
8977
&input.record_segment,
9078
&input.blockfile_provider,
@@ -118,16 +106,8 @@ impl Operator<SparseKnnInput, SparseKnnOutput> for SparseKnn {
118106
else {
119107
continue;
120108
};
121-
let mut log_raw_vector = sparse_vector
122-
.iter()
123-
.map(|(index, val)| (*index as usize, *val))
124-
.collect::<Vec<_>>();
125-
log_raw_vector.sort_unstable_by_key(|(index, _)| *index);
126-
let (log_raw_indexes, log_raw_values) = log_raw_vector
127-
.into_iter()
128-
.unzip::<usize, f32, Vec<_>, Vec<_>>();
129-
let log_sparse_verctor =
130-
CsVec::new(u32::MAX as usize, log_raw_indexes, log_raw_values);
109+
// Convert SparseVector to sprs::CsVec
110+
let log_sparse_verctor: CsVec<f32> = sparse_vector.into();
131111
let score = query_sparse_verctor.dot(&log_sparse_verctor);
132112
if (min_heap.len() as u32) < self.limit {
133113
min_heap.push(Reverse(RecordDistance {
@@ -170,7 +150,11 @@ impl Operator<SparseKnnInput, SparseKnnOutput> for SparseKnn {
170150

171151
let sorted_compact_records = sparse_reader
172152
.wand(
173-
self.embedding.clone(),
153+
self.embedding
154+
.indices
155+
.iter()
156+
.copied()
157+
.zip(self.embedding.values.iter().copied()),
174158
self.limit,
175159
input.mask.compact_offset_ids.clone(),
176160
)

rust/worker/src/execution/orchestration/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ mod compact;
77
pub mod get;
88
pub mod knn;
99
pub mod knn_filter;
10-
pub mod retrieve;
10+
pub mod search;

0 commit comments

Comments
 (0)