@@ -315,92 +315,62 @@ def raw_mql(self, pipeline, using=None):
315315
316316class MongoRawQuery (RawQuery ):
317317 def __init__ (self , pipeline , using , model ):
318+ self .pipeline = pipeline
318319 super ().__init__ (sql = None , using = using )
319320 self .model = model
320- self .pipeline = pipeline
321-
322- def __iter__ (self ):
323- self ._execute_query ()
324- return self .cursor
325321
326322 def _execute_query (self ):
327323 connection = connections [self .using ]
328324 collection = connection .get_collection (self .model ._meta .db_table )
329325 self .cursor = collection .aggregate (self .pipeline )
330326
331- def get_columns (self ):
332- return [f .column for f in self .model ._meta .fields ]
333-
334327 def __str__ (self ):
335- return "%s" % self .pipeline
328+ return str ( self .pipeline )
336329
337330
338331class MongoRawQuerySet (RawQuerySet ):
339- def __init__ (
340- self ,
341- pipeline ,
342- model = None ,
343- query = None ,
344- translations = None ,
345- using = None ,
346- hints = None ,
347- ):
348- super ().__init__ (
349- pipeline ,
350- model = model ,
351- query = query ,
352- using = using ,
353- hints = hints ,
354- translations = translations ,
355- )
356- self .query = query or MongoRawQuery (pipeline , using = self .db , model = self .model )
332+ def __init__ (self , pipeline , model = None , using = None ):
333+ super ().__init__ (pipeline , model = model , using = using )
334+ self .query = MongoRawQuery (pipeline , using = self .db , model = self .model )
335+ # Override the superclass's columns property which relies on PEP 249's
336+ # cursor.description. Instead, RawModelIterable will set the columns
337+ # based on the keys in the first result.
338+ self .columns = None
357339
358340 def iterator (self ):
359341 yield from MongoRawModelIterable (self )
360342
361- def resolve_model_init_order (self , columns ):
362- """Resolve the init field names and value positions."""
363- model_init_fields = [f for f in self .model ._meta .fields if f .column in columns ]
364- annotation_fields = [
365- (column , pos ) for pos , column in enumerate (columns ) if column not in self .model_fields
366- ]
367- model_init_order = [columns .index (f .column ) for f in model_init_fields ]
368- model_init_names = [f .attname for f in model_init_fields ]
369- return model_init_names , model_init_order , annotation_fields
370-
371343
372344class MongoRawModelIterable (RawModelIterable ):
373- """
374- Iterable that yields a model instance for each row from a raw queryset.
375- """
376-
377345 def __iter__ (self ):
378- # Cache some things for performance reasons outside the loop.
346+ """
347+ This is mostly copied from the superclass except for the part that
348+ sets self.queryset.columns from the first document.
349+ """
379350 db = self .queryset .db
380351 query = self .queryset .query
381352 connection = connections [db ]
382353 compiler = connection .ops .compiler ("SQLCompiler" )(query , connection , db )
383354 query_iterator = iter (query )
384- # Get the columns from the first result.
385- try :
386- first_item = next (query_iterator )
387- except StopIteration :
388- # No results.
389- query .cursor .close ()
390- return
391- columns = list (first_item .keys ())
392- # Reset the iterator to include the first item.
393- query_iterator = self ._make_result (chain ([first_item ], query_iterator ))
394355 try :
356+ # Get the columns from the first result.
357+ try :
358+ first_item = next (query_iterator )
359+ except StopIteration :
360+ # No results.
361+ return
362+ self .queryset .columns = list (first_item .keys ())
363+ # Reset the iterator to include the first item.
364+ query_iterator = self ._make_result (chain ([first_item ], query_iterator ))
395365 (
396366 model_init_names ,
397367 model_init_pos ,
398368 annotation_fields ,
399- ) = self .queryset .resolve_model_init_order (columns )
369+ ) = self .queryset .resolve_model_init_order ()
400370 model_cls = self .queryset .model
401371 if model_cls ._meta .pk .attname not in model_init_names :
402372 raise FieldDoesNotExist ("Raw query must include the primary key" )
403- fields = [self .queryset .model_fields .get (c ) for c in columns ]
373+ fields = [self .queryset .model_fields .get (c ) for c in self . queryset . columns ]
404374 converters = compiler .get_converters (
405375 [f .get_col (f .model ._meta .db_table ) if f else None for f in fields ]
406376 )
@@ -415,9 +385,7 @@ def __iter__(self):
415385 setattr (instance , column , values [pos ])
416386 yield instance
417387 finally :
418- # Done iterating the Query. If it has its own cursor, close it.
419- if hasattr (query , "cursor" ) and query .cursor :
420- query .cursor .close ()
388+ query .cursor .close ()
421389
422390 def _make_result (self , query ):
423391 for result in query :
0 commit comments