@@ -132,7 +132,7 @@ def get_transform(self, name):
132132 if transform :
133133 return transform
134134 field = self .embedded_model ._meta .get_field (name )
135- return KeyTransformFactory ( name , field )
135+ return EmbeddedModelTransformFactory ( field )
136136
137137 def validate (self , value , model_instance ):
138138 super ().validate (value , model_instance )
@@ -156,39 +156,41 @@ def formfield(self, **kwargs):
156156 )
157157
158158
159- class KeyTransform (Transform ):
160- def __init__ (self , key_name , ref_field , * args , ** kwargs ):
159+ class EmbeddedModelTransform (Transform ):
160+ def __init__ (self , field , * args , ** kwargs ):
161161 super ().__init__ (* args , ** kwargs )
162- self .key_name = str (key_name )
163- self .ref_field = ref_field
162+ # The field is referenced in this class as self.field since
163+ # BaseExpression.field returns self.output_field, which returns
164+ # self._field).
165+ self ._field = field
164166
165167 def get_lookup (self , name ):
166- return self .ref_field .get_lookup (name )
168+ return self .field .get_lookup (name )
167169
168170 def get_transform (self , name ):
169171 """
170172 Validate that `name` is either a field of an embedded model or a
171173 lookup on an embedded model's field.
172174 """
173- if transform := self .ref_field .get_transform (name ):
175+ if transform := self .field .get_transform (name ):
174176 return transform
175- suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
177+ suggested_lookups = difflib .get_close_matches (name , self .field .get_lookups ())
176178 if suggested_lookups :
177179 suggested_lookups = " or " .join (suggested_lookups )
178180 suggestion = f", perhaps you meant { suggested_lookups } ?"
179181 else :
180182 suggestion = "."
181183 raise FieldDoesNotExist (
182184 f"Unsupported lookup '{ name } ' for "
183- f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
185+ f"{ self .field .__class__ .__name__ } '{ self .field .name } '"
184186 f"{ suggestion } "
185187 )
186188
187189 def as_mql (self , compiler , connection , as_path = False ):
188190 previous = self
189191 columns = []
190- while isinstance (previous , KeyTransform ):
191- columns .insert (0 , previous .ref_field .column )
192+ while isinstance (previous , EmbeddedModelTransform ):
193+ columns .insert (0 , previous .field .column )
192194 previous = previous .lhs
193195 if as_path :
194196 mql = previous .as_mql (compiler , connection , as_path = True )
@@ -201,13 +203,12 @@ def as_mql(self, compiler, connection, as_path=False):
201203
202204 @property
203205 def output_field (self ):
204- return self .ref_field
206+ return self ._field
205207
206208
207- class KeyTransformFactory :
208- def __init__ (self , key_name , ref_field ):
209- self .key_name = key_name
210- self .ref_field = ref_field
209+ class EmbeddedModelTransformFactory :
210+ def __init__ (self , field ):
211+ self .field = field
211212
212213 def __call__ (self , * args , ** kwargs ):
213- return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
214+ return EmbeddedModelTransform (self .field , * args , ** kwargs )
0 commit comments