diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 63818a5bae..d17d876d2e 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -51,7 +51,7 @@ def __init__( self.client = chromadb.CloudClient( api_key=api_key, tenant=tenant, - database="mem0" # Use fixed database name for cloud + database="mem0", # Use fixed database name for cloud ) else: # Initialize local or server client @@ -83,26 +83,37 @@ def _parse_output(self, data: Dict) -> List[OutputData]: Returns: List[OutputData]: Parsed output data. """ - keys = ["ids", "distances", "metadatas"] - values = [] + # Fast-path: Try to reduce the number of isinstance checks and temporary lists + ids = data.get("ids", []) + distances = data.get("distances", []) + metadatas = data.get("metadatas", []) + + # If the first element is itself a list, flatten it (matching original behavior) + if isinstance(ids, list) and ids and isinstance(ids[0], list): + ids = ids[0] + if isinstance(distances, list) and distances and isinstance(distances[0], list): + distances = distances[0] + if isinstance(metadatas, list) and metadatas and isinstance(metadatas[0], list): + metadatas = metadatas[0] + + # Pre-calculate lengths ONCE; avoids repeated len() calls + ids_len = len(ids) if isinstance(ids, list) else 0 + distances_len = len(distances) if isinstance(distances, list) else 0 + metadatas_len = len(metadatas) if isinstance(metadatas, list) else 0 + max_length = max(ids_len, distances_len, metadatas_len) + + # Hoist .append to local for slight efficiency + result: List["OutputData"] = [] + append = result.append - for key in keys: - value = data.get(key, []) - if isinstance(value, list) and value and isinstance(value[0], list): - value = value[0] - values.append(value) - - ids, distances, metadatas = values - max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) - - result = [] for i in range(max_length): - entry = OutputData( - id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, - score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), - payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), + append( + OutputData( + id=ids[i] if i < ids_len else None, + score=distances[i] if i < distances_len else None, + payload=metadatas[i] if i < metadatas_len else None, + ) ) - result.append(entry) return result @@ -247,16 +258,16 @@ def reset(self): def _generate_where_clause(where: dict[str, any]) -> dict[str, any]: """ Generate a properly formatted where clause for ChromaDB. - + Args: where (dict[str, any]): The filter conditions. - + Returns: dict[str, any]: Properly formatted where clause for ChromaDB. """ if where is None: return {} - + def convert_condition(key: str, value: any) -> dict: """Convert universal filter format to ChromaDB format.""" if value == "*": @@ -292,9 +303,9 @@ def convert_condition(key: str, value: any) -> dict: else: # Simple equality return {key: {"$eq": value}} - + processed_filters = [] - + for key, value in where.items(): if key == "$or": # Handle OR conditions @@ -307,22 +318,22 @@ def convert_condition(key: str, value: any) -> dict: or_condition.update(converted) if or_condition: or_conditions.append(or_condition) - + if len(or_conditions) > 1: processed_filters.append({"$or": or_conditions}) elif len(or_conditions) == 1: processed_filters.append(or_conditions[0]) - + elif key == "$not": # Handle NOT conditions - ChromaDB doesn't have direct NOT, so we'll skip for now continue - + else: # Regular condition converted = convert_condition(key, value) if converted: processed_filters.append(converted) - + # Return appropriate format based on number of conditions if len(processed_filters) == 0: return {}