44from django .db import models
55from django .db .models import Field
66from django .db .models .expressions import Col
7- from django .db .models .lookups import Transform
7+ from django .db .models .lookups import Lookup , Transform
88
99from .. import forms
1010from ..query_utils import process_lhs , process_rhs
1111from . import EmbeddedModelField
1212from .array import ArrayField
13- from .embedded_model import EMFExact
13+ from .embedded_model import EMFExact , EMFMixin
1414
1515
1616class EmbeddedModelArrayField (ArrayField ):
@@ -63,17 +63,8 @@ def get_transform(self, name):
6363 return KeyTransformFactory (name , self )
6464
6565
66- class ProcessRHSMixin :
67- def process_rhs (self , compiler , connection ):
68- if isinstance (self .lhs , KeyTransform ):
69- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
70- else :
71- get_db_prep_value = self .lhs .output_field .get_db_prep_value
72- return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
73-
74-
7566@EmbeddedModelArrayField .register_lookup
76- class EMFArrayExact (EMFExact , ProcessRHSMixin ):
67+ class EMFArrayExact (EMFExact ):
7768 def as_mql (self , compiler , connection ):
7869 lhs_mql = process_lhs (self , compiler , connection )
7970 value = process_rhs (self , compiler , connection )
@@ -116,12 +107,29 @@ def as_mql(self, compiler, connection):
116107
117108
118109@EmbeddedModelArrayField .register_lookup
119- class ArrayOverlap (EMFExact , ProcessRHSMixin ):
110+ class ArrayOverlap (EMFMixin , Lookup ):
120111 lookup_name = "overlap"
112+ get_db_prep_lookup_value_is_iterable = True
113+
114+ def process_rhs (self , compiler , connection ):
115+ values = self .rhs
116+ if self .get_db_prep_lookup_value_is_iterable :
117+ values = [values ]
118+ # Compute how to serialize each value based on the query target.
119+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
120+ # field of the subfield. Otherwise, use the base field of the array itself.
121+ if isinstance (self .lhs , KeyTransform ):
122+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
123+ else :
124+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
125+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
121126
122127 def as_mql (self , compiler , connection ):
123128 lhs_mql = process_lhs (self , compiler , connection )
124129 values = process_rhs (self , compiler , connection )
130+ # Querying a subfield within the array elements (via nested KeyTransform).
131+ # Replicates MongoDB's implicit ANY-match by mapping over the array and applying
132+ # `$in` on the subfield.
125133 if isinstance (self .lhs , KeyTransform ):
126134 lhs_mql , inner_lhs_mql = lhs_mql
127135 return {
@@ -140,11 +148,12 @@ def as_mql(self, compiler, connection):
140148 }
141149 conditions = []
142150 inner_lhs_mql = "$$item"
151+ # Querying full embedded documents in the array.
152+ # Builds `$or` conditions and maps them over the array to match any full document.
143153 for value in values :
144- if isinstance (value , models .Model ):
145- value , emf_data = self .model_to_dict (value )
146- # Get conditions for any nested EmbeddedModelFields.
147- conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
154+ value , emf_data = self .model_to_dict (value )
155+ # Get conditions for any nested EmbeddedModelFields.
156+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
148157 return {
149158 "$anyElementTrue" : {
150159 "$ifNull" : [
0 commit comments