Skip to content

Commit 99059fd

Browse files
committed
[ENH]: Allow specifiying multiple filter keys in get_statistics
1 parent cb626b1 commit 99059fd

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

chromadb/test/distributed/test_statistics_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
234234

235235
# Get statistics filtered by "category" key only
236236
category_stats = get_statistics(
237-
collection, "key_filter_test_statistics", key="category"
237+
collection, "key_filter_test_statistics", keys=["category"]
238238
)
239239
assert "category" in category_stats["statistics"]
240240
assert "score" not in category_stats["statistics"]
@@ -246,7 +246,9 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
246246
assert category_stats["summary"]["total_count"] == 3
247247

248248
# Get statistics filtered by "score" key only
249-
score_stats = get_statistics(collection, "key_filter_test_statistics", key="score")
249+
score_stats = get_statistics(
250+
collection, "key_filter_test_statistics", keys=["score"]
251+
)
250252
assert "score" in score_stats["statistics"]
251253
assert "category" not in score_stats["statistics"]
252254
assert "active" not in score_stats["statistics"]

chromadb/utils/statistics.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
>>> print(stats)
2727
"""
2828

29-
from typing import TYPE_CHECKING, Optional, Dict, Any, cast
29+
from typing import TYPE_CHECKING, Optional, Dict, Any, List, cast
3030
from collections import defaultdict
3131

3232
from chromadb.api.types import Where
@@ -121,7 +121,9 @@ def detach_statistics_function(
121121

122122

123123
def get_statistics(
124-
collection: "Collection", stats_collection_name: str, key: Optional[str] = None
124+
collection: "Collection",
125+
stats_collection_name: str,
126+
keys: Optional[List[str]] = None,
125127
) -> Dict[str, Any]:
126128
"""Get the current statistics for a collection.
127129
@@ -131,8 +133,8 @@ def get_statistics(
131133
Args:
132134
collection: The collection to get statistics for
133135
stats_collection_name: Name of the statistics collection to read from.
134-
key: Optional metadata key to filter statistics for. If provided,
135-
only returns statistics for that specific key.
136+
keys: Optional list of metadata keys to filter statistics for. If provided,
137+
only returns statistics for those specific keys.
136138
137139
Returns:
138140
Dict[str, Any]: A dictionary with the structure:
@@ -174,7 +176,19 @@ def get_statistics(
174176
"total_count": 2
175177
}
176178
}
179+
180+
Raises:
181+
ValueError: If more than 30 keys are provided in the keys filter.
177182
"""
183+
# Validate keys count to avoid issues with large $in queries
184+
MAX_KEYS = 30
185+
if keys is not None and len(keys) > MAX_KEYS:
186+
raise ValueError(
187+
f"Too many keys provided: {len(keys)}. "
188+
f"Maximum allowed is {MAX_KEYS} keys per request. "
189+
"Consider calling get_statistics multiple times with smaller key batches."
190+
)
191+
178192
# Import here to avoid circular dependency
179193
from chromadb.api.models.Collection import Collection
180194

@@ -198,11 +212,9 @@ def get_statistics(
198212
summary: Dict[str, Any] = {}
199213

200214
offset = 0
201-
# When filtering by key, also include "summary" entries to get total_count
215+
# When filtering by keys, also include "summary" entries to get total_count
202216
where_filter: Optional[Where] = (
203-
cast(Where, {"$or": [{"key": key}, {"key": "summary"}]})
204-
if key is not None
205-
else None
217+
cast(Where, {"key": {"$in": keys + ["summary"]}}) if keys is not None else None
206218
)
207219

208220
while True:

0 commit comments

Comments
 (0)