Skip to content

Commit aa8a6e1

Browse files
committed
[CHORE]: Change store_tokens to get_tokens
1 parent dae1739 commit aa8a6e1

File tree

6 files changed

+1650
-511
lines changed

6 files changed

+1650
-511
lines changed

chromadb/utils/embedding_functions/chroma_bm25_embedding_function.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
2424

25+
2526
class _HashedToken:
2627
__slots__ = ("hash", "label")
2728

@@ -47,7 +48,7 @@ class ChromaBm25Config(TypedDict, total=False):
4748
avg_doc_length: float
4849
token_max_length: int
4950
stopwords: List[str]
50-
store_tokens: bool
51+
include_tokens: bool
5152

5253

5354
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@@ -58,15 +59,15 @@ def __init__(
5859
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
5960
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
6061
stopwords: Optional[Iterable[str]] = None,
61-
store_tokens: bool = False,
62+
include_tokens: bool = False,
6263
) -> None:
6364
"""Initialize the BM25 sparse embedding function."""
6465

6566
self.k = float(k)
6667
self.b = float(b)
6768
self.avg_doc_length = float(avg_doc_length)
6869
self.token_max_length = int(token_max_length)
69-
self.store_tokens = bool(store_tokens)
70+
self.include_tokens = bool(include_tokens)
7071

7172
if stopwords is not None:
7273
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
@@ -87,28 +88,30 @@ def _encode(self, text: str) -> SparseVector:
8788

8889
doc_len = float(len(tokens))
8990
counts = Counter(
90-
_HashedToken(self._hasher.hash(token), token if self.store_tokens else None)
91+
_HashedToken(
92+
self._hasher.hash(token), token if self.include_tokens else None
93+
)
9194
for token in tokens
9295
)
9396

9497
sorted_keys = sorted(counts.keys())
9598
indices: List[int] = []
9699
values: List[float] = []
97-
tokens: Optional[List[str]] = [] if self.store_tokens else None
100+
labels: Optional[List[str]] = [] if self.include_tokens else None
98101

99102
for key in sorted_keys:
100103
tf = float(counts[key])
101104
denominator = tf + self.k * (
102105
1 - self.b + (self.b * doc_len) / self.avg_doc_length
103106
)
104107
score = tf * (self.k + 1) / denominator
105-
108+
106109
indices.append(key.hash)
107110
values.append(score)
108-
if tokens is not None:
109-
tokens.append(key.label)
111+
if labels is not None and key.label is not None:
112+
labels.append(key.label)
110113

111-
return SparseVector(indices=indices, values=values, labels=tokens)
114+
return SparseVector(indices=indices, values=values, labels=labels)
112115

113116
def __call__(self, input: Documents) -> SparseVectors:
114117
sparse_vectors: SparseVectors = []
@@ -138,7 +141,7 @@ def build_from_config(
138141
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
139142
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
140143
stopwords=config.get("stopwords"),
141-
store_tokens=config.get("store_tokens", False),
144+
include_tokens=config.get("include_tokens", False),
142145
)
143146

144147
def get_config(self) -> Dict[str, Any]:
@@ -147,7 +150,7 @@ def get_config(self) -> Dict[str, Any]:
147150
"b": self.b,
148151
"avg_doc_length": self.avg_doc_length,
149152
"token_max_length": self.token_max_length,
150-
"store_tokens": self.store_tokens,
153+
"include_tokens": self.include_tokens,
151154
}
152155

153156
if self.stopwords is not None:
@@ -158,11 +161,18 @@ def get_config(self) -> Dict[str, Any]:
158161
def validate_config_update(
159162
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
160163
) -> None:
161-
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords", "store_tokens"}
164+
mutable_keys = {
165+
"k",
166+
"b",
167+
"avg_doc_length",
168+
"token_max_length",
169+
"stopwords",
170+
"include_tokens",
171+
}
162172
for key in new_config:
163173
if key not in mutable_keys:
164174
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
165175

166176
@staticmethod
167177
def validate_config(config: Dict[str, Any]) -> None:
168-
validate_config_schema(config, NAME)
178+
validate_config_schema(config, NAME)

