Skip to content

Commit 65d9520

Browse files
authored
Merge pull request #59 from mts-ai/feature/fix-loading-multi-relationships
fix loading multi relationships
2 parents 2814145 + 56a4a58 commit 65d9520

File tree

3 files changed

+218
-21
lines changed

3 files changed

+218
-21
lines changed

fastapi_jsonapi/views/view_base.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import inspect
22
import logging
3+
from collections import defaultdict
34
from contextvars import ContextVar
45
from functools import partial
56
from typing import (
67
Any,
78
Callable,
9+
ClassVar,
810
Dict,
911
Iterable,
1012
List,
@@ -47,14 +49,17 @@
4749
included_object_schema_ctx_var: ContextVar[Type[TypeSchema]] = ContextVar("included_object_schema_ctx_var")
4850
relationship_info_ctx_var: ContextVar[RelationshipInfo] = ContextVar("relationship_info_ctx_var")
4951

52+
# TODO: just change state on `self`!! (refactor)
53+
included_objects_ctx_var: ContextVar[Dict[Tuple[str, str], TypeSchema]] = ContextVar("included_objects_ctx_var")
54+
5055

5156
class ViewBase:
5257
"""
5358
Views are inited for each request
5459
"""
5560

5661
data_layer_cls = BaseDataLayer
57-
method_dependencies: Dict[HTTPMethod, HTTPMethodConfig] = {}
62+
method_dependencies: ClassVar[Dict[HTTPMethod, HTTPMethodConfig]] = {}
5863

5964
def __init__(self, *, request: Request, jsonapi: RoutersJSONAPI, **options):
6065
self.request: Request = request
@@ -241,12 +246,12 @@ def prepare_data_for_relationship(
241246
def update_related_object(
242247
cls,
243248
relationship_data: Union[Dict[str, str], List[Dict[str, str]]],
244-
included_objects: Dict[Tuple[str, str], TypeSchema],
245249
cache_key: Tuple[str, str],
246250
related_field_name: str,
247251
):
248252
relationships_schema: Type[BaseModel] = relationships_schema_ctx_var.get()
249253
object_schema: Type[JSONAPIObjectSchema] = object_schema_ctx_var.get()
254+
included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get()
250255

251256
relationship_data_schema = get_related_schema(relationships_schema, related_field_name)
252257
parent_included_object = included_objects.get(cache_key)
@@ -257,12 +262,10 @@ def update_related_object(
257262
existing = existing.dict()
258263
new_relationships.update(existing)
259264
new_relationships.update(
260-
{
261-
**{
262-
related_field_name: relationship_data_schema(
263-
data=relationship_data,
264-
),
265-
},
265+
**{
266+
related_field_name: relationship_data_schema(
267+
data=relationship_data,
268+
),
266269
},
267270
)
268271
included_objects[cache_key] = object_schema.parse_obj(
@@ -274,17 +277,19 @@ def update_related_object(
274277
@classmethod
275278
def update_known_included(
276279
cls,
277-
included_objects: Dict[Tuple[str, str], TypeSchema],
278280
new_included: List[TypeSchema],
279281
):
282+
included_objects: Dict[Tuple[str, str], TypeSchema] = included_objects_ctx_var.get()
283+
280284
for included in new_included:
281-
included_objects[(included.id, included.type)] = included
285+
key = (included.id, included.type)
286+
if key not in included_objects:
287+
included_objects[key] = included
282288

283289
@classmethod
284290
def process_single_db_item_and_prepare_includes(
285291
cls,
286292
parent_db_item: TypeModel,
287-
included_objects: Dict[Tuple[str, str], TypeSchema],
288293
):
289294
previous_resource_type: str = previous_resource_type_ctx_var.get()
290295
related_field_name: str = related_field_name_ctx_var.get()
@@ -306,7 +311,6 @@ def process_single_db_item_and_prepare_includes(
306311
)
307312

308313
cls.update_known_included(
309-
included_objects=included_objects,
310314
new_included=new_included,
311315
)
312316
relationship_data_items.append(data_for_relationship)
@@ -318,7 +322,6 @@ def process_single_db_item_and_prepare_includes(
318322

319323
cls.update_related_object(
320324
relationship_data=relationship_data_items,
321-
included_objects=included_objects,
322325
cache_key=cache_key,
323326
related_field_name=related_field_name,
324327
)
@@ -329,14 +332,12 @@ def process_single_db_item_and_prepare_includes(
329332
def process_db_items_and_prepare_includes(
330333
cls,
331334
parent_db_items: List[TypeModel],
332-
included_objects: Dict[Tuple[str, str], TypeSchema],
333335
):
334336
next_current_db_item = []
335337

336338
for parent_db_item in parent_db_items:
337339
new_next_items = cls.process_single_db_item_and_prepare_includes(
338340
parent_db_item=parent_db_item,
339-
included_objects=included_objects,
340341
)
341342
next_current_db_item.extend(new_next_items)
342343
return next_current_db_item
@@ -347,18 +348,21 @@ def process_include_with_nested(
347348
current_db_item: Union[List[TypeModel], TypeModel],
348349
item_as_schema: TypeSchema,
349350
current_relation_schema: Type[TypeSchema],
351+
included_objects: Dict[Tuple[str, str], TypeSchema],
352+
requested_includes: Dict[str, Iterable[str]],
350353
) -> Tuple[Dict[str, TypeSchema], List[JSONAPIObjectSchema]]:
351354
root_item_key = (item_as_schema.id, item_as_schema.type)
352-
included_objects: Dict[Tuple[str, str], TypeSchema] = {
353-
root_item_key: item_as_schema,
354-
}
355+
356+
if root_item_key not in included_objects:
357+
included_objects[root_item_key] = item_as_schema
355358
previous_resource_type = item_as_schema.type
356359

360+
previous_related_field_name = previous_resource_type
357361
for related_field_name in include.split(SPLIT_REL):
358362
object_schemas = self.jsonapi.schema_builder.create_jsonapi_object_schemas(
359363
schema=current_relation_schema,
360-
includes=[related_field_name],
361-
compute_included_schemas=bool([related_field_name]),
364+
includes=requested_includes[previous_related_field_name],
365+
compute_included_schemas=True,
362366
)
363367
relationships_schema = object_schemas.relationships_schema
364368
schemas_include = object_schemas.can_be_included_schemas
@@ -380,16 +384,28 @@ def process_include_with_nested(
380384
related_field_name_ctx_var.set(related_field_name)
381385
relationship_info_ctx_var.set(relationship_info)
382386
included_object_schema_ctx_var.set(included_object_schema)
387+
included_objects_ctx_var.set(included_objects)
383388

384389
current_db_item = self.process_db_items_and_prepare_includes(
385390
parent_db_items=current_db_item,
386-
included_objects=included_objects,
387391
)
388392

389393
previous_resource_type = relationship_info.resource_type
394+
previous_related_field_name = related_field_name
390395

391396
return included_objects.pop(root_item_key), list(included_objects.values())
392397

398+
def prep_requested_includes(self, includes: Iterable[str]):
399+
requested_includes: Dict[str, set[str]] = defaultdict(set)
400+
default: str = self.jsonapi.type_
401+
for include in includes:
402+
prev = default
403+
for related_field_name in include.split(SPLIT_REL):
404+
requested_includes[prev].add(related_field_name)
405+
prev = related_field_name
406+
407+
return requested_includes
408+
393409
def process_db_object(
394410
self,
395411
includes: List[str],
@@ -404,12 +420,17 @@ def process_db_object(
404420
attributes=object_schemas.attributes_schema.from_orm(item),
405421
)
406422

423+
cache_included_objects: Dict[Tuple[str, str], TypeSchema] = {}
424+
requested_includes = self.prep_requested_includes(includes)
425+
407426
for include in includes:
408427
item_as_schema, new_included_objects = self.process_include_with_nested(
409428
include=include,
410429
current_db_item=item,
411430
item_as_schema=item_as_schema,
412431
current_relation_schema=item_schema,
432+
included_objects=cache_included_objects,
433+
requested_includes=requested_includes,
413434
)
414435

415436
included_objects.extend(new_included_objects)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ extend-ignore = [
219219
"RUF001", # String contains ambiguous unicode character {confusable} (did you mean {representant}?)
220220
"RUF002", # Docstring contains ambiguous unicode character {confusable} (did you mean {representant}?)
221221
"RUF003", # Comment contains ambiguous unicode character {confusable} (did you mean {representant}?)
222+
"PT006", # pytest parametrize tuple args
222223
]
223224

224225
[tool.ruff.per-file-ignores]

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections import defaultdict
23
from itertools import chain, zip_longest
34
from json import dumps
45
from typing import Dict, List
@@ -598,6 +599,180 @@ async def test_many_to_many_load_inner_includes_to_parents(
598599
assert ("child", ViewBase.get_db_item_id(child_4)) not in included_data
599600

600601

602+
class TestUserWithPostsWithInnerIncludes:
603+
@mark.parametrize(
604+
"include, expected_relationships_inner_relations, expect_user_include",
605+
[
606+
(
607+
["posts", "posts.user"],
608+
{"post": ["user"], "user": []},
609+
False,
610+
),
611+
(
612+
["posts", "posts.comments"],
613+
{"post": ["comments"], "post_comment": []},
614+
False,
615+
),
616+
(
617+
["posts", "posts.user", "posts.comments"],
618+
{"post": ["user", "comments"], "user": [], "post_comment": []},
619+
False,
620+
),
621+
(
622+
["posts", "posts.user", "posts.comments", "posts.comments.author"],
623+
{"post": ["user", "comments"], "post_comment": ["author"], "user": []},
624+
True,
625+
),
626+
],
627+
)
628+
async def test_get_users_with_posts_and_inner_includes(
629+
self,
630+
app: FastAPI,
631+
client: AsyncClient,
632+
user_1: User,
633+
user_2: User,
634+
user_1_posts: list[PostComment],
635+
user_1_post_for_comments: Post,
636+
user_2_comment_for_one_u1_post: PostComment,
637+
include: list[str],
638+
expected_relationships_inner_relations: dict[str, list[str]],
639+
expect_user_include: bool,
640+
):
641+
"""
642+
Test if requesting `posts.user` and `posts.comments`
643+
returns posts with both `user` and `comments`
644+
"""
645+
assert user_1_posts
646+
assert user_2_comment_for_one_u1_post.author_id == user_2.id
647+
include_param = ",".join(include)
648+
resource_type = "user"
649+
url = app.url_path_for(f"get_{resource_type}_list")
650+
url = f"{url}?filter[name]={user_1.name}&include={include_param}"
651+
response = await client.get(url)
652+
assert response.status_code == status.HTTP_200_OK, response.text
653+
response_json = response.json()
654+
655+
result_data = response_json["data"]
656+
657+
assert result_data == [
658+
{
659+
"id": str(user_1.id),
660+
"type": resource_type,
661+
"attributes": UserAttributesBaseSchema.from_orm(user_1).dict(),
662+
"relationships": {
663+
"posts": {
664+
"data": [
665+
# relationship info
666+
{"id": str(p.id), "type": "post"}
667+
# for every post
668+
for p in user_1_posts
669+
],
670+
},
671+
},
672+
},
673+
]
674+
included_data = response_json["included"]
675+
included_as_map = defaultdict(list)
676+
for item in included_data:
677+
included_as_map[item["type"]].append(item)
678+
679+
for item_type, items in included_as_map.items():
680+
expected_relationships = expected_relationships_inner_relations[item_type]
681+
for item in items:
682+
relationships = set(item.get("relationships", {}))
683+
assert relationships.intersection(expected_relationships) == set(
684+
expected_relationships,
685+
), f"Expected relationships {expected_relationships} not found in {item_type} {item['id']}"
686+
687+
expected_includes = self.prepare_expected_includes(
688+
user_1=user_1,
689+
user_2=user_2,
690+
user_1_posts=user_1_posts,
691+
user_2_comment_for_one_u1_post=user_2_comment_for_one_u1_post,
692+
)
693+
694+
for item_type, includes_names in expected_relationships_inner_relations.items():
695+
items = expected_includes[item_type]
696+
have_to_be_present = set(includes_names)
697+
for item in items: # type: dict
698+
item_relationships = item.get("relationships", {})
699+
for key in tuple(item_relationships.keys()):
700+
if key not in have_to_be_present:
701+
item_relationships.pop(key)
702+
if not item_relationships:
703+
item.pop("relationships", None)
704+
705+
for key in set(expected_includes).difference(expected_relationships_inner_relations):
706+
expected_includes.pop(key)
707+
708+
# XXX
709+
if not expect_user_include:
710+
expected_includes.pop("user", None)
711+
assert included_as_map == expected_includes
712+
713+
def prepare_expected_includes(
714+
self,
715+
user_1: User,
716+
user_2: User,
717+
user_1_posts: list[PostComment],
718+
user_2_comment_for_one_u1_post: PostComment,
719+
):
720+
expected_includes = {
721+
"post": [
722+
#
723+
{
724+
"id": str(p.id),
725+
"type": "post",
726+
"attributes": PostAttributesBaseSchema.from_orm(p).dict(),
727+
"relationships": {
728+
"user": {
729+
"data": {
730+
"id": str(user_1.id),
731+
"type": "user",
732+
},
733+
},
734+
"comments": {
735+
"data": [
736+
{
737+
"id": str(user_2_comment_for_one_u1_post.id),
738+
"type": "post_comment",
739+
},
740+
]
741+
if p.id == user_2_comment_for_one_u1_post.post_id
742+
else [],
743+
},
744+
},
745+
}
746+
#
747+
for p in user_1_posts
748+
],
749+
"post_comment": [
750+
{
751+
"id": str(user_2_comment_for_one_u1_post.id),
752+
"type": "post_comment",
753+
"attributes": PostCommentAttributesBaseSchema.from_orm(user_2_comment_for_one_u1_post).dict(),
754+
"relationships": {
755+
"author": {
756+
"data": {
757+
"id": str(user_2.id),
758+
"type": "user",
759+
},
760+
},
761+
},
762+
},
763+
],
764+
"user": [
765+
{
766+
"id": str(user_2.id),
767+
"type": "user",
768+
"attributes": UserAttributesBaseSchema.from_orm(user_2).dict(),
769+
},
770+
],
771+
}
772+
773+
return expected_includes
774+
775+
601776
async def test_method_not_allowed(app: FastAPI, client: AsyncClient):
602777
url = app.url_path_for("get_user_list")
603778
res = await client.put(url, json={})

0 commit comments

Comments
 (0)