88from pydantic import BaseModel
99from sqlalchemy .exc import MissingGreenlet
1010from sqlalchemy .ext .asyncio import AsyncSession , AsyncSessionTransaction
11- from sqlalchemy .inspection import inspect
12- from sqlalchemy .orm import joinedload , selectinload
11+ from sqlalchemy .orm import joinedload , load_only , selectinload
1312from sqlalchemy .orm .attributes import InstrumentedAttribute
1413from sqlalchemy .orm .collections import InstrumentedList
1514from sqlalchemy .sql import Select
@@ -246,13 +245,45 @@ async def create_object(
246245 await self .after_create_object (obj , model_kwargs , view_kwargs )
247246 return obj
248247
249- def get_object_id_field_name (self ):
250- """
251- compound key may cause errors
248+ def get_fields_options (
249+ self ,
250+ resource_type : str ,
251+ qs : QueryStringManager ,
252+ required_to_load : Optional [set ] = None ,
253+ ) -> set :
254+ required_to_load = required_to_load or set ()
252255
253- :return:
254- """
255- return self .id_name_field or inspect (self .model ).primary_key [0 ].key
256+ if resource_type not in qs .fields :
257+ return set ()
258+
259+ # empty str means skip all attributes
260+ if "" not in qs .fields [resource_type ]:
261+ required_to_load .update (field_name for field_name in qs .fields [resource_type ])
262+
263+ return self .get_load_only_options (
264+ resource_type = resource_type ,
265+ field_names = required_to_load ,
266+ )
267+
268+ @staticmethod
269+ def get_load_only_options (
270+ resource_type : str ,
271+ field_names : Iterable [str ],
272+ ) -> set :
273+ model = models_storage .get_model (resource_type )
274+ options = {
275+ load_only (
276+ getattr (
277+ model ,
278+ models_storage .get_model_id_field_name (resource_type ),
279+ ),
280+ ),
281+ }
282+
283+ for field_name in field_names :
284+ options .add (load_only (getattr (model , field_name )))
285+
286+ return options
256287
257288 async def get_object (
258289 self ,
@@ -268,17 +299,24 @@ async def get_object(
268299 """
269300 await self .before_get_object (view_kwargs )
270301
271- filter_field = self .get_object_id_field ()
302+ filter_field = models_storage .get_object_id_field (self . resource_type )
272303 filter_value = self .prepare_id_value (filter_field , view_kwargs [self .url_id_field ])
273304
274- relation_join_objects : list = []
305+ options = set ()
275306 if qs is not None :
276- relation_join_objects = self .eagerload_includes (qs )
307+ options .update (self .eagerload_includes (qs ))
308+ options .update (
309+ self .get_fields_options (
310+ resource_type = self .resource_type ,
311+ qs = qs ,
312+ required_to_load = set (view_kwargs .get ("required_to_load" , set ())),
313+ ),
314+ )
277315
278- query = await self ._base_sql .query (
316+ query = self ._base_sql .query (
279317 model = self .model ,
280318 filters = [filter_field == filter_value ],
281- options = set ( relation_join_objects ) ,
319+ options = options ,
282320 stmt = self ._query ,
283321 )
284322 obj = await self ._base_sql .one_or_raise (
@@ -317,16 +355,16 @@ async def get_collection(
317355 for relationship_path in relationship_paths
318356 ]
319357
320- relation_join_objects : list = []
358+ options = self . get_fields_options ( self . resource_type , qs )
321359 if self .eagerload_includes_ :
322- relation_join_objects = self .eagerload_includes (qs )
360+ options . update ( self .eagerload_includes (qs ) )
323361
324- query = await self ._base_sql .query (
362+ query = self ._base_sql .query (
325363 model = self .model ,
326364 filters = self .get_filter_expressions (qs ),
327365 join = relationships_info ,
328366 number = qs .pagination .number ,
329- options = set ( relation_join_objects ) ,
367+ options = options ,
330368 order = self .get_sort_expressions (qs ),
331369 size = qs .pagination .size ,
332370 stmt = self ._query ,
@@ -535,7 +573,7 @@ async def get_related_objects(
535573 id_field = getattr (related_model , related_id_field )
536574 id_values = [self .prepare_id_value (id_field , id_ ) for id_ in ids ]
537575
538- query = await self ._base_sql .query (
576+ query = self ._base_sql .query (
539577 model = related_model ,
540578 filters = [id_field .in_ (id_values )],
541579 )
@@ -622,6 +660,13 @@ def eagerload_includes(
622660 current_resource_type = relationship_info .resource_type
623661 current_model = models_storage .get_model (current_resource_type )
624662
663+ relation_join_object = relation_join_object .options (
664+ * self .get_fields_options (
665+ resource_type = current_resource_type ,
666+ qs = qs ,
667+ ),
668+ )
669+
625670 relation_join_objects .append (relation_join_object )
626671
627672 return relation_join_objects
@@ -638,7 +683,7 @@ async def before_create_object(
638683 :param view_kwargs: kwargs from the resource view.
639684 """
640685 if (id_value := model_kwargs .get ("id" )) and self .auto_convert_id_to_column_type :
641- model_field = self .get_object_id_field ()
686+ model_field = models_storage .get_object_id_field (resource_type = self . resource_type )
642687 model_kwargs .update (id = self .prepare_id_value (model_field , id_value ))
643688
644689 async def after_create_object (
0 commit comments