Skip to content

Commit 409a3b0

Browse files
committed
[ENH]: Allow specifiying multiple filter keys in get_statistics
1 parent 41e5172 commit 409a3b0

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

chromadb/test/distributed/test_statistics_wrapper.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import time
77
from typing import Any
88

9+
import pytest
10+
911
from chromadb.api.client import Client as ClientCreator
1012
from chromadb.base_types import SparseVector
1113
from chromadb.config import System
@@ -234,7 +236,7 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
234236

235237
# Get statistics filtered by "category" key only
236238
category_stats = get_statistics(
237-
collection, "key_filter_test_statistics", key="category"
239+
collection, "key_filter_test_statistics", keys=["category"]
238240
)
239241
assert "category" in category_stats["statistics"]
240242
assert "score" not in category_stats["statistics"]
@@ -246,7 +248,9 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
246248
assert category_stats["summary"]["total_count"] == 3
247249

248250
# Get statistics filtered by "score" key only
249-
score_stats = get_statistics(collection, "key_filter_test_statistics", key="score")
251+
score_stats = get_statistics(
252+
collection, "key_filter_test_statistics", keys=["score"]
253+
)
250254
assert "score" in score_stats["statistics"]
251255
assert "category" not in score_stats["statistics"]
252256
assert "active" not in score_stats["statistics"]
@@ -260,6 +264,30 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
260264
detach_statistics_function(collection, delete_stats_collection=True)
261265

262266

267+
def test_statistics_wrapper_key_filter_too_many_keys(basic_http_client: System) -> None:
268+
"""Test that get_statistics raises ValueError when more than 30 keys are provided"""
269+
client = ClientCreator.from_system(basic_http_client)
270+
client.reset()
271+
272+
collection = client.create_collection(name="too_many_keys_test")
273+
274+
# Enable statistics
275+
attach_statistics_function(collection, "too_many_keys_test_statistics")
276+
277+
# Generate more than 30 keys
278+
too_many_keys = [f"key_{i}" for i in range(31)]
279+
280+
# Should raise ValueError when more than 30 keys are provided
281+
with pytest.raises(ValueError) as exc_info:
282+
get_statistics(collection, "too_many_keys_test_statistics", keys=too_many_keys)
283+
284+
assert "Too many keys provided: 31" in str(exc_info.value)
285+
assert "Maximum allowed is 30" in str(exc_info.value)
286+
287+
# Cleanup
288+
detach_statistics_function(collection, delete_stats_collection=True)
289+
290+
263291
# commenting out for now as waiting for query cache invalidateion slows down the test suite
264292
def test_statistics_wrapper_incremental_updates(basic_http_client: System) -> None:
265293
"""Test that statistics are updated incrementally"""

chromadb/utils/statistics.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from typing import TYPE_CHECKING, Optional, Dict, Any, cast
3030
from collections import defaultdict
3131

32-
from chromadb.api.types import Where
32+
from chromadb.api.types import OneOrMany, Where, maybe_cast_one_to_many
3333

3434
if TYPE_CHECKING:
3535
from chromadb.api.models.Collection import Collection
@@ -121,7 +121,9 @@ def detach_statistics_function(
121121

122122

123123
def get_statistics(
124-
collection: "Collection", stats_collection_name: str, key: Optional[str] = None
124+
collection: "Collection",
125+
stats_collection_name: str,
126+
keys: Optional[OneOrMany[str]] = None,
125127
) -> Dict[str, Any]:
126128
"""Get the current statistics for a collection.
127129
@@ -131,8 +133,9 @@ def get_statistics(
131133
Args:
132134
collection: The collection to get statistics for
133135
stats_collection_name: Name of the statistics collection to read from.
134-
key: Optional metadata key to filter statistics for. If provided,
135-
only returns statistics for that specific key.
136+
keys: Optional metadata key(s) to filter statistics for. Can be a single key
137+
string or a list of keys. If provided, only returns statistics for
138+
those specific keys.
136139
137140
Returns:
138141
Dict[str, Any]: A dictionary with the structure:
@@ -174,7 +177,22 @@ def get_statistics(
174177
"total_count": 2
175178
}
176179
}
180+
181+
Raises:
182+
ValueError: If more than 30 keys are provided in the keys filter.
177183
"""
184+
# Normalize keys to list
185+
keys_list = maybe_cast_one_to_many(keys)
186+
187+
# Validate keys count to avoid issues with large $in queries
188+
MAX_KEYS = 30
189+
if keys_list is not None and len(keys_list) > MAX_KEYS:
190+
raise ValueError(
191+
f"Too many keys provided: {len(keys_list)}. "
192+
f"Maximum allowed is {MAX_KEYS} keys per request. "
193+
"Consider calling get_statistics multiple times with smaller key batches."
194+
)
195+
178196
# Import here to avoid circular dependency
179197
from chromadb.api.models.Collection import Collection
180198

@@ -198,10 +216,10 @@ def get_statistics(
198216
summary: Dict[str, Any] = {}
199217

200218
offset = 0
201-
# When filtering by key, also include "summary" entries to get total_count
219+
# When filtering by keys, also include "summary" entries to get total_count
202220
where_filter: Optional[Where] = (
203-
cast(Where, {"$or": [{"key": key}, {"key": "summary"}]})
204-
if key is not None
221+
cast(Where, {"key": {"$in": keys_list + ["summary"]}})
222+
if keys_list is not None
205223
else None
206224
)
207225

0 commit comments

Comments
 (0)