diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index e531999a..c2a47956 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -601,27 +601,27 @@ async def test_many_to_many_load_inner_includes_to_parents( class TestUserWithPostsWithInnerIncludes: @mark.parametrize( - "include, expected_relationships_post, case_name", + "include, expected_relationships_inner_relations, expect_user_include", [ ( ["posts", "posts.user"], - ["user"], - "", + {"post": ["user"], "user": []}, + False, ), ( ["posts", "posts.comments"], - ["comments"], - "", + {"post": ["comments"], "post_comment": []}, + False, ), ( ["posts", "posts.user", "posts.comments"], - ["user", "comments"], - "case_1", + {"post": ["user", "comments"], "user": [], "post_comment": []}, + False, ), ( ["posts", "posts.user", "posts.comments", "posts.comments.author"], - ["user", "comments"], - "case_2", + {"post": ["user", "comments"], "post_comment": ["author"], "user": []}, + True, ), ], ) @@ -635,8 +635,8 @@ async def test_get_users_with_posts_and_inner_includes( user_1_post_for_comments: Post, user_2_comment_for_one_u1_post: PostComment, include: list[str], - expected_relationships_post: list[str], - case_name: bool, + expected_relationships_inner_relations: dict[str, list[str]], + expect_user_include: bool, ): """ Test if requesting `posts.user` and `posts.comments` @@ -672,45 +672,51 @@ async def test_get_users_with_posts_and_inner_includes( }, ] included_data = response_json["included"] + included_as_map = defaultdict(list) + for item in included_data: + included_as_map[item["type"]].append(item) - included_posts = [item for item in included_data if item["type"] == "post"] - for post in included_posts: - post_relationships = set(post.get("relationships", {})) - assert post_relationships.intersection(expected_relationships_post) == set( - expected_relationships_post, - ), f"Expected relationships {expected_relationships_post} not found in post {post['id']}" - - if not case_name: - return - included_as_map, expected_includes = self.prepare_expected_includes( - included=included_data, + for item_type, items in included_as_map.items(): + expected_relationships = expected_relationships_inner_relations[item_type] + for item in items: + relationships = set(item.get("relationships", {})) + assert relationships.intersection(expected_relationships) == set( + expected_relationships, + ), f"Expected relationships {expected_relationships} not found in {item_type} {item['id']}" + + expected_includes = self.prepare_expected_includes( user_1=user_1, user_2=user_2, user_1_posts=user_1_posts, user_2_comment_for_one_u1_post=user_2_comment_for_one_u1_post, ) - if case_name == "case_2": - assert "user" in expected_includes - elif case_name == "case_1": + for item_type, includes_names in expected_relationships_inner_relations.items(): + items = expected_includes[item_type] + have_to_be_present = set(includes_names) + for item in items: # type: dict + item_relationships = item.get("relationships", {}) + for key in tuple(item_relationships.keys()): + if key not in have_to_be_present: + item_relationships.pop(key) + if not item_relationships: + item.pop("relationships", None) + + for key in set(expected_includes).difference(expected_relationships_inner_relations): + expected_includes.pop(key) + + # XXX + if not expect_user_include: expected_includes.pop("user", None) - for pc in expected_includes["post_comment"]: - pc.pop("relationships", None) - assert included_as_map == expected_includes def prepare_expected_includes( self, - included: list[dict], user_1: User, user_2: User, user_1_posts: list[PostComment], user_2_comment_for_one_u1_post: PostComment, ): - included_as_map = defaultdict(list) - for item in included: - included_as_map[item["type"]].append(item) - expected_includes = { "post": [ # @@ -764,7 +770,7 @@ def prepare_expected_includes( ], } - return included_as_map, expected_includes + return expected_includes async def test_method_not_allowed(app: FastAPI, client: AsyncClient):