Skip to content

Commit 80df7fa

Browse files
committed
added test case and base logic
1 parent 7015aa1 commit 80df7fa

File tree

6 files changed

+131
-26
lines changed

6 files changed

+131
-26
lines changed

fastapi_jsonapi/querystring.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Helper to deal with querystring parameters according to jsonapi specification."""
2+
from collections import defaultdict
23
from functools import cached_property
34
from typing import (
45
TYPE_CHECKING,
@@ -22,17 +23,16 @@
2223
)
2324
from starlette.datastructures import QueryParams
2425

26+
from fastapi_jsonapi.api import RoutersJSONAPI
2527
from fastapi_jsonapi.exceptions import (
2628
BadRequest,
27-
InvalidField,
2829
InvalidFilters,
2930
InvalidInclude,
3031
InvalidSort,
3132
)
3233
from fastapi_jsonapi.schema import (
3334
get_model_field,
3435
get_relationships,
35-
get_schema_from_type,
3636
)
3737
from fastapi_jsonapi.splitter import SPLIT_REL
3838

@@ -97,7 +97,7 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]:
9797
:return: a dict of key / values items
9898
:raises BadRequest: if an error occurred while parsing the querystring.
9999
"""
100-
results: Dict[str, Union[List[str], str]] = {}
100+
results = defaultdict(set)
101101

102102
for raw_key, value in self.qs.multi_items():
103103
key = unquote(raw_key)
@@ -109,10 +109,7 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]:
109109
key_end = key.index("]")
110110
item_key = key[key_start:key_end]
111111

112-
if "," in value:
113-
results.update({item_key: value.split(",")})
114-
else:
115-
results.update({item_key: value})
112+
results[item_key].update(value.split(","))
116113
except Exception:
117114
msg = "Parse error"
118115
raise BadRequest(msg, parameter=key)
@@ -216,27 +213,28 @@ def fields(self) -> Dict[str, List[str]]:
216213
217214
:raises InvalidField: if result field not in schema.
218215
"""
219-
if self.request.method != "GET":
220-
msg = "attribute 'fields' allowed only for GET-method"
221-
raise InvalidField(msg)
222216
fields = self._get_key_values("fields")
223-
for key, value in fields.items():
224-
if not isinstance(value, list):
225-
value = [value] # noqa: PLW2901
226-
fields[key] = value
217+
for resource_type, field_names in fields.items():
227218
# TODO: we have registry for models (BaseModel)
228219
# TODO: create `type to schemas` registry
229-
schema: Type[BaseModel] = get_schema_from_type(key, self.app)
230-
for field in value:
231-
if field not in schema.__fields__:
232-
msg = "{schema} has no attribute {field}".format(
233-
schema=schema.__name__,
234-
field=field,
235-
)
236-
raise InvalidField(msg)
220+
221+
# schema: Type[BaseModel] = get_schema_from_type(key, self.app)
222+
self._get_schema(resource_type)
223+
224+
# for field_name in field_names:
225+
# if field_name not in schema.__fields__:
226+
# msg = "{schema} has no attribute {field}".format(
227+
# schema=schema.__name__,
228+
# field=field_name,
229+
# )
230+
# raise InvalidField(msg)
237231

238232
return fields
239233

234+
def _get_schema(self, resource_type: str) -> Type[BaseModel]:
235+
target_router = RoutersJSONAPI.all_jsonapi_routers[resource_type]
236+
return target_router.detail_response_schema
237+
240238
def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]:
241239
"""
242240
Return fields to sort by including sort name for SQLAlchemy and row sort parameter for other ORMs.

fastapi_jsonapi/schema_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def create_jsonapi_object_schemas(
484484
base_name: str = "",
485485
compute_included_schemas: bool = False,
486486
use_schema_cache: bool = True,
487+
exclude_attributes: Optional[List[str]] = None,
487488
) -> JSONAPIObjectSchemas:
488489
if use_schema_cache and schema in self.object_schemas_cache and includes is not_passed:
489490
return self.object_schemas_cache[schema]

fastapi_jsonapi/views/list_view.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
JSONAPIResultDetailSchema,
77
JSONAPIResultListSchema,
88
)
9+
from fastapi_jsonapi.views.utils import get_includes_indexes_by_type
910
from fastapi_jsonapi.views.view_base import ViewBase
1011

1112
if TYPE_CHECKING:
@@ -14,6 +15,31 @@
1415
logger = logging.getLogger(__name__)
1516

1617

18+
def calculate_include_fields(response, query_params, jsonapi) -> Dict:
19+
included = "included" in response.__fields__ and response.included or []
20+
21+
include_params = {
22+
field_name: {*response.__fields__[field_name].type_.__fields__.keys()}
23+
for field_name in response.__fields__
24+
if field_name
25+
}
26+
include_params["included"] = {}
27+
28+
includes_indexes_by_type = get_includes_indexes_by_type(included)
29+
30+
for resource_type, field_names in query_params.fields.items():
31+
if resource_type == jsonapi.type_:
32+
include_params["data"] = {"__all__": {"attributes": field_names, "id": {"id"}, "type": {"type"}}}
33+
continue
34+
35+
target_type_indexes = includes_indexes_by_type.get(resource_type)
36+
37+
if resource_type in includes_indexes_by_type and target_type_indexes:
38+
include_params["included"].update((idx, field_names) for idx in target_type_indexes)
39+
40+
return include_params
41+
42+
1743
class ListViewBase(ViewBase):
1844
def _calculate_total_pages(self, db_items_count: int) -> int:
1945
total_pages = 1
@@ -40,7 +66,17 @@ async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultList
4066
count, items_from_db = await dl.get_collection(qs=query_params)
4167
total_pages = self._calculate_total_pages(count)
4268

