@@ -518,6 +518,38 @@ async def delete_relationship(
518518 :param view_kwargs: kwargs from the resource view.
519519 """
520520
521+ def get_related_model_query_base (
522+ self ,
523+ related_model : Type [TypeModel ],
524+ ) -> "Select" :
525+ """
526+ :param related_model:
527+ :return:
528+ """
529+ return select (related_model )
530+
531+ def get_related_object_query (
532+ self ,
533+ related_model : Type [TypeModel ],
534+ related_id_field : str ,
535+ id_value : str ,
536+ ):
537+ id_field = getattr (related_model , related_id_field )
538+ id_value = self .prepare_id_value (id_field , id_value )
539+ stmt : "Select" = self .get_related_model_query_base (related_model )
540+ return stmt .where (id_field == id_value )
541+
542+ def get_related_objects_list_query (
543+ self ,
544+ related_model : Type [TypeModel ],
545+ related_id_field : str ,
546+ ids : list [str ],
547+ ) -> Tuple ["Select" , list [str ]]:
548+ id_field = getattr (related_model , related_id_field )
549+ prepared_ids = [self .prepare_id_value (id_field , _id ) for _id in ids ]
550+ stmt : "Select" = self .get_related_model_query_base (related_model )
551+ return stmt .where (id_field .in_ (prepared_ids )), prepared_ids
552+
521553 async def get_related_object (
522554 self ,
523555 related_model : Type [TypeModel ],
@@ -532,9 +564,12 @@ async def get_related_object(
532564 :param id_value: related object id value
533565 :return: a related SQLA ORM object
534566 """
535- id_field = getattr (related_model , related_id_field )
536- id_value = self .prepare_id_value (id_field , id_value )
537- stmt = select (related_model ).where (id_field == id_value )
567+ stmt = self .get_related_object_query (
568+ related_model = related_model ,
569+ related_id_field = related_id_field ,
570+ id_value = id_value ,
571+ )
572+
538573 try :
539574 related_object = (await self .session .execute (stmt )).scalar_one ()
540575 except NoResultFound :
@@ -556,9 +591,11 @@ async def get_related_objects_list(
556591 :param ids:
557592 :return:
558593 """
559- id_field = getattr (related_model , related_id_field )
560- ids = [self .prepare_id_value (id_field , _id ) for _id in ids ]
561- stmt = select (related_model ).where (id_field .in_ (ids ))
594+ stmt , ids = self .get_related_objects_list_query (
595+ related_model = related_model ,
596+ related_id_field = related_id_field ,
597+ ids = ids ,
598+ )
562599
563600 related_objects = (await self .session .execute (stmt )).scalars ().all ()
564601 object_ids = [getattr (obj , related_id_field ) for obj in related_objects ]
0 commit comments