Skip to content

Commit cfe650f

Browse files
NatalyaGrigorevaNatalia Grigoreva
authored andcommitted
fix dict pydantic fields
1 parent 653c827 commit cfe650f

File tree

4 files changed

+64
-20
lines changed

4 files changed

+64
-20
lines changed

fastapi_jsonapi/signature.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import logging
55
from enum import Enum
66
from inspect import Parameter
7-
from typing import Optional
7+
from typing import Any, Optional, Type, Union, get_args, get_origin
88

99
from fastapi import Query
1010

11+
# noinspection PyProtectedMember
12+
from fastapi._compat import field_annotation_is_scalar, field_annotation_is_sequence
13+
from fastapi.types import UnionType
14+
1115
# noinspection PyProtectedMember
1216
from pydantic.fields import FieldInfo
1317

@@ -17,6 +21,23 @@
1721
log = logging.getLogger(__name__)
1822

1923

24+
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
25+
origin = get_origin(annotation)
26+
if origin is Union or origin is UnionType:
27+
at_least_one_scalar_sequence = False
28+
for arg in get_args(annotation):
29+
if field_annotation_is_scalar_sequence(arg):
30+
at_least_one_scalar_sequence = True
31+
continue
32+
elif not field_annotation_is_scalar(arg):
33+
return False
34+
return at_least_one_scalar_sequence
35+
return (
36+
field_annotation_is_sequence(annotation)
37+
and all(field_annotation_is_scalar(sub_annotation) for sub_annotation in get_args(annotation))
38+
) or field_annotation_is_scalar(annotation)
39+
40+
2041
def create_filter_parameter(
2142
name: str,
2243
field: FieldInfo,
@@ -30,6 +51,9 @@ def create_filter_parameter(
3051
):
3152
default = Query(None, alias=query_filter_name, enum=list(field.annotation))
3253
type_field = str
54+
elif not field_annotation_is_scalar_sequence(field.annotation):
55+
default = Query(None, alias=query_filter_name)
56+
type_field = str
3357
else:
3458
default = Query(None, alias=query_filter_name)
3559
type_field = field.annotation

tests/fixtures/models/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
class Task(Base):
1010
__tablename__ = "tasks"
1111

12-
task_ids: Mapped[Optional[list]] = mapped_column(JSON, unique=False)
12+
task_ids_dict: Mapped[Optional[dict]] = mapped_column(JSON, unique=False)
13+
task_ids_list: Mapped[Optional[list]] = mapped_column(JSON, unique=False)

tests/fixtures/schemas/task.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@ class TaskBaseSchema(BaseModel):
1212
from_attributes=True,
1313
)
1414

15-
task_ids: Optional[list[str]] = None
15+
task_ids_dict: Optional[dict[str, list]] = None
16+
task_ids_list: Optional[list] = None
1617

1718
# noinspection PyMethodParameters
18-
@field_validator("task_ids", mode="before", check_fields=False)
19+
@field_validator("task_ids_dict", mode="before", check_fields=False)
1920
@classmethod
20-
def task_ids_validator(cls, value: Optional[list[str]]):
21+
def task_ids_dict_validator(cls, value: Optional[dict[str, list]]):
22+
"""
23+
return `{}`, if value is None both on get and on create
24+
"""
25+
return value or {}
26+
27+
# noinspection PyMethodParameters
28+
@field_validator("task_ids_list", mode="before", check_fields=False)
29+
@classmethod
30+
def task_ids_list_validator(cls, value: Optional[list]):
2131
"""
2232
return `[]`, if value is None both on get and on create
2333
"""

tests/test_api/test_validators.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
async def task_with_none_ids(
2525
async_session: AsyncSession,
2626
) -> Task:
27-
task = Task(task_ids=None)
27+
task = Task(
28+
task_ids_dict=None,
29+
task_ids_list=None,
30+
)
2831
async_session.add(task)
2932
await async_session.commit()
3033

@@ -44,7 +47,8 @@ async def test_base_model_validator_pre_true_get_one(
4447
resource_type: str,
4548
task_with_none_ids: Task,
4649
):
47-
assert task_with_none_ids.task_ids is None
50+
assert task_with_none_ids.task_ids_dict is None
51+
assert task_with_none_ids.task_ids_list is None
4852
url = app.url_path_for(f"get_{resource_type}_detail", obj_id=task_with_none_ids.id)
4953
res = await client.get(url)
5054
assert res.status_code == status.HTTP_200_OK, res.text
@@ -59,20 +63,22 @@ async def test_base_model_validator_pre_true_get_one(
5963
"meta": None,
6064
}
6165
assert attributes == {
62-
# not `None`! schema validator returns empty list `[]`
66+
# not `None`! schema validator returns empty dict `{}` and empty list `[]`
6367
# "task_ids": None,
64-
"task_ids": [],
68+
"task_ids_dict": {},
69+
"task_ids_list": [],
6570
}
6671
assert attributes == TaskBaseSchema.model_validate(task_with_none_ids).model_dump()
6772

68-
async def test_base_model_model_validator_get_list(
73+
async def test_base_model_model_validator_get_list_and_dict(
6974
self,
7075
app: FastAPI,
7176
client: AsyncClient,
7277
resource_type: str,
7378
task_with_none_ids: Task,
7479
):
75-
assert task_with_none_ids.task_ids is None
80+
assert task_with_none_ids.task_ids_dict is None
81+
assert task_with_none_ids.task_ids_list is None
7682
url = app.url_path_for(f"get_{resource_type}_list")
7783
res = await client.get(url)
7884
assert res.status_code == status.HTTP_200_OK, res.text
@@ -83,9 +89,10 @@ async def test_base_model_model_validator_get_list(
8389
"id": f"{task_with_none_ids.id}",
8490
"type": resource_type,
8591
"attributes": {
86-
# not `None`! schema validator returns empty list `[]`
92+
# not `None`! schema validator returns empty dict `{}` and empty list `[]`
8793
# "task_ids": None,
88-
"task_ids": [],
94+
"task_ids_dict": {},
95+
"task_ids_list": [],
8996
},
9097
},
9198
],
@@ -109,8 +116,9 @@ async def test_base_model_model_validator_create(
109116
"data": {
110117
"type": resource_type,
111118
"attributes": {
112-
# should be converted to [] by schema on create
113-
"task_ids": None,
119+
# should be converted to [] and {} by schema on create
120+
"task_ids_dict": None,
121+
"task_ids_list": None,
114122
},
115123
},
116124
}
@@ -121,16 +129,17 @@ async def test_base_model_model_validator_create(
121129
task_id = response_data["data"].pop("id")
122130
task = await async_session.get(Task, int(task_id))
123131
assert isinstance(task, Task)
124-
assert task.task_ids == []
125-
# we sent request with `None`, but value in db is `[]`
132+
# we sent request with `None`, but value in db is `[]` and `{}`
126133
# because validator converted data before object creation
127-
assert task.task_ids == []
134+
assert task.task_ids_dict == {}
135+
assert task.task_ids_list == []
128136
assert response_data == {
129137
"data": {
130138
"type": resource_type,
131139
"attributes": {
132-
# should be empty list
133-
"task_ids": [],
140+
# should be empty list and empty dict
141+
"task_ids_dict": {},
142+
"task_ids_list": [],
134143
},
135144
},
136145
"jsonapi": {"version": "1.0"},

0 commit comments

Comments
 (0)