|
4 | 4 | from django.core.exceptions import EmptyResultSet, FullResultSet |
5 | 5 | from django.db import DatabaseError, IntegrityError, NotSupportedError |
6 | 6 | from django.db.models.expressions import Case, Col, When |
| 7 | +from django.db.models.fields.related import ForeignKey |
7 | 8 | from django.db.models.functions import Mod |
8 | 9 | from django.db.models.lookups import Exact |
9 | 10 | from django.db.models.sql.constants import INNER |
@@ -180,15 +181,26 @@ def _get_reroot_replacements(expression): |
180 | 181 | lookup_pipeline = [] |
181 | 182 | lhs_fields = [] |
182 | 183 | rhs_fields = [] |
| 184 | + local_field = None |
| 185 | + foreign_field = None |
183 | 186 | # Add a join condition for each pair of joining fields. |
184 | 187 | for lhs, rhs in self.join_fields: |
185 | | - lhs, rhs = connection.ops.prepare_join_on_clause( |
| 188 | + lhs_prepared, rhs_prepared = connection.ops.prepare_join_on_clause( |
186 | 189 | self.parent_alias, lhs, compiler.collection_name, rhs |
187 | 190 | ) |
188 | | - lhs_fields.append(lhs.as_mql(compiler, connection, as_expr=True)) |
189 | | - # In the lookup stage, the reference to this column doesn't include the |
190 | | - # collection name. |
191 | | - rhs_fields.append(rhs.as_mql(compiler, connection, as_expr=True)) |
| 191 | + if ( |
| 192 | + (isinstance(lhs, ForeignKey) or isinstance(rhs, ForeignKey)) |
| 193 | + and lhs_prepared.is_simple_column |
| 194 | + and rhs_prepared.is_simple_column |
| 195 | + ): |
| 196 | + # The join can be made using localField and foreignField. |
| 197 | + local_field = lhs_prepared.as_mql(compiler, connection) |
| 198 | + foreign_field = rhs_prepared.as_mql(compiler, connection) |
| 199 | + else: |
| 200 | + lhs_fields.append(lhs_prepared.as_mql(compiler, connection, as_expr=True)) |
| 201 | + # In the lookup stage, the reference to this column doesn't include |
| 202 | + # the collection name. |
| 203 | + rhs_fields.append(rhs_prepared.as_mql(compiler, connection, as_expr=True)) |
192 | 204 | # Handle any join conditions besides matching field pairs. |
193 | 205 | extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) |
194 | 206 | extra_conditions = [] |
@@ -218,32 +230,47 @@ def _get_reroot_replacements(expression): |
218 | 230 | # self.table_name.field2 = parent_table.field2 |
219 | 231 | # AND |
220 | 232 | # ... |
221 | | - condition = { |
222 | | - "$expr": { |
223 | | - "$and": [ |
224 | | - {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) |
225 | | - ] |
226 | | - } |
227 | | - } |
| 233 | + all_conditions = [] |
| 234 | + if rhs_fields: |
| 235 | + all_conditions.append( |
| 236 | + { |
| 237 | + "$expr": { |
| 238 | + "$and": [ |
| 239 | + {"$eq": [f"$${parent_template}{i}", field]} |
| 240 | + for i, field in enumerate(rhs_fields) |
| 241 | + ] |
| 242 | + } |
| 243 | + } |
| 244 | + ) |
228 | 245 | if extra_conditions: |
229 | | - condition = {"$and": [condition, *extra_conditions]} |
230 | | - lookup_pipeline = [ |
231 | | - { |
232 | | - "$lookup": { |
233 | | - # The right-hand table to join. |
234 | | - "from": self.table_name, |
235 | | - # The pipeline variables to be matched in the pipeline's |
236 | | - # expression. |
237 | | - "let": { |
238 | | - f"{parent_template}{i}": parent_field |
239 | | - for i, parent_field in enumerate(lhs_fields) |
240 | | - }, |
241 | | - "pipeline": [{"$match": condition}], |
242 | | - # Rename the output as table_alias. |
243 | | - "as": self.table_alias, |
| 246 | + all_conditions.extend(extra_conditions) |
| 247 | + # Build matching pipeline |
| 248 | + num_conditions = len(all_conditions) |
| 249 | + if num_conditions == 0: |
| 250 | + pipeline = [] |
| 251 | + elif num_conditions == 1: |
| 252 | + pipeline = [{"$match": all_conditions[0]}] |
| 253 | + else: |
| 254 | + pipeline = [{"$match": {"$and": all_conditions}}] |
| 255 | + lookup = { |
| 256 | + # The right-hand table to join. |
| 257 | + "from": self.table_name, |
| 258 | + "pipeline": pipeline, |
| 259 | + # Rename the output as table_alias. |
| 260 | + "as": self.table_alias, |
| 261 | + } |
| 262 | + if local_field and foreign_field: |
| 263 | + lookup.update( |
| 264 | + { |
| 265 | + "localField": local_field, |
| 266 | + "foreignField": foreign_field, |
244 | 267 | } |
245 | | - }, |
246 | | - ] |
| 268 | + ) |
| 269 | + if lhs_fields: |
| 270 | + lookup["let"] = { |
| 271 | + f"{parent_template}{i}": parent_field for i, parent_field in enumerate(lhs_fields) |
| 272 | + } |
| 273 | + lookup_pipeline = [{"$lookup": lookup}] |
247 | 274 | # To avoid missing data when using $unwind, an empty collection is added if |
248 | 275 | # the join isn't an inner join. For inner joins, rows with empty arrays are |
249 | 276 | # removed, as $unwind unrolls or unnests the array and removes the row if |
|
0 commit comments