1616from sqlalchemy .sql import Select , column , distinct
1717
1818from fastapi_jsonapi import BadRequest
19- from fastapi_jsonapi .common import get_relationship_info_from_field_metadata
2019from fastapi_jsonapi .data_layers .base import BaseDataLayer
2120from fastapi_jsonapi .data_layers .sqla .query_building import (
2221 build_filter_expressions ,
2322 build_sort_expressions ,
2423 prepare_relationships_info ,
24+ relationships_info_storage ,
2525)
2626from fastapi_jsonapi .data_typing import TypeModel , TypeSchema
2727from fastapi_jsonapi .exceptions import (
3232 RelatedObjectNotFound ,
3333 RelationNotFound ,
3434)
35+ from fastapi_jsonapi .models_storage import models_storage
3536from fastapi_jsonapi .querystring import PaginationQueryStringManager , QueryStringManager
3637from fastapi_jsonapi .schema import (
3738 BaseJSONAPIItemInSchema ,
3839 BaseJSONAPIRelationshipDataToManySchema ,
3940 BaseJSONAPIRelationshipDataToOneSchema ,
40- get_model_field ,
41- get_related_schema ,
4241)
43- from fastapi_jsonapi .splitter import SPLIT_REL
42+ from fastapi_jsonapi .schemas_storage import schemas_storage
4443from fastapi_jsonapi .types_metadata import RelationshipInfo
4544
4645log = logging .getLogger (__name__ )
@@ -56,6 +55,7 @@ def __init__(
5655 schema : Type [TypeSchema ],
5756 model : Type [TypeModel ],
5857 session : AsyncSession ,
58+ resource_type : str ,
5959 disable_collection_count : bool = False ,
6060 default_collection_count : int = - 1 ,
6161 id_name_field : Optional [str ] = None ,
@@ -82,6 +82,7 @@ def __init__(
8282 super ().__init__ (
8383 schema = schema ,
8484 model = model ,
85+ resource_type = resource_type ,
8586 url_id_field = url_id_field ,
8687 id_name_field = id_name_field ,
8788 disable_collection_count = disable_collection_count ,
@@ -227,20 +228,17 @@ async def apply_relationships(
227228 if relationships is None :
228229 return
229230
230- schema_fields = self .schema .model_fields or {}
231231 for relation_name , relationship_in in relationships :
232232 if relationship_in is None :
233233 continue
234234
235- field = schema_fields .get (relation_name )
236- if field is None :
237- # should not happen if schema is built properly
238- # there may be an error if schema and schema_in are different
239- log .warning ("Field for %s in schema %s not found" , relation_name , self .schema .__name__ )
240- continue
241-
242- relationship_info : Optional [RelationshipInfo ] = get_relationship_info_from_field_metadata (field )
235+ relationship_info = schemas_storage .get_relationship (
236+ resource_type = self .resource_type ,
237+ operation_type = action_trigger ,
238+ field_name = relation_name ,
239+ )
243240 if relationship_info is None :
241+ log .warning ("Not found relationship %s for resource_type %s" , relation_name , self .resource_type )
244242 continue
245243
246244 related_model = getattr (type (obj ), relation_name ).property .mapper .class_
@@ -413,7 +411,7 @@ async def update_object(
413411 msg ,
414412 pointer = "/data" ,
415413 meta = {
416- "type" : self .type_ ,
414+ "type" : self .resource_type ,
417415 "id" : view_kwargs .get (self .url_id_field ),
418416 },
419417 )
@@ -427,7 +425,7 @@ async def update_object(
427425 detail = err_message ,
428426 pointer = "/data" ,
429427 meta = {
430- "type" : self .type_ ,
428+ "type" : self .resource_type ,
431429 "id" : view_kwargs .get (self .url_id_field ),
432430 },
433431 )
@@ -459,7 +457,7 @@ async def delete_object(self, obj: TypeModel, view_kwargs: dict):
459457 detail = err_message ,
460458 pointer = "/data" ,
461459 meta = {
462- "type" : self .type_ ,
460+ "type" : self .resource_type ,
463461 "id" : view_kwargs .get (self .url_id_field ),
464462 },
465463 )
@@ -476,7 +474,7 @@ async def delete_objects(self, objects: list[TypeModel], view_kwargs: dict):
476474 except DBAPIError as e :
477475 await self .session .rollback ()
478476 raise InternalServerError (
479- detail = f"Got an error { e .__class__ .__name__ } during delete data from DB: { e !s} " ,
477+ detail = f"Got an error { e .__class__ .__name__ } during delete data from DB: { e !s} . " ,
480478 )
481479
482480 await self .after_delete_objects (objects , view_kwargs )
@@ -583,7 +581,7 @@ async def delete_relationship(
583581
584582 def get_related_model_query_base (
585583 self ,
586- related_model : Type [ TypeModel ] ,
584+ related_model : TypeModel ,
587585 ) -> Select :
588586 """
589587 Prepare sql query (statement) to fetch related model
@@ -595,7 +593,7 @@ def get_related_model_query_base(
595593
596594 def get_related_object_query (
597595 self ,
598- related_model : Type [ TypeModel ] ,
596+ related_model : TypeModel ,
599597 related_id_field : str ,
600598 id_value : str ,
601599 ):
@@ -606,7 +604,7 @@ def get_related_object_query(
606604
607605 def get_related_objects_list_query (
608606 self ,
609- related_model : Type [ TypeModel ] ,
607+ related_model : TypeModel ,
610608 related_id_field : str ,
611609 ids : list [str ],
612610 ) -> tuple [Select , list [str ]]:
@@ -617,7 +615,7 @@ def get_related_objects_list_query(
617615
618616 async def get_related_object (
619617 self ,
620- related_model : Type [ TypeModel ] ,
618+ related_model : TypeModel ,
621619 related_id_field : str ,
622620 id_value : str ,
623621 ) -> TypeModel :
@@ -645,7 +643,7 @@ async def get_related_object(
645643
646644 async def get_related_objects_list (
647645 self ,
648- related_model : Type [ TypeModel ] ,
646+ related_model : TypeModel ,
649647 related_id_field : str ,
650648 ids : list [str ],
651649 ) -> list [TypeModel ]:
@@ -678,17 +676,25 @@ async def get_related_objects_list(
678676
679677 def apply_filters_and_sorts (self , query : Select , qs : QueryStringManager ):
680678 filters , sorts = qs .filters , qs .sorts
681- relationships_info = prepare_relationships_info (self .model , self .schema , filters , sorts )
682679
683- for info in relationships_info .values ():
680+ relationship_paths = prepare_relationships_info (
681+ model = self .model ,
682+ schema = self .schema ,
683+ resource_type = self .resource_type ,
684+ filter_info = filters ,
685+ sorting_info = sorts ,
686+ )
687+
688+ for relationship_path in relationship_paths :
689+ info = relationships_info_storage .get_info (self .resource_type , relationship_path )
684690 query = query .join (info .aliased_model , info .join_column )
685691
686692 if filters :
687693 filter_expressions = build_filter_expressions (
688694 filter_item = {"and" : filters },
689695 target_model = self .model ,
690696 target_schema = self .schema ,
691- relationships_info = relationships_info ,
697+ entrypoint_resource_type = self . resource_type ,
692698 )
693699 query = query .where (filter_expressions )
694700
@@ -697,7 +703,7 @@ def apply_filters_and_sorts(self, query: Select, qs: QueryStringManager):
697703 sort_items = sorts ,
698704 target_model = self .model ,
699705 target_schema = self .schema ,
700- relationships_info = relationships_info ,
706+ entrypoint_resource_type = self . resource_type ,
701707 )
702708 query = query .order_by (* sort_expressions )
703709
@@ -731,16 +737,25 @@ def eagerload_includes(self, query: Select, qs: QueryStringManager) -> Select:
731737 for include in qs .include :
732738 relation_join_object = None
733739
734- current_schema = self .schema
735740 current_model = self .model
736- for related_field_name in include .split (SPLIT_REL ):
737- try :
738- field_name_to_load = get_model_field (current_schema , related_field_name )
739- except Exception as e :
740- msg = f"{ e } "
741- raise InvalidInclude (msg )
742-
743- field_to_load : InstrumentedAttribute = getattr (current_model , field_name_to_load )
741+ current_resource_type = self .resource_type
742+
743+ for related_field_name in include .split ("." ):
744+ relationship_info = schemas_storage .get_relationship (
745+ resource_type = current_resource_type ,
746+ operation_type = "get" ,
747+ field_name = related_field_name ,
748+ )
749+ if relationship_info is None :
750+ msg = (
751+ f"Not found relationship { related_field_name !r} from include { include !r} "
752+ f"for resource_type { current_resource_type !r} ."
753+ )
754+ raise InvalidInclude (
755+ msg ,
756+ )
757+
758+ field_to_load : InstrumentedAttribute = getattr (current_model , related_field_name )
744759 is_many = field_to_load .property .uselist
745760 if relation_join_object is None :
746761 relation_join_object = selectinload (field_to_load ) if is_many else joinedload (field_to_load )
@@ -749,11 +764,8 @@ def eagerload_includes(self, query: Select, qs: QueryStringManager) -> Select:
749764 else :
750765 relation_join_object = relation_join_object .joinedload (field_to_load )
751766
752- current_schema = get_related_schema (current_schema , related_field_name )
753-
754- # the first entity is Mapper,
755- # the second entity is DeclarativeMeta
756- current_model = field_to_load .property .entity .entity
767+ current_resource_type = relationship_info .resource_type
768+ current_model = models_storage .get_model (current_resource_type )
757769
758770 query = query .options (relation_join_object )
759771
0 commit comments