1717from django .utils .functional import cached_property
1818from pymongo import ASCENDING , DESCENDING
1919
20+ from .expressions .search import SearchExpression , SearchVector
2021from .query import MongoQuery , wrap_database_errors
2122from .query_utils import is_direct_value
2223
@@ -35,6 +36,8 @@ def __init__(self, *args, **kwargs):
3536 # A list of OrderBy objects for this query.
3637 self .order_by_objs = None
3738 self .subqueries = []
39+ # Atlas search stage.
40+ self .search_pipeline = []
3841
3942 def _get_group_alias_column (self , expr , annotation_group_idx ):
4043 """Generate a dummy field for use in the ids fields in $group."""
@@ -58,6 +61,29 @@ def _get_column_from_expression(self, expr, alias):
5861 column_target .set_attributes_from_name (alias )
5962 return Col (self .collection_name , column_target )
6063
64+ def _get_replace_expr (self , sub_expr , group , alias ):
65+ column_target = sub_expr .output_field .clone ()
66+ column_target .db_column = alias
67+ column_target .set_attributes_from_name (alias )
68+ inner_column = Col (self .collection_name , column_target )
69+ if getattr (sub_expr , "distinct" , False ):
70+ # If the expression should return distinct values, use $addToSet to
71+ # deduplicate.
72+ rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
73+ group [alias ] = {"$addToSet" : rhs }
74+ replacing_expr = sub_expr .copy ()
75+ replacing_expr .set_source_expressions ([inner_column , None ])
76+ else :
77+ group [alias ] = sub_expr .as_mql (self , self .connection )
78+ replacing_expr = inner_column
79+ # Count must return 0 rather than null.
80+ if isinstance (sub_expr , Count ):
81+ replacing_expr = Coalesce (replacing_expr , 0 )
82+ # Variance = StdDev^2
83+ if isinstance (sub_expr , Variance ):
84+ replacing_expr = Power (replacing_expr , 2 )
85+ return replacing_expr
86+
6187 def _prepare_expressions_for_pipeline (self , expression , target , annotation_group_idx ):
6288 """
6389 Prepare expressions for the aggregation pipeline.
@@ -81,29 +107,51 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
81107 alias = (
82108 f"__aggregation{ next (annotation_group_idx )} " if sub_expr != expression else target
83109 )
84- column_target = sub_expr .output_field .clone ()
85- column_target .db_column = alias
86- column_target .set_attributes_from_name (alias )
87- inner_column = Col (self .collection_name , column_target )
88- if sub_expr .distinct :
89- # If the expression should return distinct values, use
90- # $addToSet to deduplicate.
91- rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
92- group [alias ] = {"$addToSet" : rhs }
93- replacing_expr = sub_expr .copy ()
94- replacing_expr .set_source_expressions ([inner_column , None ])
95- else :
96- group [alias ] = sub_expr .as_mql (self , self .connection )
97- replacing_expr = inner_column
98- # Count must return 0 rather than null.
99- if isinstance (sub_expr , Count ):
100- replacing_expr = Coalesce (replacing_expr , 0 )
101- # Variance = StdDev^2
102- if isinstance (sub_expr , Variance ):
103- replacing_expr = Power (replacing_expr , 2 )
104- replacements [sub_expr ] = replacing_expr
110+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , group , alias )
105111 return replacements , group
106112
113+ def _prepare_search_expressions_for_pipeline (self , expression , search_idx , replacements ):
114+ """
115+ Collect and prepare unique search expressions for inclusion in an
116+ aggregation pipeline.
117+
118+ Iterate over all search sub-expressions of the given expression.
119+ Assigning a unique alias to each and map them to their replacement
120+ expressions.
121+ """
122+ searches = {}
123+ for sub_expr in self ._get_search_expressions (expression ):
124+ if sub_expr not in replacements :
125+ alias = f"__search_expr.search{ next (search_idx )} "
126+ replacements [sub_expr ] = self ._get_replace_expr (sub_expr , searches , alias )
127+
128+ def _prepare_search_query_for_aggregation_pipeline (self , order_by ):
129+ """
130+ Prepare expressions for the search pipeline.
131+
132+ Handle the computation of search functions used by various expressions.
133+ Separate and create intermediate columns, and replace nodes to simulate
134+ a search operation.
135+
136+ To apply operations over the $search or $searchVector stages, compute
137+ the $search or $vectorSearch first, then apply additional operations in
138+ a subsequent stage by replacing the aggregate expressions with a new
139+ document field prefixed by `__search_expr.search#`.
140+ """
141+ replacements = {}
142+ annotation_group_idx = itertools .count (start = 1 )
143+ for expr in self .query .annotation_select .values ():
144+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
145+ for expr , _ in order_by :
146+ self ._prepare_search_expressions_for_pipeline (expr , annotation_group_idx , replacements )
147+ self ._prepare_search_expressions_for_pipeline (
148+ self .having , annotation_group_idx , replacements
149+ )
150+ self ._prepare_search_expressions_for_pipeline (
151+ self .get_where (), annotation_group_idx , replacements
152+ )
153+ return replacements
154+
107155 def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
108156 """Prepare annotations for the aggregation pipeline."""
109157 replacements = {}
@@ -208,9 +256,67 @@ def _build_aggregation_pipeline(self, ids, group):
208256 pipeline .append ({"$unset" : "_id" })
209257 return pipeline
210258
259+ def _compound_searches_queries (self , search_replacements ):
260+ """
261+ Build a query pipeline from a mapping of search expressions to result
262+ columns.
263+
264+ Currently only a single $search or $vectorSearch expression is
265+ supported. Combining multiple search expressions raises ValueError.
266+
267+ This method will eventually support hybrid search by allowing the
268+ combination of $search and $vectorSearch operations.
269+ """
270+ if not search_replacements :
271+ return []
272+ if len (search_replacements ) > 1 :
273+ has_search = any (not isinstance (search , SearchVector ) for search in search_replacements )
274+ has_vector_search = any (
275+ isinstance (search , SearchVector ) for search in search_replacements
276+ )
277+ if has_search and has_vector_search :
278+ raise ValueError (
279+ "Cannot combine a `$vectorSearch` with a `$search` operator. "
280+ "If you need to combine them, consider restructuring your query logic or "
281+ "running them as separate queries."
282+ )
283+ if has_vector_search :
284+ raise ValueError (
285+ "Cannot combine two `$vectorSearch` operator. "
286+ "If you need to combine them, consider restructuring your query logic or "
287+ "running them as separate queries."
288+ )
289+ raise ValueError (
290+ "Only one $search operation is allowed per query. "
291+ f"Received { len (search_replacements )} search expressions. "
292+ "To combine multiple search expressions, use either a CompoundExpression for "
293+ "fine-grained control or CombinedSearchExpression for simple logical combinations."
294+ )
295+ pipeline = []
296+ for search , result_col in search_replacements .items ():
297+ score_function = (
298+ "vectorSearchScore" if isinstance (search , SearchVector ) else "searchScore"
299+ )
300+ pipeline .extend (
301+ [
302+ search .as_mql (self , self .connection ),
303+ {
304+ "$addFields" : {
305+ result_col .as_mql (self , self .connection , as_path = True ): {
306+ "$meta" : score_function
307+ }
308+ }
309+ },
310+ ]
311+ )
312+ return pipeline
313+
211314 def pre_sql_setup (self , with_col_aliases = False ):
212315 extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
213- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
316+ search_replacements = self ._prepare_search_query_for_aggregation_pipeline (order_by )
317+ group , group_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
318+ all_replacements = {** search_replacements , ** group_replacements }
319+ self .search_pipeline = self ._compound_searches_queries (search_replacements )
214320 # query.group_by is either:
215321 # - None: no GROUP BY
216322 # - True: group by select fields
@@ -235,6 +341,8 @@ def pre_sql_setup(self, with_col_aliases=False):
235341 for target , expr in self .query .annotation_select .items ()
236342 }
237343 self .order_by_objs = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
344+ if (where := self .get_where ()) and search_replacements :
345+ self .set_where (where .replace_expressions (search_replacements ))
238346 return extra_select , order_by , group_by
239347
240348 def execute_sql (
@@ -573,10 +681,16 @@ def get_lookup_pipeline(self):
573681 return result
574682
575683 def _get_aggregate_expressions (self , expr ):
684+ return self ._get_all_expressions_of_type (expr , Aggregate )
685+
686+ def _get_search_expressions (self , expr ):
687+ return self ._get_all_expressions_of_type (expr , SearchExpression )
688+
689+ def _get_all_expressions_of_type (self , expr , target_type ):
576690 stack = [expr ]
577691 while stack :
578692 expr = stack .pop ()
579- if isinstance (expr , Aggregate ):
693+ if isinstance (expr , target_type ):
580694 yield expr
581695 elif hasattr (expr , "get_source_expressions" ):
582696 stack .extend (expr .get_source_expressions ())
@@ -645,6 +759,9 @@ def _get_ordering(self):
645759 def get_where (self ):
646760 return getattr (self , "where" , self .query .where )
647761
762+ def set_where (self , value ):
763+ self .where = value
764+
648765 def explain_query (self ):
649766 # Validate format (none supported) and options.
650767 options = self .connection .ops .explain_query_prefix (
0 commit comments