diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 31f4b29ba..798c75d6d 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -8,14 +8,7 @@ MONGO_AGGREGATIONS = {Count: "sum"} -def aggregate( - self, - compiler, - connection, - operator=None, - resolve_inner_expression=False, - **extra_context, # noqa: ARG001 -): +def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False): if self.filter: node = self.copy() node.filter = None @@ -31,7 +24,7 @@ def aggregate( return {f"${operator}": lhs_mql} -def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 +def count(self, compiler, connection, resolve_inner_expression=False): """ When resolve_inner_expression=True, return the MQL that resolves as a value. This is used to count different elements, so the inner values are @@ -64,12 +57,12 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): +def stddev_variance(self, compiler, connection): if self.function.endswith("_SAMP"): operator = "stdDevSamp" elif self.function.endswith("_POP"): operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator, **extra_context) + return aggregate(self, compiler, connection, operator=operator) def register_aggregates(): diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index a220969cd..38f119524 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -227,13 +227,14 @@ class EmbeddedModelArrayFieldLessThanOrEqual( class EmbeddedModelArrayFieldTransform(Transform): field_class_name = "EmbeddedModelArrayField" + VIRTUAL_COLUMN_ITERABLE = "item" def __init__(self, field, *args, **kwargs): super().__init__(*args, **kwargs) # Lookups iterate over the array of embedded models. A virtual column # of the queried field's type represents each element. column_target = field.clone() - column_name = f"$item.{field.column}" + column_name = f"${self.VIRTUAL_COLUMN_ITERABLE}.{field.column}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name) self._lhs = Col(None, column_target) @@ -283,7 +284,7 @@ def as_mql(self, compiler, connection): { "$map": { "input": lhs_mql, - "as": "item", + "as": self.VIRTUAL_COLUMN_ITERABLE, "in": inner_lhs_mql, } }, diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py index 6325ca4fc..c3832675a 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py @@ -98,7 +98,7 @@ def __init__(self, field, *args, **kwargs): # Lookups iterate over the array of embedded models. A virtual column # of the queried field's type represents each element. column_target = field.clone() - column_name = f"$item.{field.column}" + column_name = f"${self.VIRTUAL_COLUMN_ITERABLE}.{field.column}" column_target.name = f"{field.name}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name)