@@ -123,25 +123,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
123123 raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
124124
125125
126- def join (self , compiler , connection ):
127- lookup_pipeline = []
128- lhs_fields = []
129- rhs_fields = []
130- # Add a join condition for each pair of joining fields.
131- parent_template = "parent__field__"
132- for lhs , rhs in self .join_fields :
133- lhs , rhs = connection .ops .prepare_join_on_clause (
134- self .parent_alias , lhs , compiler .collection_name , rhs
135- )
136- lhs_fields .append (lhs .as_mql (compiler , connection ))
137- # In the lookup stage, the reference to this column doesn't include
138- # the collection name.
139- rhs_fields .append (rhs .as_mql (compiler , connection ))
140- # Handle any join conditions besides matching field pairs.
141- extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
142- if extra :
126+ def join (self , compiler , connection , pushed_expressions = None ):
127+ def _get_reroot_replacements (expressions ):
128+ if not expressions :
129+ return []
143130 columns = []
144- for expr in extra . leaves () :
131+ for expr in expressions :
145132 # Determine whether the column needs to be transformed or rerouted
146133 # as part of the subquery.
147134 for hand_side in ["lhs" , "rhs" ]:
@@ -159,18 +146,45 @@ def join(self, compiler, connection):
159146 # based on their rerouted positions in the join pipeline.
160147 replacements = {}
161148 for col , parent_pos in columns :
162- column_target = Col (compiler .collection_name , expr . output_field . __class__ () )
149+ column_target = Col (compiler .collection_name , col . target , col . output_field )
163150 if parent_pos is not None :
164151 target_col = f"${ parent_template } { parent_pos } "
165152 column_target .target .db_column = target_col
166153 column_target .target .set_attributes_from_name (target_col )
167154 else :
168155 column_target .target = col .target
169156 replacements [col ] = column_target
170- # Apply the transformed expressions in the extra condition.
157+ return replacements
158+
159+ lookup_pipeline = []
160+ lhs_fields = []
161+ rhs_fields = []
162+ # Add a join condition for each pair of joining fields.
163+ parent_template = "parent__field__"
164+ for lhs , rhs in self .join_fields :
165+ lhs , rhs = connection .ops .prepare_join_on_clause (
166+ self .parent_alias , lhs , compiler .collection_name , rhs
167+ )
168+ lhs_fields .append (lhs .as_mql (compiler , connection ))
169+ # In the lookup stage, the reference to this column doesn't include
170+ # the collection name.
171+ rhs_fields .append (rhs .as_mql (compiler , connection ))
172+ # Handle any join conditions besides matching field pairs.
173+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
174+
175+ if extra :
176+ replacements = _get_reroot_replacements (extra .leaves ())
171177 extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
172178 else :
173179 extra_condition = []
180+ if self .join_type == INNER :
181+ rerooted_replacement = _get_reroot_replacements (pushed_expressions )
182+ resolved_pushed_expressions = [
183+ expr .replace_expressions (rerooted_replacement ).as_mql (compiler , connection )
184+ for expr in pushed_expressions
185+ ]
186+ else :
187+ resolved_pushed_expressions = []
174188
175189 lookup_pipeline = [
176190 {
@@ -198,6 +212,7 @@ def join(self, compiler, connection):
198212 for i , field in enumerate (rhs_fields )
199213 ]
200214 + extra_condition
215+ + resolved_pushed_expressions
201216 }
202217 }
203218 }
0 commit comments