2828from ..query_utils import process_lhs
2929
3030
31- def case (self , compiler , connection ):
31+ # EXTRA IS TOTALLY IGNORED
32+ def case (self , compiler , connection , ** extra ): # noqa: ARG001
3233 case_parts = []
3334 for case in self .cases :
3435 case_mql = {}
@@ -53,7 +54,7 @@ def case(self, compiler, connection):
5354 }
5455
5556
56- def col (self , compiler , connection , as_path = False ): # noqa: ARG001
57+ def col (self , compiler , connection , as_path = False , as_expr = None ): # noqa: ARG001
5758 # If the column is part of a subquery and belongs to one of the parent
5859 # queries, it will be stored for reference using $let in a $lookup stage.
5960 # If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7172 # Add the column's collection's alias for columns in joined collections.
7273 has_alias = self .alias and self .alias != compiler .collection_name
7374 prefix = f"{ self .alias } ." if has_alias else ""
74- if not as_path :
75+ if not as_path or as_expr :
7576 prefix = f"${ prefix } "
7677 return f"{ prefix } { self .target .column } "
7778
@@ -83,16 +84,16 @@ def col_pairs(self, compiler, connection):
8384 return cols [0 ].as_mql (compiler , connection )
8485
8586
86- def combined_expression (self , compiler , connection ):
87+ def combined_expression (self , compiler , connection , ** extra ):
8788 expressions = [
88- self .lhs .as_mql (compiler , connection ),
89- self .rhs .as_mql (compiler , connection ),
89+ self .lhs .as_mql (compiler , connection , ** extra ),
90+ self .rhs .as_mql (compiler , connection , ** extra ),
9091 ]
9192 return connection .ops .combine_expression (self .connector , expressions )
9293
9394
94- def expression_wrapper (self , compiler , connection ):
95- return self .expression .as_mql (compiler , connection )
95+ def expression_wrapper (self , compiler , connection , ** extra ):
96+ return self .expression .as_mql (compiler , connection , ** extra )
9697
9798
9899def negated_expression (self , compiler , connection ):
@@ -103,7 +104,7 @@ def order_by(self, compiler, connection):
103104 return self .expression .as_mql (compiler , connection )
104105
105106
106- def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
107+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False , as_expr = None ):
107108 subquery_compiler = self .get_compiler (connection = connection )
108109 subquery_compiler .pre_sql_setup (with_col_aliases = False )
109110 field_name , expr = subquery_compiler .columns [0 ]
@@ -145,7 +146,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
145146 # Erase project_fields since the required value is projected above.
146147 subquery .project_fields = None
147148 compiler .subqueries .append (subquery )
148- if as_path :
149+ if as_path and not as_expr :
149150 return f"{ table_output } .{ field_name } "
150151 return f"${ table_output } .{ field_name } "
151152
@@ -200,7 +201,7 @@ def when(self, compiler, connection, **extra):
200201 return self .condition .as_mql (compiler , connection , ** extra )
201202
202203
203- def value (self , compiler , connection ): # noqa: ARG001
204+ def value (self , compiler , connection , ** extra ): # noqa: ARG001
204205 value = self .value
205206 if isinstance (value , (list , int )):
206207 # Wrap lists & numbers in $literal to prevent ambiguity when Value
0 commit comments