@@ -601,27 +601,27 @@ async def test_many_to_many_load_inner_includes_to_parents(
601601
602602class TestUserWithPostsWithInnerIncludes :
603603 @mark .parametrize (
604- "include, expected_relationships_post, case_name " ,
604+ "include, expected_relationships_inner_relations, expect_user_include " ,
605605 [
606606 (
607607 ["posts" , "posts.user" ],
608- ["user" ],
609- "" ,
608+ { "post" : ["user" ], "user" : []} ,
609+ False ,
610610 ),
611611 (
612612 ["posts" , "posts.comments" ],
613- ["comments" ],
614- "" ,
613+ { "post" : ["comments" ], "post_comment" : []} ,
614+ False ,
615615 ),
616616 (
617617 ["posts" , "posts.user" , "posts.comments" ],
618- ["user" , "comments" ],
619- "case_1" ,
618+ { "post" : ["user" , "comments" ], "user" : [], "post_comment" : []} ,
619+ False ,
620620 ),
621621 (
622622 ["posts" , "posts.user" , "posts.comments" , "posts.comments.author" ],
623- ["user" , "comments" ],
624- "case_2" ,
623+ { "post" : ["user" , "comments" ], "post_comment" : [ "author" ], "user" : []} ,
624+ True ,
625625 ),
626626 ],
627627 )
@@ -635,8 +635,8 @@ async def test_get_users_with_posts_and_inner_includes(
635635 user_1_post_for_comments : Post ,
636636 user_2_comment_for_one_u1_post : PostComment ,
637637 include : list [str ],
638- expected_relationships_post : list [str ],
639- case_name : bool ,
638+ expected_relationships_inner_relations : dict [ str , list [str ] ],
639+ expect_user_include : bool ,
640640 ):
641641 """
642642 Test if requesting `posts.user` and `posts.comments`
@@ -672,45 +672,51 @@ async def test_get_users_with_posts_and_inner_includes(
672672 },
673673 ]
674674 included_data = response_json ["included" ]
675+ included_as_map = defaultdict (list )
676+ for item in included_data :
677+ included_as_map [item ["type" ]].append (item )
675678
676- included_posts = [item for item in included_data if item ["type" ] == "post" ]
677- for post in included_posts :
678- post_relationships = set (post .get ("relationships" , {}))
679- assert post_relationships .intersection (expected_relationships_post ) == set (
680- expected_relationships_post ,
681- ), f"Expected relationships { expected_relationships_post } not found in post { post ['id' ]} "
682-
683- if not case_name :
684- return
685- included_as_map , expected_includes = self .prepare_expected_includes (
686- included = included_data ,
679+ for item_type , items in included_as_map .items ():
680+ expected_relationships = expected_relationships_inner_relations [item_type ]
681+ for item in items :
682+ relationships = set (item .get ("relationships" , {}))
683+ assert relationships .intersection (expected_relationships ) == set (
684+ expected_relationships ,
685+ ), f"Expected relationships { expected_relationships } not found in { item_type } { item ['id' ]} "
686+
687+ expected_includes = self .prepare_expected_includes (
687688 user_1 = user_1 ,
688689 user_2 = user_2 ,
689690 user_1_posts = user_1_posts ,
690691 user_2_comment_for_one_u1_post = user_2_comment_for_one_u1_post ,
691692 )
692693
693- if case_name == "case_2" :
694- assert "user" in expected_includes
695- elif case_name == "case_1" :
694+ for item_type , includes_names in expected_relationships_inner_relations .items ():
695+ items = expected_includes [item_type ]
696+ have_to_be_present = set (includes_names )
697+ for item in items : # type: dict
698+ item_relationships = item .get ("relationships" , {})
699+ for key in tuple (item_relationships .keys ()):
700+ if key not in have_to_be_present :
701+ item_relationships .pop (key )
702+ if not item_relationships :
703+ item .pop ("relationships" , None )
704+
705+ for key in set (expected_includes ).difference (expected_relationships_inner_relations ):
706+ expected_includes .pop (key )
707+
708+ # XXX
709+ if not expect_user_include :
696710 expected_includes .pop ("user" , None )
697- for pc in expected_includes ["post_comment" ]:
698- pc .pop ("relationships" , None )
699-
700711 assert included_as_map == expected_includes
701712
702713 def prepare_expected_includes (
703714 self ,
704- included : list [dict ],
705715 user_1 : User ,
706716 user_2 : User ,
707717 user_1_posts : list [PostComment ],
708718 user_2_comment_for_one_u1_post : PostComment ,
709719 ):
710- included_as_map = defaultdict (list )
711- for item in included :
712- included_as_map [item ["type" ]].append (item )
713-
714720 expected_includes = {
715721 "post" : [
716722 #
@@ -764,7 +770,7 @@ def prepare_expected_includes(
764770 ],
765771 }
766772
767- return included_as_map , expected_includes
773+ return expected_includes
768774
769775
770776async def test_method_not_allowed (app : FastAPI , client : AsyncClient ):
0 commit comments