1818 Union ,
1919)
2020
21+ from redisvl .redis .utils import convert_bytes , make_dict
2122from redisvl .utils .utils import deprecated_argument , deprecated_function , sync_wrapper
2223
2324if TYPE_CHECKING :
3940 SchemaValidationError ,
4041)
4142from redisvl .index .storage import BaseStorage , HashStorage , JsonStorage
42- from redisvl .query import BaseQuery , BaseVectorQuery , CountQuery , FilterQuery
43+ from redisvl .query import (
44+ AggregationQuery ,
45+ BaseQuery ,
46+ BaseVectorQuery ,
47+ CountQuery ,
48+ FilterQuery ,
49+ HybridQuery ,
50+ )
4351from redisvl .query .filter import FilterExpression
4452from redisvl .redis .connection import (
4553 RedisConnectionFactory ,
@@ -138,6 +146,34 @@ def _process(doc: "Document") -> Dict[str, Any]:
138146 return [_process (doc ) for doc in results .docs ]
139147
140148
149+ def process_aggregate_results (
150+ results : "AggregateResult" , query : AggregationQuery , storage_type : StorageType
151+ ) -> List [Dict [str , Any ]]:
152+ """Convert an aggregate reslt object into a list of document dictionaries.
153+
154+ This function processes results from Redis, handling different storage
155+ types and query types. For JSON storage with empty return fields, it
156+ unpacks the JSON object while retaining the document ID. The 'payload'
157+ field is also removed from all resulting documents for consistency.
158+
159+ Args:
160+ results (AggregarteResult): The aggregart results from Redis.
161+ query (AggregationQuery): The aggregation query object used for the aggregation.
162+ storage_type (StorageType): The storage type of the search
163+ index (json or hash).
164+
165+ Returns:
166+ List[Dict[str, Any]]: A list of processed document dictionaries.
167+ """
168+
169+ def _process (row ):
170+ result = make_dict (convert_bytes (row ))
171+ result .pop ("__score" , None )
172+ return result
173+
174+ return [_process (r ) for r in results .rows ]
175+
176+
141177class BaseSearchIndex :
142178 """Base search engine class"""
143179
@@ -650,6 +686,17 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
650686 return convert_bytes (obj [0 ])
651687 return None
652688
689+ def _aggregate (self , aggregation_query : AggregationQuery ) -> List [Dict [str , Any ]]:
690+ """Execute an aggretation query and processes the results."""
691+ results = self .aggregate (
692+ aggregation_query , query_params = aggregation_query .params # type: ignore[attr-defined]
693+ )
694+ return process_aggregate_results (
695+ results ,
696+ query = aggregation_query ,
697+ storage_type = self .schema .index .storage_type ,
698+ )
699+
653700 def aggregate (self , * args , ** kwargs ) -> "AggregateResult" :
654701 """Perform an aggregation operation against the index.
655702
@@ -772,14 +819,14 @@ def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
772819 results = self .search (query .query , query_params = query .params )
773820 return process_results (results , query = query , schema = self .schema )
774821
775- def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
822+ def query (self , query : Union [ BaseQuery , AggregationQuery ] ) -> List [Dict [str , Any ]]:
776823 """Execute a query on the index.
777824
778- This method takes a BaseQuery object directly, runs the search , and
825+ This method takes a BaseQuery or AggregationQuery object directly , and
779826 handles post-processing of the search.
780827
781828 Args:
782- query (BaseQuery): The query to run.
829+ query (Union[ BaseQuery, AggregateQuery] ): The query to run.
783830
784831 Returns:
785832 List[Result]: A list of search results.
@@ -797,7 +844,10 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
797844 results = index.query(query)
798845
799846 """
800- return self ._query (query )
847+ if isinstance (query , AggregationQuery ):
848+ return self ._aggregate (query )
849+ else :
850+ return self ._query (query )
801851
802852 def paginate (self , query : BaseQuery , page_size : int = 30 ) -> Generator :
803853 """Execute a given query against the index and return results in
@@ -1303,6 +1353,19 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
13031353 return convert_bytes (obj [0 ])
13041354 return None
13051355
1356+ async def _aggregate (
1357+ self , aggregation_query : AggregationQuery
1358+ ) -> List [Dict [str , Any ]]:
1359+ """Execute an aggretation query and processes the results."""
1360+ results = await self .aggregate (
1361+ aggregation_query , query_params = aggregation_query .params # type: ignore[attr-defined]
1362+ )
1363+ return process_aggregate_results (
1364+ results ,
1365+ query = aggregation_query ,
1366+ storage_type = self .schema .index .storage_type ,
1367+ )
1368+
13061369 async def aggregate (self , * args , ** kwargs ) -> "AggregateResult" :
13071370 """Perform an aggregation operation against the index.
13081371
@@ -1426,14 +1489,16 @@ async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14261489 results = await self .search (query .query , query_params = query .params )
14271490 return process_results (results , query = query , schema = self .schema )
14281491
1429- async def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
1492+ async def query (
1493+ self , query : Union [BaseQuery , AggregationQuery ]
1494+ ) -> List [Dict [str , Any ]]:
14301495 """Asynchronously execute a query on the index.
14311496
1432- This method takes a BaseQuery object directly, runs the search, and
1433- handles post-processing of the search.
1497+ This method takes a BaseQuery or AggregationQuery object directly, runs
1498+ the search, and handles post-processing of the search.
14341499
14351500 Args:
1436- query (BaseQuery): The query to run.
1501+ query (Union[ BaseQuery, AggregateQuery] ): The query to run.
14371502
14381503 Returns:
14391504 List[Result]: A list of search results.
@@ -1450,7 +1515,10 @@ async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14501515
14511516 results = await index.query(query)
14521517 """
1453- return await self ._query (query )
1518+ if isinstance (query , AggregationQuery ):
1519+ return await self ._aggregate (query )
1520+ else :
1521+ return await self ._query (query )
14541522
14551523 async def paginate (self , query : BaseQuery , page_size : int = 30 ) -> AsyncGenerator :
14561524 """Execute a given query against the index and return results in
0 commit comments