Skip to content

Commit 653c827

Browse files
CosmoVNatalia Grigoreva
authored andcommitted
updated fields requests logic
1 parent a511ab7 commit 653c827

File tree

13 files changed

+304
-105
lines changed

13 files changed

+304
-105
lines changed

fastapi_jsonapi/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
pagination_default_limit: Optional[int] = None,
7878
methods: Iterable[str] = (),
7979
ending_slash: bool = True,
80+
model_id_field_name: str = "id",
8081
) -> None:
8182
"""
8283
Initialize router items.
@@ -122,7 +123,7 @@ def __init__(
122123
msg = f"Resource type {self.type_!r} already registered"
123124
raise ValueError(msg)
124125
self.all_jsonapi_routers[self.type_] = self
125-
models_storage.add_model(resource_type, model)
126+
models_storage.add_model(resource_type, model, model_id_field_name)
126127

127128
self.pagination_default_size: Optional[int] = pagination_default_size
128129
self.pagination_default_number: Optional[int] = pagination_default_number

fastapi_jsonapi/data_layers/base.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(
2626
model: Type[TypeModel],
2727
resource_type: str,
2828
url_id_field: str,
29-
id_name_field: Optional[str] = None,
3029
disable_collection_count: bool = False,
3130
default_collection_count: int = -1,
3231
**kwargs,
@@ -38,7 +37,6 @@ def __init__(
3837
:param schema:
3938
:param model:
4039
:param url_id_field:
41-
:param id_name_field:
4240
:param disable_collection_count:
4341
:param default_collection_count:
4442
:param resource_type: resource type
@@ -49,7 +47,6 @@ def __init__(
4947
self.model = model
5048
self.resource_type = resource_type
5149
self.url_id_field = url_id_field
52-
self.id_name_field = id_name_field
5350
self.disable_collection_count: bool = disable_collection_count
5451
self.default_collection_count: int = default_collection_count
5552
self.is_atomic = False
@@ -94,26 +91,6 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
9491
"""
9592
raise NotImplementedError
9693

97-
def get_object_id_field_name(self):
98-
"""
99-
compound key may cause errors
100-
101-
:return:
102-
"""
103-
return self.id_name_field
104-
105-
def get_object_id_field(self):
106-
id_name_field = self.get_object_id_field_name()
107-
try:
108-
return getattr(self.model, id_name_field)
109-
except AttributeError:
110-
msg = f"{self.model.__name__} has no attribute {id_name_field}"
111-
# TODO: any custom exception type?
112-
raise Exception(msg)
113-
114-
def get_object_id(self, obj: TypeModel):
115-
return getattr(obj, self.get_object_id_field_name())
116-
11794
async def get_object(self, view_kwargs: dict, qs: Optional[QueryStringManager] = None) -> TypeModel:
11895
"""
11996
Retrieve an object

fastapi_jsonapi/data_layers/sqla/base_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Literal, Optional, Type, Union
2+
from typing import Any, Iterable, Literal, Optional, Type, Union
33

44
from sqlalchemy import and_, delete, func, select
55
from sqlalchemy.exc import IntegrityError
@@ -165,22 +165,21 @@ async def one_or_raise(
165165
return result
166166

167167
@classmethod
168-
async def query(
168+
def query(
169169
cls,
170170
model: TypeModel,
171171
distinct_: bool = False,
172-
fields: Optional[list] = None,
173172
filters: Optional[list[Union[BinaryExpression, bool]]] = None,
174173
for_update: Optional[dict] = None,
175174
join: Optional[list[RelationshipInfo]] = None,
176175
number: Optional[int] = None,
177-
options: set = (),
176+
options: Iterable = (),
178177
order: Optional[Union[str, UnaryExpression]] = None,
179178
size: Optional[int] = None,
180179
stmt: Optional[Select] = None,
181180
) -> Select:
182181
if stmt is None:
183-
stmt = select(model) if fields is None else select(*fields)
182+
stmt = select(model)
184183

185184
if filters is not None:
186185
stmt = stmt.where(*filters)

fastapi_jsonapi/data_layers/sqla/orm.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from pydantic import BaseModel
99
from sqlalchemy.exc import MissingGreenlet
1010
from sqlalchemy.ext.asyncio import AsyncSession, AsyncSessionTransaction
11-
from sqlalchemy.inspection import inspect
12-
from sqlalchemy.orm import joinedload, selectinload
11+
from sqlalchemy.orm import joinedload, load_only, selectinload
1312
from sqlalchemy.orm.attributes import InstrumentedAttribute
1413
from sqlalchemy.orm.collections import InstrumentedList
1514
from sqlalchemy.sql import Select
@@ -246,13 +245,45 @@ async def create_object(
246245
await self.after_create_object(obj, model_kwargs, view_kwargs)
247246
return obj
248247

249-
def get_object_id_field_name(self):
250-
"""
251-
compound key may cause errors
248+
def get_fields_options(
249+
self,
250+
resource_type: str,
251+
qs: QueryStringManager,
252+
required_to_load: Optional[set] = None,
253+
) -> set:
254+
required_to_load = required_to_load or set()
252255

