Skip to content

Commit 56a4a58

Browse files
committed
refactor tests
1 parent ff859d1 commit 56a4a58

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -601,27 +601,27 @@ async def test_many_to_many_load_inner_includes_to_parents(
601601

602602
class TestUserWithPostsWithInnerIncludes:
603603
@mark.parametrize(
604-
"include, expected_relationships_post, case_name",
604+
"include, expected_relationships_inner_relations, expect_user_include",
605605
[
606606
(
607607
["posts", "posts.user"],
608-
["user"],
609-
"",
608+
{"post": ["user"], "user": []},
609+
False,
610610
),
611611
(
612612
["posts", "posts.comments"],
613-
["comments"],
614-
"",
613+
{"post": ["comments"], "post_comment": []},
614+
False,
615615
),
616616
(
617617
["posts", "posts.user", "posts.comments"],
618-
["user", "comments"],
619-
"case_1",
618+
{"post": ["user", "comments"], "user": [], "post_comment": []},
619+
False,
620620
),
621621
(
622622
["posts", "posts.user", "posts.comments", "posts.comments.author"],
623-
["user", "comments"],
624-
"case_2",
623+
{"post": ["user", "comments"], "post_comment": ["author"], "user": []},
624+
True,
625625
),
626626
],
627627
)
@@ -635,8 +635,8 @@ async def test_get_users_with_posts_and_inner_includes(
635635
user_1_post_for_comments: Post,
636636
user_2_comment_for_one_u1_post: PostComment,
637637
include: list[str],
638-
expected_relationships_post: list[str],
639-
case_name: bool,
638+
expected_relationships_inner_relations: dict[str, list[str]],
639+
expect_user_include: bool,
640640
):
641641
"""
642642
Test if requesting `posts.user` and `posts.comments`
@@ -672,45 +672,51 @@ async def test_get_users_with_posts_and_inner_includes(
672672
},
673673
]
674674
included_data = response_json["included"]
675+
included_as_map = defaultdict(list)
676+
for item in included_data:
677+
included_as_map[item["type"]].append(item)
675678

676-
included_posts = [item for item in included_data if item["type"] == "post"]
677-
for post in included_posts:
678-
post_relationships = set(post.get("relationships", {}))
679-
assert post_relationships.intersection(expected_relationships_post) == set(
680-
expected_relationships_post,
681-
), f"Expected relationships {expected_relationships_post} not found in post {post['id']}"
682-
683-
if not case_name:
684-
return
685-
included_as_map, expected_includes = self.prepare_expected_includes(
686-
included=included_data,
679+
for item_type, items in included_as_map.items():
680+
expected_relationships = expected_relationships_inner_relations[item_type]
681+
for item in items:
682+
relationships = set(item.get("relationships", {}))
683+
assert relationships.intersection(expected_relationships) == set(
684+
expected_relationships,
685+
), f"Expected relationships {expected_relationships} not found in {item_type} {item['id']}"
686+
687+
expected_includes = self.prepare_expected_includes(
687688
user_1=user_1,
688689
user_2=user_2,
689690
user_1_posts=user_1_posts,
690691
user_2_comment_for_one_u1_post=user_2_comment_for_one_u1_post,
691692
)
692693

693-
if case_name == "case_2":
694-
assert "user" in expected_includes
695-
elif case_name == "case_1":
694+
for item_type, includes_names in expected_relationships_inner_relations.items():
695+
items = expected_includes[item_type]
696+
have_to_be_present = set(includes_names)
697+
for item in items: # type: dict
698+
item_relationships = item.get("relationships", {})
699+
for key in tuple(item_relationships.keys()):
700+
if key not in have_to_be_present:
701+
item_relationships.pop(key)
702+
if not item_relationships:
703+
item.pop("relationships", None)
704+
705+
for key in set(expected_includes).difference(expected_relationships_inner_relations):
706+
expected_includes.pop(key)
707+
708+
# XXX
709+
if not expect_user_include:
696710
expected_includes.pop("user", None)
697-
for pc in expected_includes["post_comment"]:
698-
pc.pop("relationships", None)
699-
700711
assert included_as_map == expected_includes
701712

702713
def prepare_expected_includes(
703714
self,
704-
included: list[dict],
705715
user_1: User,
706716
user_2: User,
707717
user_1_posts: list[PostComment],
708718
user_2_comment_for_one_u1_post: PostComment,
709719
):
710-
included_as_map = defaultdict(list)
711-
for item in included:
712-
included_as_map[item["type"]].append(item)
713-
714720
expected_includes = {
715721
"post": [
716722
#
@@ -764,7 +770,7 @@ def prepare_expected_includes(
764770
],
765771
}
766772

767-
return included_as_map, expected_includes
773+
return expected_includes
768774

769775

770776
async def test_method_not_allowed(app: FastAPI, client: AsyncClient):

0 commit comments

Comments
 (0)