Skip to content

Commit 45d4bb9

Browse files
committed
Support sparse vector metadata value type
1 parent 700e575 commit 45d4bb9

File tree

3 files changed

+101
-11
lines changed

3 files changed

+101
-11
lines changed

chromadb/api/types.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Where,
2828
WhereDocumentOperator,
2929
WhereDocument,
30+
SparseVector,
3031
)
3132
from inspect import signature
3233
from tenacity import retry
@@ -45,6 +46,9 @@
4546
"UpdateMetadata",
4647
"SearchRecord",
4748
"SearchResult",
49+
"SparseVector",
50+
"is_valid_sparse_vector",
51+
"validate_sparse_vector",
4852
]
4953
META_KEY_CHROMA_DOCUMENT = "chroma:document"
5054
T = TypeVar("T")
@@ -744,8 +748,69 @@ def validate_ids(ids: IDs) -> IDs:
744748
return ids
745749

746750

751+
def is_valid_sparse_vector(value: Any) -> bool:
752+
"""Check if a value looks like a SparseVector (has indices and values keys)."""
753+
return isinstance(value, dict) and "indices" in value and "values" in value
754+
755+
756+
def validate_sparse_vector(value: Any) -> None:
757+
"""Validate that a value is a properly formed SparseVector.
758+
759+
Args:
760+
value: The value to validate as a SparseVector
761+
762+
Raises:
763+
ValueError: If the value is not a valid SparseVector
764+
"""
765+
if not isinstance(value, dict):
766+
raise ValueError(f"Expected SparseVector to be a dict, got {type(value).__name__}")
767+
768+
if "indices" not in value or "values" not in value:
769+
raise ValueError("SparseVector must have 'indices' and 'values' keys")
770+
771+
indices = value.get("indices")
772+
values = value.get("values")
773+
774+
# Validate indices
775+
if not isinstance(indices, list):
776+
raise ValueError(
777+
f"Expected SparseVector indices to be a list, got {type(indices).__name__}"
778+
)
779+
780+
# Validate values
781+
if not isinstance(values, list):
782+
raise ValueError(
783+
f"Expected SparseVector values to be a list, got {type(values).__name__}"
784+
)
785+
786+
# Check lengths match
787+
if len(indices) != len(values):
788+
raise ValueError(
789+
f"SparseVector indices and values must have the same length, "
790+
f"got {len(indices)} indices and {len(values)} values"
791+
)
792+
793+
# Validate each index
794+
for i, idx in enumerate(indices):
795+
if not isinstance(idx, int):
796+
raise ValueError(
797+
f"SparseVector indices must be integers, got {type(idx).__name__} at position {i}"
798+
)
799+
if idx < 0:
800+
raise ValueError(
801+
f"SparseVector indices must be non-negative, got {idx} at position {i}"
802+
)
803+
804+
# Validate each value
805+
for i, val in enumerate(values):
806+
if not isinstance(val, (int, float)):
807+
raise ValueError(
808+
f"SparseVector values must be numbers, got {type(val).__name__} at position {i}"
809+
)
810+
811+
747812
def validate_metadata(metadata: Metadata) -> Metadata:
748-
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
813+
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats, bools, or SparseVectors"""
749814
if not isinstance(metadata, dict) and metadata is not None:
750815
raise ValueError(
751816
f"Expected metadata to be a dict or None, got {type(metadata).__name__} as metadata"
@@ -765,18 +830,24 @@ def validate_metadata(metadata: Metadata) -> Metadata:
765830
raise TypeError(
766831
f"Expected metadata key to be a str, got {key} which is a {type(key).__name__}"
767832
)
833+
# Check if value is a SparseVector
834+
if is_valid_sparse_vector(value):
835+
try:
836+
validate_sparse_vector(value)
837+
except ValueError as e:
838+
raise ValueError(f"Invalid SparseVector for key '{key}': {e}")
768839
# isinstance(True, int) evaluates to True, so we need to check for bools separately
769-
if not isinstance(value, bool) and not isinstance(
840+
elif not isinstance(value, bool) and not isinstance(
770841
value, (str, int, float, type(None))
771842
):
772843
raise ValueError(
773-
f"Expected metadata value to be a str, int, float, bool, or None, got {value} which is a {type(value).__name__}"
844+
f"Expected metadata value to be a str, int, float, bool, SparseVector, or None, got {value} which is a {type(value).__name__}"
774845
)
775846
return metadata
776847

777848

778849
def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
779-
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools"""
850+
"""Validates metadata to ensure it is a dictionary of strings to strings, ints, floats, bools, or SparseVectors"""
780851
if not isinstance(metadata, dict) and metadata is not None:
781852
raise ValueError(
782853
f"Expected metadata to be a dict or None, got {type(metadata)}"
@@ -788,12 +859,18 @@ def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
788859
for key, value in metadata.items():
789860
if not isinstance(key, str):
790861
raise ValueError(f"Expected metadata key to be a str, got {key}")
862+
# Check if value is a SparseVector
863+
if is_valid_sparse_vector(value):
864+
try:
865+
validate_sparse_vector(value)
866+
except ValueError as e:
867+
raise ValueError(f"Invalid SparseVector for key '{key}': {e}")
791868
# isinstance(True, int) evaluates to True, so we need to check for bools separately
792-
if not isinstance(value, bool) and not isinstance(
869+
elif not isinstance(value, bool) and not isinstance(
793870
value, (str, int, float, type(None))
794871
):
795872
raise ValueError(
796-
f"Expected metadata value to be a str, int, or float, got {value}"
873+
f"Expected metadata value to be a str, int, float, bool, SparseVector, or None, got {value}"
797874
)
798875
return metadata
799876

chromadb/base_types.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
from typing import Dict, List, Mapping, Optional, Sequence, Union
2-
from typing_extensions import Literal
2+
from typing_extensions import Literal, TypedDict
33
import numpy as np
44
from numpy.typing import NDArray
55

6-
Metadata = Mapping[str, Optional[Union[str, int, float, bool]]]
7-
UpdateMetadata = Mapping[str, Union[int, float, str, bool, None]]
6+
7+
class SparseVector(TypedDict):
8+
"""Represents a sparse vector using parallel arrays for indices and values.
9+
10+
Attributes:
11+
indices: List of dimension indices (must be non-negative integers)
12+
values: List of values corresponding to each index
13+
"""
14+
indices: List[int]
15+
values: List[float]
16+
17+
18+
Metadata = Mapping[str, Optional[Union[str, int, float, bool, SparseVector]]]
19+
UpdateMetadata = Mapping[str, Union[int, float, str, bool, SparseVector, None]]
820
PyVector = Union[Sequence[float], Sequence[int]]
921
Vector = NDArray[Union[np.int32, np.float32]] # TODO: Specify that the vector is 1D
1022
# Metadata Query Grammar

chromadb/execution/expression/operator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List, Dict, Set, Any
33

4-
from chromadb.api.types import Embeddings, IDs, Include
4+
from chromadb.api.types import Embeddings, IDs, Include, SparseVector
55
from chromadb.types import (
66
Collection,
77
RequestVersionContext,
@@ -232,11 +232,12 @@ def to_dict(self) -> Dict[str, Any]:
232232
@dataclass
233233
class SparseKnn(Rank):
234234
"""Sparse KNN ranking"""
235-
embedding: Dict[int, float] # Sparse vector: index -> value
235+
embedding: SparseVector # Sparse vector with indices and values
236236
key: str # No default for sparse KNN
237237
limit: int = 1024
238238

239239
def to_dict(self) -> Dict[str, Any]:
240+
# Convert SparseVector to the format expected by Rust API
240241
result = {"embedding": self.embedding, "key": self.key}
241242
if self.limit != 1024:
242243
result["limit"] = self.limit # type: ignore[assignment]

0 commit comments

Comments
 (0)