@@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
129129 raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
130130
131131
132- def join (self , compiler , connection ):
133- lookup_pipeline = []
134- lhs_fields = []
135- rhs_fields = []
136- # Add a join condition for each pair of joining fields.
137- parent_template = "parent__field__"
138- for lhs , rhs in self .join_fields :
139- lhs , rhs = connection .ops .prepare_join_on_clause (
140- self .parent_alias , lhs , compiler .collection_name , rhs
141- )
142- lhs_fields .append (lhs .as_mql (compiler , connection ))
143- # In the lookup stage, the reference to this column doesn't include
144- # the collection name.
145- rhs_fields .append (rhs .as_mql (compiler , connection ))
146- # Handle any join conditions besides matching field pairs.
147- extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
148- if extra :
132+ def join (self , compiler , connection , pushed_expressions = None ):
133+ def _get_reroot_replacements (expressions ):
134+ if not expressions :
135+ return []
149136 columns = []
150- for expr in extra . leaves () :
137+ for expr in expressions :
151138 # Determine whether the column needs to be transformed or rerouted
152139 # as part of the subquery.
153140 for hand_side in ["lhs" , "rhs" ]:
@@ -165,18 +152,45 @@ def join(self, compiler, connection):
165152 # based on their rerouted positions in the join pipeline.
166153 replacements = {}
167154 for col , parent_pos in columns :
168- column_target = Col (compiler .collection_name , expr . output_field . __class__ () )
155+ column_target = Col (compiler .collection_name , col . target , col . output_field )
169156 if parent_pos is not None :
170157 target_col = f"${ parent_template } { parent_pos } "
171158 column_target .target .db_column = target_col
172159 column_target .target .set_attributes_from_name (target_col )
173160 else :
174161 column_target .target = col .target
175162 replacements [col ] = column_target
176- # Apply the transformed expressions in the extra condition.
163+ return replacements
164+
165+ lookup_pipeline = []
166+ lhs_fields = []
167+ rhs_fields = []
168+ # Add a join condition for each pair of joining fields.
169+ parent_template = "parent__field__"
170+ for lhs , rhs in self .join_fields :
171+ lhs , rhs = connection .ops .prepare_join_on_clause (
172+ self .parent_alias , lhs , compiler .collection_name , rhs
173+ )
174+ lhs_fields .append (lhs .as_mql (compiler , connection ))
175+ # In the lookup stage, the reference to this column doesn't include
176+ # the collection name.
177+ rhs_fields .append (rhs .as_mql (compiler , connection ))
178+ # Handle any join conditions besides matching field pairs.
179+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
180+
181+ if extra :
182+ replacements = _get_reroot_replacements (extra .leaves ())
177183 extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
178184 else :
179185 extra_condition = []
186+ if self .join_type == INNER :
187+ rerooted_replacement = _get_reroot_replacements (pushed_expressions )
188+ resolved_pushed_expressions = [
189+ expr .replace_expressions (rerooted_replacement ).as_mql (compiler , connection )
190+ for expr in pushed_expressions
191+ ]
192+ else :
193+ resolved_pushed_expressions = []
180194
181195 lookup_pipeline = [
182196 {
@@ -204,6 +218,7 @@ def join(self, compiler, connection):
204218 for i , field in enumerate (rhs_fields )
205219 ]
206220 + extra_condition
221+ + resolved_pushed_expressions
207222 }
208223 }
209224 }
0 commit comments