66from django .core .exceptions import EmptyResultSet , FullResultSet
77from django .db import NotSupportedError
88from django .db .models .expressions import (
9+ BaseExpression ,
910 Case ,
1011 Col ,
1112 ColPairs ,
1213 CombinedExpression ,
1314 Exists ,
1415 ExpressionList ,
1516 ExpressionWrapper ,
17+ Func ,
1618 NegatedExpression ,
1719 OrderBy ,
1820 RawSQL ,
2325 Value ,
2426 When ,
2527)
28+ from django .db .models .fields .json import KeyTransform
2629from django .db .models .sql import Query
2730
28- from django_mongodb_backend .query_utils import process_lhs
31+ from django_mongodb_backend .fields .array import Array
32+ from django_mongodb_backend .query_utils import is_direct_value , process_lhs
2933
3034
31- def case (self , compiler , connection ):
35+ def case (self , compiler , connection , as_path = False ):
3236 case_parts = []
3337 for case in self .cases :
3438 case_mql = {}
3539 try :
36- case_mql ["case" ] = case .as_mql (compiler , connection )
40+ case_mql ["case" ] = case .as_mql (compiler , connection , as_path = False )
3741 except EmptyResultSet :
3842 continue
3943 except FullResultSet :
@@ -45,12 +49,16 @@ def case(self, compiler, connection):
4549 default_mql = self .default .as_mql (compiler , connection )
4650 if not case_parts :
4751 return default_mql
48- return {
52+ expr = {
4953 "$switch" : {
5054 "branches" : case_parts ,
5155 "default" : default_mql ,
5256 }
5357 }
58+ if as_path :
59+ return {"$expr" : expr }
60+
61+ return expr
5462
5563
5664def col (self , compiler , connection , as_path = False ): # noqa: ARG001
@@ -76,34 +84,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7684 return f"{ prefix } { self .target .column } "
7785
7886
79- def col_pairs (self , compiler , connection ):
87+ def col_pairs (self , compiler , connection , as_path = False ):
8088 cols = self .get_cols ()
8189 if len (cols ) > 1 :
8290 raise NotSupportedError ("ColPairs is not supported." )
83- return cols [0 ].as_mql (compiler , connection )
91+ return cols [0 ].as_mql (compiler , connection , as_path = as_path )
8492
8593
86- def combined_expression (self , compiler , connection ):
94+ def combined_expression (self , compiler , connection , as_path = False ):
8795 expressions = [
88- self .lhs .as_mql (compiler , connection ),
89- self .rhs .as_mql (compiler , connection ),
96+ self .lhs .as_mql (compiler , connection , as_path = as_path ),
97+ self .rhs .as_mql (compiler , connection , as_path = as_path ),
9098 ]
9199 return connection .ops .combine_expression (self .connector , expressions )
92100
93101
94- def expression_wrapper (self , compiler , connection ):
95- return self .expression .as_mql (compiler , connection )
102+ def expression_wrapper (self , compiler , connection , as_path = False ):
103+ return self .expression .as_mql (compiler , connection , as_path = as_path )
96104
97105
98- def negated_expression (self , compiler , connection ):
99- return {"$not" : expression_wrapper (self , compiler , connection )}
106+ def negated_expression (self , compiler , connection , as_path = False ):
107+ return {"$not" : expression_wrapper (self , compiler , connection , as_path = as_path )}
100108
101109
102110def order_by (self , compiler , connection ):
103111 return self .expression .as_mql (compiler , connection )
104112
105113
106- def query (self , compiler , connection , get_wrapping_pipeline = None ):
114+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
107115 subquery_compiler = self .get_compiler (connection = connection )
108116 subquery_compiler .pre_sql_setup (with_col_aliases = False )
109117 field_name , expr = subquery_compiler .columns [0 ]
@@ -145,14 +153,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145153 # Erase project_fields since the required value is projected above.
146154 subquery .project_fields = None
147155 compiler .subqueries .append (subquery )
156+ if as_path :
157+ return f"{ table_output } .{ field_name } "
148158 return f"${ table_output } .{ field_name } "
149159
150160
151161def raw_sql (self , compiler , connection ): # noqa: ARG001
152162 raise NotSupportedError ("RawSQL is not supported on MongoDB." )
153163
154164
155- def ref (self , compiler , connection ): # noqa: ARG001
165+ def ref (self , compiler , connection , as_path = False ): # noqa: ARG001
156166 prefix = (
157167 f"{ self .source .alias } ."
158168 if isinstance (self .source , Col ) and self .source .alias != compiler .collection_name
@@ -162,32 +172,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162172 refs , _ = compiler .columns [self .ordinal - 1 ]
163173 else :
164174 refs = self .refs
165- return f"${ prefix } { refs } "
175+ if not as_path :
176+ prefix = f"${ prefix } "
177+ return f"{ prefix } { refs } "
166178
167179
168- def star (self , compiler , connection ): # noqa: ARG001
180+ def star (self , compiler , connection , ** extra ): # noqa: ARG001
169181 return {"$literal" : True }
170182
171183
172- def subquery (self , compiler , connection , get_wrapping_pipeline = None ):
173- return self .query .as_mql (compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
184+ def subquery (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
185+ expr = self .query .as_mql (
186+ compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_path = False
187+ )
188+ if as_path :
189+ return {"$expr" : expr }
190+ return expr
174191
175192
176- def exists (self , compiler , connection , get_wrapping_pipeline = None ):
193+ def exists (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
177194 try :
178- lhs_mql = subquery (self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
195+ lhs_mql = subquery (
196+ self ,
197+ compiler ,
198+ connection ,
199+ get_wrapping_pipeline = get_wrapping_pipeline ,
200+ as_path = as_path ,
201+ )
179202 except EmptyResultSet :
180203 return Value (False ).as_mql (compiler , connection )
181- return connection .mongo_operators ["isnull" ](lhs_mql , False )
204+ if as_path :
205+ return {"$expr" : connection .mongo_operators_match ["isnull" ](lhs_mql , False )}
206+ return connection .mongo_operators_expr ["isnull" ](lhs_mql , False )
182207
183208
184- def when (self , compiler , connection ):
185- return self .condition .as_mql (compiler , connection )
209+ def when (self , compiler , connection , as_path = False ):
210+ return self .condition .as_mql (compiler , connection , as_path = as_path )
186211
187212
188- def value (self , compiler , connection ): # noqa: ARG001
213+ def value (self , compiler , connection , as_path = False ): # noqa: ARG001
189214 value = self .value
190- if isinstance (value , (list , int )):
215+ if isinstance (value , (list , int )) and not as_path :
191216 # Wrap lists & numbers in $literal to prevent ambiguity when Value
192217 # appears in $project.
193218 return {"$literal" : value }
@@ -209,6 +234,36 @@ def value(self, compiler, connection): # noqa: ARG001
209234 return value
210235
211236
237+ @staticmethod
238+ def _is_constant_value (value ):
239+ if isinstance (value , list | Array ):
240+ iterable = value .get_source_expressions () if isinstance (value , Array ) else value
241+ return all (_is_constant_value (e ) for e in iterable )
242+ if is_direct_value (value ):
243+ return True
244+ return isinstance (value , Func | Value ) and not (
245+ value .contains_aggregate
246+ or value .contains_over_clause
247+ or value .contains_column_references
248+ or value .contains_subquery
249+ )
250+
251+
252+ @staticmethod
253+ def _is_simple_column (lhs ):
254+ while isinstance (lhs , KeyTransform ):
255+ if "." in getattr (lhs , "key_name" , "" ):
256+ return False
257+ lhs = lhs .lhs
258+ col = lhs .source if isinstance (lhs , Ref ) else lhs
259+ # Foreign columns from parent cannot be addressed as single match
260+ return isinstance (col , Col ) and col .alias is not None
261+
262+
263+ def _is_simple_expression (self ):
264+ return self .is_simple_column (self .lhs ) and self .is_constant_value (self .rhs )
265+
266+
212267def register_expressions ():
213268 Case .as_mql = case
214269 Col .as_mql = col
@@ -227,3 +282,6 @@ def register_expressions():
227282 Subquery .as_mql = subquery
228283 When .as_mql = when
229284 Value .as_mql = value
285+ BaseExpression .is_simple_expression = _is_simple_expression
286+ BaseExpression .is_simple_column = _is_simple_column
287+ BaseExpression .is_constant_value = _is_constant_value
0 commit comments