27
27
Where ,
28
28
WhereDocumentOperator ,
29
29
WhereDocument ,
30
+ SparseVector ,
30
31
)
31
32
from inspect import signature
32
33
from tenacity import retry
45
46
"UpdateMetadata" ,
46
47
"SearchRecord" ,
47
48
"SearchResult" ,
49
+ "SparseVector" ,
50
+ "is_valid_sparse_vector" ,
51
+ "validate_sparse_vector" ,
48
52
]
49
53
META_KEY_CHROMA_DOCUMENT = "chroma:document"
50
54
T = TypeVar ("T" )
@@ -744,8 +748,69 @@ def validate_ids(ids: IDs) -> IDs:
744
748
return ids
745
749
746
750
751
+ def is_valid_sparse_vector (value : Any ) -> bool :
752
+ """Check if a value looks like a SparseVector (has indices and values keys)."""
753
+ return isinstance (value , dict ) and "indices" in value and "values" in value
754
+
755
+
756
+ def validate_sparse_vector (value : Any ) -> None :
757
+ """Validate that a value is a properly formed SparseVector.
758
+
759
+ Args:
760
+ value: The value to validate as a SparseVector
761
+
762
+ Raises:
763
+ ValueError: If the value is not a valid SparseVector
764
+ """
765
+ if not isinstance (value , dict ):
766
+ raise ValueError (f"Expected SparseVector to be a dict, got { type (value ).__name__ } " )
767
+
768
+ if "indices" not in value or "values" not in value :
769
+ raise ValueError ("SparseVector must have 'indices' and 'values' keys" )
770
+
771
+ indices = value .get ("indices" )
772
+ values = value .get ("values" )
773
+
774
+ # Validate indices
775
+ if not isinstance (indices , list ):
776
+ raise ValueError (
777
+ f"Expected SparseVector indices to be a list, got { type (indices ).__name__ } "
778
+ )
779
+
780
+ # Validate values
781
+ if not isinstance (values , list ):
782
+ raise ValueError (
783
+ f"Expected SparseVector values to be a list, got { type (values ).__name__ } "
784
+ )
785
+
786
+ # Check lengths match
787
+ if len (indices ) != len (values ):
788
+ raise ValueError (
789
+ f"SparseVector indices and values must have the same length, "
790
+ f"got { len (indices )} indices and { len (values )} values"
791
+ )
792
+
793
+ # Validate each index
794
+ for i , idx in enumerate (indices ):
795
+ if not isinstance (idx , int ):
796
+ raise ValueError (
797
+ f"SparseVector indices must be integers, got { type (idx ).__name__ } at position { i } "
798
+ )
799
+ if idx < 0 :
800
+ raise ValueError (
801
+ f"SparseVector indices must be non-negative, got { idx } at position { i } "
802
+ )
803
+
804
+ # Validate each value
805
+ for i , val in enumerate (values ):
806
+ if not isinstance (val , (int , float )):
807
+ raise ValueError (
808
+ f"SparseVector values must be numbers, got { type (val ).__name__ } at position { i } "
809
+ )
810
+
811
+
747
812
def validate_metadata (metadata : Metadata ) -> Metadata :
748
- """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools """
813
+ """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats, bools, or SparseVectors """
749
814
if not isinstance (metadata , dict ) and metadata is not None :
750
815
raise ValueError (
751
816
f"Expected metadata to be a dict or None, got { type (metadata ).__name__ } as metadata"
@@ -765,18 +830,24 @@ def validate_metadata(metadata: Metadata) -> Metadata:
765
830
raise TypeError (
766
831
f"Expected metadata key to be a str, got { key } which is a { type (key ).__name__ } "
767
832
)
833
+ # Check if value is a SparseVector
834
+ if is_valid_sparse_vector (value ):
835
+ try :
836
+ validate_sparse_vector (value )
837
+ except ValueError as e :
838
+ raise ValueError (f"Invalid SparseVector for key '{ key } ': { e } " )
768
839
# isinstance(True, int) evaluates to True, so we need to check for bools separately
769
- if not isinstance (value , bool ) and not isinstance (
840
+ elif not isinstance (value , bool ) and not isinstance (
770
841
value , (str , int , float , type (None ))
771
842
):
772
843
raise ValueError (
773
- f"Expected metadata value to be a str, int, float, bool, or None, got { value } which is a { type (value ).__name__ } "
844
+ f"Expected metadata value to be a str, int, float, bool, SparseVector, or None, got { value } which is a { type (value ).__name__ } "
774
845
)
775
846
return metadata
776
847
777
848
778
849
def validate_update_metadata (metadata : UpdateMetadata ) -> UpdateMetadata :
779
- """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools """
850
+ """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats, bools, or SparseVectors """
780
851
if not isinstance (metadata , dict ) and metadata is not None :
781
852
raise ValueError (
782
853
f"Expected metadata to be a dict or None, got { type (metadata )} "
@@ -788,12 +859,18 @@ def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata:
788
859
for key , value in metadata .items ():
789
860
if not isinstance (key , str ):
790
861
raise ValueError (f"Expected metadata key to be a str, got { key } " )
862
+ # Check if value is a SparseVector
863
+ if is_valid_sparse_vector (value ):
864
+ try :
865
+ validate_sparse_vector (value )
866
+ except ValueError as e :
867
+ raise ValueError (f"Invalid SparseVector for key '{ key } ': { e } " )
791
868
# isinstance(True, int) evaluates to True, so we need to check for bools separately
792
- if not isinstance (value , bool ) and not isinstance (
869
+ elif not isinstance (value , bool ) and not isinstance (
793
870
value , (str , int , float , type (None ))
794
871
):
795
872
raise ValueError (
796
- f"Expected metadata value to be a str, int, or float , got { value } "
873
+ f"Expected metadata value to be a str, int, float, bool, SparseVector, or None , got { value } "
797
874
)
798
875
return metadata
799
876
0 commit comments