253-
:return:
254-
"""
255-
return self.id_name_field or inspect(self.model).primary_key[0].key
256+
if resource_type not in qs.fields:
257+
return set()
258+
259+
# empty str means skip all attributes
260+
if "" not in qs.fields[resource_type]:
261+
required_to_load.update(field_name for field_name in qs.fields[resource_type])
262+
263+
return self.get_load_only_options(
264+
resource_type=resource_type,
265+
field_names=required_to_load,
266+
)
267+
268+
@staticmethod
269+
def get_load_only_options(
270+
resource_type: str,
271+
field_names: Iterable[str],
272+
) -> set:
273+
model = models_storage.get_model(resource_type)
274+
options = {
275+
load_only(
276+
getattr(
277+
model,
278+
models_storage.get_model_id_field_name(resource_type),
279+
),
280+
),
281+
}
282+
283+
for field_name in field_names:
284+
options.add(load_only(getattr(model, field_name)))
285+
286+
return options
256287

257288
async def get_object(
258289
self,
@@ -268,17 +299,24 @@ async def get_object(
268299
"""
269300
await self.before_get_object(view_kwargs)
270301

271-
filter_field = self.get_object_id_field()
302+
filter_field = models_storage.get_object_id_field(self.resource_type)
272303
filter_value = self.prepare_id_value(filter_field, view_kwargs[self.url_id_field])
273304

274-
relation_join_objects: list = []
305+
options = set()
275306
if qs is not None:
276-
relation_join_objects = self.eagerload_includes(qs)
307+
options.update(self.eagerload_includes(qs))
308+
options.update(
309+
self.get_fields_options(
310+
resource_type=self.resource_type,
311+
qs=qs,
312+
required_to_load=set(view_kwargs.get("required_to_load", set())),
313+
),
314+
)
277315

