Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions mem0/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.client = chromadb.CloudClient(
api_key=api_key,
tenant=tenant,
database="mem0" # Use fixed database name for cloud
database="mem0", # Use fixed database name for cloud
)
else:
# Initialize local or server client
Expand Down Expand Up @@ -83,26 +83,37 @@ def _parse_output(self, data: Dict) -> List[OutputData]:
Returns:
List[OutputData]: Parsed output data.
"""
keys = ["ids", "distances", "metadatas"]
values = []
# Fast-path: Try to reduce the number of isinstance checks and temporary lists
ids = data.get("ids", [])
distances = data.get("distances", [])
metadatas = data.get("metadatas", [])

# If the first element is itself a list, flatten it (matching original behavior)
if isinstance(ids, list) and ids and isinstance(ids[0], list):
ids = ids[0]
if isinstance(distances, list) and distances and isinstance(distances[0], list):
distances = distances[0]
if isinstance(metadatas, list) and metadatas and isinstance(metadatas[0], list):
metadatas = metadatas[0]

# Pre-calculate lengths ONCE; avoids repeated len() calls
ids_len = len(ids) if isinstance(ids, list) else 0
distances_len = len(distances) if isinstance(distances, list) else 0
metadatas_len = len(metadatas) if isinstance(metadatas, list) else 0
max_length = max(ids_len, distances_len, metadatas_len)

# Hoist .append to local for slight efficiency
result: List["OutputData"] = []
append = result.append

for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)

ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)

result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
append(
OutputData(
id=ids[i] if i < ids_len else None,
score=distances[i] if i < distances_len else None,
payload=metadatas[i] if i < metadatas_len else None,
)
)
result.append(entry)

return result

Expand Down Expand Up @@ -247,16 +258,16 @@ def reset(self):
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
"""
Generate a properly formatted where clause for ChromaDB.

Args:
where (dict[str, any]): The filter conditions.

Returns:
dict[str, any]: Properly formatted where clause for ChromaDB.
"""
if where is None:
return {}

def convert_condition(key: str, value: any) -> dict:
"""Convert universal filter format to ChromaDB format."""
if value == "*":
Expand Down Expand Up @@ -292,9 +303,9 @@ def convert_condition(key: str, value: any) -> dict:
else:
# Simple equality
return {key: {"$eq": value}}

processed_filters = []

for key, value in where.items():
if key == "$or":
# Handle OR conditions
Expand All @@ -307,22 +318,22 @@ def convert_condition(key: str, value: any) -> dict:
or_condition.update(converted)
if or_condition:
or_conditions.append(or_condition)

if len(or_conditions) > 1:
processed_filters.append({"$or": or_conditions})
elif len(or_conditions) == 1:
processed_filters.append(or_conditions[0])

elif key == "$not":
# Handle NOT conditions - ChromaDB doesn't have direct NOT, so we'll skip for now
continue

else:
# Regular condition
converted = convert_condition(key, value)
if converted:
processed_filters.append(converted)

# Return appropriate format based on number of conditions
if len(processed_filters) == 0:
return {}
Expand Down