43-
return self._build_list_response(items_from_db, count, total_pages)
69+
response = self._build_list_response(items_from_db, count, total_pages)
70+
71+
if not query_params.fields:
72+
return response
73+
74+
include_params = calculate_include_fields(response, query_params, self.jsonapi)
75+
76+
if include_params:
77+
return response.dict(include=include_params)
78+
79+
return response
4480

4581
async def handle_post_resource_list(
4682
self,

fastapi_jsonapi/views/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import defaultdict
12
from enum import Enum
23
from functools import cache
3-
from typing import Callable, Coroutine, Optional, Set, Type, Union
4+
from typing import Callable, Coroutine, Dict, List, Optional, Set, Type, Union
45

56
from pydantic import BaseModel
67

@@ -27,3 +28,12 @@ class Config:
2728
@property
2829
def handler(self) -> Optional[Union[Callable, Coroutine]]:
2930
return self.prepare_data_layer_kwargs
31+
32+
33+
def get_includes_indexes_by_type(included: List[Dict]) -> Dict[str, List[int]]:
34+
result = defaultdict(list)
35+
36+
for idx, item in enumerate(included, 1):
37+
result[item["type"]].append(idx)
38+
39+
return result

fastapi_jsonapi/views/view_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from fastapi_jsonapi.schema import (
3131
JSONAPIObjectSchema,
3232
JSONAPIResultListMetaSchema,
33+
JSONAPIResultListSchema,
3334
get_related_schema,
3435
)
3536
from fastapi_jsonapi.schema_base import BaseModel, RelationshipInfo
@@ -185,7 +186,12 @@ def _build_detail_response(self, db_item: TypeModel):
185186

186187
return detail_jsonapi_schema(data=result_object, **extras)
187188

188-
def _build_list_response(self, items_from_db: List[TypeModel], count: int, total_pages: int):
189+
def _build_list_response(
190+
self,
191+
items_from_db: List[TypeModel],
192+
count: int,
193+
total_pages: int,
194+
) -> JSONAPIResultListSchema:
189195
result_objects, object_schemas, extras = self._build_response(items_from_db, self.jsonapi.schema_list)
190196

191197
# we need to build a new schema here

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime, timezone
55
from itertools import chain, zip_longest
66
from json import dumps, loads
7-
from typing import Dict, List, Literal
7+
from typing import Dict, List, Literal, Set, Tuple
88
from uuid import UUID, uuid4
99

1010
import pytest
@@ -16,6 +16,7 @@
1616
from sqlalchemy import func, select
1717
from sqlalchemy.ext.asyncio import AsyncSession
1818
from sqlalchemy.orm import InstrumentedAttribute
19+
from starlette.datastructures import QueryParams
1920

2021
from fastapi_jsonapi.views.view_base import ViewBase
2122
from tests.common import is_postgres_tests
@@ -151,6 +152,59 @@ async def test_get_users_paginated(
151152
"meta": {"count": 2, "totalPages": 2},
152153
}
153154

155+
@mark.parametrize(
156+
"fields, expected_include",
157+
[
158+
param(
159+
[
160+
("fields[user]", "name,age"),
161+
],
162+
{"name", "age"},
163+
),
164+
param(
165+
[
166+
("fields[user]", "name,age"),
167+
("fields[user]", "email"),
168+
],
169+
{"name", "age", "email"},
170+
),
171+
],
172+
)
173+
async def test_select_custom_fields(
174+
self,
175+
app: FastAPI,
176+
client: AsyncClient,
177+
user_1: User,
178+
user_2: User,
179+
fields: List[Tuple[str, str]],
180+
expected_include: Set[str],
181+
):
182+
url = app.url_path_for("get_user_list")
183+
user_1, user_2 = sorted((user_1, user_2), key=lambda x: x.id)
184+
185+
params = QueryParams(fields)
186+
response = await client.get(url, params=str(params))
187+
188+
assert response.status_code == status.HTTP_200_OK, response.text
189+
response_data = response.json()
190+
191+
assert response_data == {
192+
"data": [
193+
{
194+
"attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include=expected_include),
195+
"id": str(user_1.id),
196+
"type": "user",
197+
},
198+
{
199+
"attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include=expected_include),
200+
"id": str(user_2.id),
201+
"type": "user",
202+
},
203+
],
204+
"jsonapi": {"version": "1.0"},
205+
"meta": {"count": 2, "total_pages": 1},
206+
}
207+
154208

155209
class TestCreatePostAndComments:
156210
async def test_get_posts_with_users(

0 commit comments

Comments
 (0)