278-
query = await self._base_sql.query(
316+
query = self._base_sql.query(
279317
model=self.model,
280318
filters=[filter_field == filter_value],
281-
options=set(relation_join_objects),
319+
options=options,
282320
stmt=self._query,
283321
)
284322
obj = await self._base_sql.one_or_raise(
@@ -317,16 +355,16 @@ async def get_collection(
317355
for relationship_path in relationship_paths
318356
]
319357

320-
relation_join_objects: list = []
358+
options = self.get_fields_options(self.resource_type, qs)
321359
if self.eagerload_includes_:
322-
relation_join_objects = self.eagerload_includes(qs)
360+
options.update(self.eagerload_includes(qs))
323361

324-
query = await self._base_sql.query(
362+
query = self._base_sql.query(
325363
model=self.model,
326364
filters=self.get_filter_expressions(qs),
327365
join=relationships_info,
328366
number=qs.pagination.number,
329-
options=set(relation_join_objects),
367+
options=options,
330368
order=self.get_sort_expressions(qs),
331369
size=qs.pagination.size,
332370
stmt=self._query,
@@ -535,7 +573,7 @@ async def get_related_objects(
535573
id_field = getattr(related_model, related_id_field)
536574
id_values = [self.prepare_id_value(id_field, id_) for id_ in ids]
537575

538-
query = await self._base_sql.query(
576+
query = self._base_sql.query(
539577
model=related_model,
540578
filters=[id_field.in_(id_values)],
541579
)
@@ -622,6 +660,13 @@ def eagerload_includes(
622660
current_resource_type = relationship_info.resource_type
623661
current_model = models_storage.get_model(current_resource_type)
624662

663+
relation_join_object = relation_join_object.options(
664+
*self.get_fields_options(
665+
resource_type=current_resource_type,
666+
qs=qs,
667+
),
668+
)
669+
625670
relation_join_objects.append(relation_join_object)
626671

627672
return relation_join_objects
@@ -638,7 +683,7 @@ async def before_create_object(
638683
:param view_kwargs: kwargs from the resource view.
639684
"""
640685
if (id_value := model_kwargs.get("id")) and self.auto_convert_id_to_column_type:
641-
model_field = self.get_object_id_field()
686+
model_field = models_storage.get_object_id_field(resource_type=self.resource_type)
642687
model_kwargs.update(id=self.prepare_id_value(model_field, id_value))
643688

644689
async def after_create_object(

fastapi_jsonapi/models_storage.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Callable, Type
2+
from typing import Any, Callable, Type
33

44
from fastapi_jsonapi.data_typing import TypeModel
55
from fastapi_jsonapi.exceptions import BadRequest, InternalServerError
@@ -11,17 +11,51 @@ class ModelsStorage:
1111
relationship_search_handlers: dict[str, Callable[[str, Type[TypeModel], str], Type[TypeModel]]]
1212

1313
def __init__(self):
14-
self._data: dict[str, TypeModel] = {}
14+
self._models: dict[str, Type[TypeModel]] = {}
15+
self._id_field_names: dict[str, str] = {}
1516
self.relationship_search_handlers = {}
1617

17-
def add_model(self, resource_type: str, model: Type[TypeModel]):
18-
self._data[resource_type] = model
18+
def add_model(self, resource_type: str, model: Type[TypeModel], id_field_name: str):
19+
self._models[resource_type] = model
20+
self._id_field_names[resource_type] = id_field_name
1921

2022
def get_model(self, resource_type: str) -> Type[TypeModel]:
2123
try:
22-
return self._data[resource_type]
24+
return self._models[resource_type]
2325
except KeyError:
24-
raise InternalServerError(detail=f"Not found model for resource_type {resource_type!r}.")
26+
raise InternalServerError(
27+
detail=f"Not found model for resource_type {resource_type!r}.",
28+
)
29+
30+
def get_model_id_field_name(self, resource_type: str) -> str:
31+
try:
32+
return self._id_field_names[resource_type]
33+
except KeyError:
34+
raise InternalServerError(
35+
detail=f"Not found model id field name for resource_type {resource_type!r}.",
36+
)
37+
38+
def get_object_id_field(self, resource_type: str) -> Any:
39+
model = self.get_model(resource_type)
40+
id_field_name = self.get_model_id_field_name(resource_type)
41+
42+
try:
43+
return getattr(model, id_field_name)
44+
except AttributeError:
45+
raise InternalServerError(
46+
detail=f"Can't get object id field. The model {model.__name__!r} has no attribute {id_field_name!r}",
47+
)
48+
49+
def get_object_id(self, db_object: TypeModel, resource_type: str) -> Any:
50+
id_field_name = self.get_model_id_field_name(resource_type)
51+
52+
try:
53+
return getattr(db_object, id_field_name)
54+
except AttributeError:
55+
model = self.get_model(resource_type)
56+
raise InternalServerError(
57+
detail=f"Can't get object id. The model {model.__name__!r} has no attribute {id_field_name!r}.",
58+
)
2559

2660
def register_search_handler(self, orm_mode: str, handler: Callable[[str, Type[TypeModel], str], Type[TypeModel]]):
2761
self.relationship_search_handlers[orm_mode] = handler
@@ -40,7 +74,7 @@ def search_relationship_model(
4074
except KeyError:
4175
raise InternalServerError(
4276
detail=f"Not found orm handler for {self._orm_mode!r}. "
43-
f"Please register this with SchemasStorage.register_search_handler",
77+
f"Please register this with SchemasStorage.register_search_handler.",
4478
)
4579

4680
return orm_handler(resource_type, model, field_name)

fastapi_jsonapi/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ class SchemasInfoDTO:
182182

183183
relationships_info: dict[str, tuple[RelationshipInfo, Any]]
184184

185+
field_schemas: dict[str, Type[BaseModel]]
186+
187+
model_validators: dict
188+
185189

186190
def get_model_field(schema: Type["TypeSchema"], field: str) -> str:
187191
"""

fastapi_jsonapi/schema_builder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def build_schema_in(
132132
source_schema=schema,
133133
data_schema=object_jsonapi_schema,
134134
attributes_schema=dto.attributes_schema,
135+
field_schemas=dto.field_schemas,
135136
relationships_info=dto.relationships_info,
137+
model_validators=dto.model_validators,
136138
)
137139

138140
wrapped_object_jsonapi_schema = create_model(
@@ -230,7 +232,7 @@ def _get_info_from_schema_for_building(
230232
if related_schema := get_schema_from_field_annotation(field):
231233
included_schemas.append((name, related_schema, relationship_info.resource_type))
232234
elif name == "id":
233-
id_validators = extract_validators(
235+
id_validators, _ = extract_validators(
234236
model=schema,
235237
include_for_field_names={"id"},
236238
)
@@ -245,13 +247,27 @@ def _get_info_from_schema_for_building(
245247
from_attributes=True,
246248
)
247249

250+
field_validators, model_validators = extract_validators(schema, exclude_for_field_names={"id"})
248251
attributes_schema = create_model(
249252
f"{base_name}AttributesJSONAPI",
250253
**attributes_schema_fields,
251254
__config__=model_config,
252-
__validators__=extract_validators(schema, exclude_for_field_names={"id"}),
255+
__validators__={**field_validators, **model_validators},
253256
)
254257

258+
field_schemas = {}
259+
for field_name, field in attributes_schema_fields.items():
260+
field_validators, _ = extract_validators(
261+
schema,
262+
include_for_field_names={field_name},
263+
)
264+
field_schemas[field_name] = create_model(
265+
f"{base_name}{field_name.title()}AttributeJSONAPI",
266+
**{field_name: field},
267+
__config__=model_config,
268+
__validators__=field_validators,
269+
)
270+
255271
relationships_schema = create_model(
256272
f"{base_name}RelationshipsJSONAPI",
257273
**relationships_schema_fields,
@@ -265,6 +281,8 @@ def _get_info_from_schema_for_building(
265281
relationships_info=relationships_info,
266282
has_required_relationship=has_required_relationship,
267283
included_schemas=included_schemas,
284+
field_schemas=field_schemas,
285+
model_validators=model_validators,
268286
)
269287

270288
@classmethod
@@ -440,7 +458,9 @@ def create_jsonapi_object_schemas(
440458
source_schema=schema,
441459
data_schema=relationship_less_object_jsonapi_schema,
442460
attributes_schema=dto.attributes_schema,
461+
field_schemas=dto.field_schemas,
443462
relationships_info=dto.relationships_info,
463+
model_validators=dto.model_validators,
444464
)
445465

446466
can_be_included_schemas = {}

0 commit comments

Comments
 (0)