chromadb/utils/embedding_functions/chroma_cloud_splade_embedding_function.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from chromadb.api.types import (
22
SparseEmbeddingFunction,
3+
SparseVector,
34
SparseVectors,
45
Documents,
56
)
@@ -21,7 +22,7 @@ def __init__(
2122
self,
2223
api_key_env_var: str = "CHROMA_API_KEY",
2324
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
24-
store_tokens: bool = False,
25+
include_tokens: bool = False,
2526
):
2627
"""
2728
Initialize the ChromaCloudSpladeEmbeddingFunction.
@@ -50,7 +51,7 @@ def __init__(
5051
f"or in any existing client instances"
5152
)
5253
self.model = model
53-
self.store_tokens = bool(store_tokens)
54+
self.include_tokens = bool(include_tokens)
5455
self._api_url = "https://embed.trychroma.com/embed_sparse"
5556
self._session = httpx.Client()
5657
self._session.headers.update(
@@ -89,7 +90,7 @@ def __call__(self, input: Documents) -> SparseVectors:
8990
"texts": list(input),
9091
"task": "",
9192
"target": "",
92-
"fetch_tokens": "true" if self.store_tokens is True else "false",
93+
"fetch_tokens": "true" if self.include_tokens is True else "false",
9394
}
9495

9596
try:
@@ -123,14 +124,14 @@ def _parse_response(self, response: Any) -> SparseVectors:
123124
if isinstance(emb, dict):
124125
indices = emb.get("indices", [])
125126
values = emb.get("values", [])
126-
raw_labels = emb.get("labels") if self.store_tokens else None
127+
raw_labels = emb.get("labels") if self.include_tokens else None
127128
labels: Optional[List[str]] = raw_labels if raw_labels else None
128129
else:
129130
# Already a SparseVector, extract its data
130-
assert(isinstance(emb, SparseVector))
131+
assert isinstance(emb, SparseVector)
131132
indices = emb.indices
132133
values = emb.values
133-
labels = emb.labels if self.store_tokens else None
134+
labels = emb.labels if self.include_tokens else None
134135

135136
normalized_vectors.append(
136137
normalize_sparse_vector(indices=indices, values=values, labels=labels)
@@ -155,23 +156,25 @@ def build_from_config(
155156
return ChromaCloudSpladeEmbeddingFunction(
156157
api_key_env_var=api_key_env_var,
157158
model=ChromaCloudSpladeEmbeddingModel(model),
158-
store_tokens=config.get("store_tokens", False),
159+
include_tokens=config.get("include_tokens", False),
159160
)
160161

161162
def get_config(self) -> Dict[str, Any]:
162163
return {
163164
"api_key_env_var": self.api_key_env_var,
164165
"model": self.model.value,
165-
"store_tokens": self.store_tokens,
166+
"include_tokens": self.include_tokens,
166167
}
167168

168169
def validate_config_update(
169170
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
170171
) -> None:
171-
immutable_keys = {"store_tokens", "model"}
172+
immutable_keys = {"include_tokens", "model"}
172173
for key in immutable_keys:
173174
if key in new_config and new_config[key] != old_config.get(key):
174-
raise ValueError(f"Updating '{key}' is not supported for chroma-cloud-splade")
175+
raise ValueError(
176+
f"Updating '{key}' is not supported for chroma-cloud-splade"
177+
)
175178

176179
@staticmethod
177180
def validate_config(config: Dict[str, Any]) -> None:

rust/chroma/src/embed/bm25.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ where
3131
H: TokenHasher,
3232
{
3333
/// Whether to store tokens in the created sparse vectors.
34-
pub store_tokens: bool,
34+
pub include_tokens: bool,
3535
/// Tokenizer for converting text into tokens.
3636
pub tokenizer: T,
3737
/// Hasher for converting tokens into u32 identifiers.
@@ -57,7 +57,7 @@ impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
5757
/// - hasher: Murmur3 with seed 0, abs() behavior
5858
pub fn default_murmur3_abs() -> Self {
5959
Self {
60-
store_tokens: true,
60+
include_tokens: true,
6161
tokenizer: Bm25Tokenizer::default(),
6262
hasher: Murmur3AbsHasher::default(),
6363
k: 1.2,
@@ -78,7 +78,7 @@ where
7878

7979
let doc_len = tokens.len() as f32;
8080

81-
if self.store_tokens {
81+
if self.include_tokens {
8282
let mut token_ids = Vec::with_capacity(tokens.len());
8383
for token in tokens {
8484
let id = self.hasher.hash(&token);

0 commit comments

Comments
 (0)