Skip to content

Commit 238cf7c

Browse files
authored
Merge pull request #62 from mts-ai/fix-type-cast-in-filters
Fix type cast in filters
2 parents 65d9520 + e5dfddb commit 238cf7c

File tree

7 files changed

+314
-13
lines changed

7 files changed

+314
-13
lines changed

fastapi_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
11
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2-
from typing import Any, List, Tuple, Type, Union
3-
4-
from pydantic import BaseModel
2+
import logging
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
List,
8+
Optional,
9+
Tuple,
10+
Type,
11+
Union,
12+
)
13+
14+
from pydantic import BaseConfig, BaseModel
515
from pydantic.fields import ModelField
16+
from pydantic.validators import _VALIDATORS, find_validators
617
from sqlalchemy import and_, not_, or_
718
from sqlalchemy.orm import InstrumentedAttribute, aliased
819
from sqlalchemy.sql.elements import BinaryExpression
920

1021
from fastapi_jsonapi.data_layers.shared import create_filters_or_sorts
1122
from fastapi_jsonapi.data_typing import TypeModel, TypeSchema
1223
from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType
24+
from fastapi_jsonapi.exceptions.json_api import HTTPException
1325
from fastapi_jsonapi.schema import get_model_field, get_relationships
1426
from fastapi_jsonapi.splitter import SPLIT_REL
1527
from fastapi_jsonapi.utils.sqla import get_related_model_cls
1628

29+
log = logging.getLogger(__name__)
30+
1731
Filter = BinaryExpression
1832
Join = List[Any]
1933

@@ -22,6 +36,11 @@
2236
List[Join],
2337
]
2438

39+
# The mapping with validators using by to cast raw value to instance of target type
40+
REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS)
41+
42+
cast_failed = object()
43+
2544

2645
def create_filters(model: Type[TypeModel], filter_info: Union[list, dict], schema: Type[TypeSchema]):
2746
"""
@@ -48,6 +67,21 @@ def __init__(self, model: Type[TypeModel], filter_: dict, schema: Type[TypeSchem
4867
self.filter_ = filter_
4968
self.schema = schema
5069

70+
def _cast_value_with_scheme(self, field_types: List[ModelField], value: Any) -> Tuple[Any, List[str]]:
71+
errors: List[str] = []
72+
casted_value = cast_failed
73+
74+
for field_type in field_types:
75+
try:
76+
if isinstance(value, list): # noqa: SIM108
77+
casted_value = [field_type(item) for item in value]
78+
else:
79+
casted_value = field_type(value)
80+
except (TypeError, ValueError) as ex:
81+
errors.append(str(ex))
82+
83+
return casted_value, errors
84+
5185
def create_filter(self, schema_field: ModelField, model_column, operator, value):
5286
"""
5387
Create sqlalchemy filter
@@ -78,19 +112,94 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
78112
types = [i.type_ for i in fields]
79113
clear_value = None
80114
errors: List[str] = []
81-
for i_type in types:
82-
try:
83-
if isinstance(value, list): # noqa: SIM108
84-
clear_value = [i_type(item) for item in value]
85-
else:
86-
clear_value = i_type(value)
87-
except (TypeError, ValueError) as ex:
88-
errors.append(str(ex))
115+
116+
pydantic_types, userspace_types = self._separate_types(types)
117+
118+
if pydantic_types:
119+
if isinstance(value, list):
120+
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
121+
else:
122+
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
123+
124+
if clear_value is None and userspace_types:
125+
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
126+
127+
clear_value, errors = self._cast_value_with_scheme(types, value)
128+
129+
if clear_value is cast_failed:
130+
raise InvalidType(
131+
detail=f"Can't cast filter value `{value}` to arbitrary type.",
132+
errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors],
133+
)
134+
89135
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
90136
if clear_value is None and not any(not i_f.required for i_f in fields):
91137
raise InvalidType(detail=", ".join(errors))
92138
return getattr(model_column, self.operator)(clear_value)
93139

140+
def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
141+
"""
142+
Separates the types into two kinds. The first are those for which
143+
there are already validators defined by pydantic - str, int, datetime
144+
and some other built-in types. The second are all other types for which
145+
the `arbitrary_types_allowed` config is applied when defining the pydantic model
146+
"""
147+
pydantic_types = [
148+
# skip format
149+
type_
150+
for type_ in types
151+
if type_ in REGISTERED_PYDANTIC_TYPES
152+
]
153+
userspace_types = [
154+
# skip format
155+
type_
156+
for type_ in types
157+
if type_ not in REGISTERED_PYDANTIC_TYPES
158+
]
159+
return pydantic_types, userspace_types
160+
161+
def _cast_value_with_pydantic(
162+
self,
163+
types: List[Type],
164+
value: Any,
165+
) -> Tuple[Optional[Any], List[str]]:
166+
result_value, errors = None, []
167+
168+
for type_to_cast in types:
169+
for validator in find_validators(type_to_cast, BaseConfig):
170+
try:
171+
result_value = validator(value)
172+
return result_value, errors
173+
except Exception as ex:
174+
errors.append(str(ex))
175+
176+
return None, errors
177+
178+
def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
179+
type_cast_failed = False
180+
failed_values = []
181+
182+
result_values: List[Any] = []
183+
errors: List[str] = []
184+
185+
for value in values:
186+
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
187+
errors.extend(cast_errors)
188+
189+
if casted_value is None:
190+
type_cast_failed = True
191+
failed_values.append(value)
192+
193+
continue
194+
195+
result_values.append(casted_value)
196+
197+
if type_cast_failed:
198+
msg = f"Can't parse items {failed_values} of value {values}"
199+
raise InvalidFilters(msg)
200+
201+
return result_values, errors
202+
94203
def resolve(self) -> FilterAndJoins: # noqa: PLR0911
95204
"""Create filter for a particular node of the filter tree"""
96205
if "or" in self.filter_:

tests/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@ def sqla_uri():
88
db_dir = Path(__file__).resolve().parent
99
testing_db_url = f"sqlite+aiosqlite:///{db_dir}/db.sqlite3"
1010
return testing_db_url
11+
12+
13+
def is_postgres_tests() -> bool:
14+
return "postgres" in sqla_uri()

tests/models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING, Dict, List, Optional
22
from uuid import UUID
33

4-
from sqlalchemy import JSON, Column, ForeignKey, Index, Integer, String, Text
4+
from sqlalchemy import JSON, Column, DateTime, ForeignKey, Index, Integer, String, Text
55
from sqlalchemy.ext.declarative import declarative_base
66
from sqlalchemy.orm import declared_attr, relationship
77
from sqlalchemy.types import CHAR, TypeDecorator
@@ -296,3 +296,8 @@ class SelfRelationship(Base):
296296
)
297297
# parent = relationship("SelfRelationship", back_populates="s")
298298
self_relationship = relationship("SelfRelationship", remote_side=[id])
299+
300+
301+
class ContainsTimestamp(Base):
302+
id = Column(Integer, primary_key=True)
303+
timestamp = Column(DateTime(True), nullable=False)

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import json
12
import logging
23
from collections import defaultdict
4+
from datetime import datetime, timezone
35
from itertools import chain, zip_longest
46
from json import dumps
57
from typing import Dict, List
@@ -8,15 +10,18 @@
810
from fastapi import FastAPI, status
911
from httpx import AsyncClient
1012
from pydantic import BaseModel, Field
11-
from pytest import fixture, mark, param # noqa PT013
13+
from pytest import fixture, mark, param, raises # noqa PT013
14+
from sqlalchemy import select
1215
from sqlalchemy.ext.asyncio import AsyncSession
1316

1417
from fastapi_jsonapi.views.view_base import ViewBase
18+
from tests.common import is_postgres_tests
1519
from tests.fixtures.app import build_app_custom
1620
from tests.fixtures.entities import build_workplace, create_user
1721
from tests.misc.utils import fake
1822
from tests.models import (
1923
Computer,
24+
ContainsTimestamp,
2025
IdCast,
2126
Post,
2227
PostComment,
@@ -1215,6 +1220,100 @@ async def test_create_with_relationship_to_the_same_table(self):
12151220
"meta": None,
12161221
}
12171222

1223+
async def test_create_with_timestamp_and_fetch(self, async_session: AsyncSession):
1224+
resource_type = "contains_timestamp_model"
1225+
1226+
class ContainsTimestampAttrsSchema(BaseModel):
1227+
timestamp: datetime
1228+
1229+
app = build_app_custom(
1230+
model=ContainsTimestamp,
1231+
schema=ContainsTimestampAttrsSchema,
1232+
schema_in_post=ContainsTimestampAttrsSchema,
1233+
schema_in_patch=ContainsTimestampAttrsSchema,
1234+
resource_type=resource_type,
1235+
)
1236+
1237+
create_timestamp = datetime.now(tz=timezone.utc)
1238+
create_body = {
1239+
"data": {
1240+
"attributes": {
1241+
"timestamp": create_timestamp.isoformat(),
1242+
},
1243+
},
1244+
}
1245+
1246+
async with AsyncClient(app=app, base_url="http://test") as client:
1247+
url = app.url_path_for(f"get_{resource_type}_list")
1248+
res = await client.post(url, json=create_body)
1249+
assert res.status_code == status.HTTP_201_CREATED, res.text
1250+
response_json = res.json()
1251+
1252+
assert (entity_id := response_json["data"]["id"])
1253+
assert response_json == {
1254+
"meta": None,
1255+
"jsonapi": {"version": "1.0"},
1256+
"data": {
1257+
"type": resource_type,
1258+
"attributes": {"timestamp": create_timestamp.isoformat()},
1259+
"id": entity_id,
1260+
},
1261+
}
1262+
1263+
stms = select(ContainsTimestamp).where(ContainsTimestamp.id == int(entity_id))
1264+
(await async_session.execute(stms)).scalar_one()
1265+
1266+
expected_response_timestamp = create_timestamp.replace(tzinfo=None).isoformat()
1267+
if is_postgres_tests():
1268+
expected_response_timestamp = create_timestamp.replace().isoformat()
1269+
1270+
params = {
1271+
"filter": json.dumps(
1272+
[
1273+
{
1274+
"name": "timestamp",
1275+
"op": "eq",
1276+
"val": create_timestamp.isoformat(),
1277+
},
1278+
],
1279+
),
1280+
}
1281+
1282+
# successfully filtered
1283+
res = await client.get(url, params=params)
1284+
assert res.status_code == status.HTTP_200_OK, res.text
1285+
assert res.json() == {
1286+
"meta": {"count": 1, "totalPages": 1},
1287+
"jsonapi": {"version": "1.0"},
1288+
"data": [
1289+
{
1290+
"type": resource_type,
1291+
"attributes": {"timestamp": expected_response_timestamp},
1292+
"id": entity_id,
1293+
},
1294+
],
1295+
}
1296+
1297+
# check filter really work
1298+
params = {
1299+
"filter": json.dumps(
1300+
[
1301+
{
1302+
"name": "timestamp",
1303+
"op": "eq",
1304+
"val": datetime.now(tz=timezone.utc).isoformat(),
1305+
},
1306+
],
1307+
),
1308+
}
1309+
res = await client.get(url, params=params)
1310+
assert res.status_code == status.HTTP_200_OK, res.text
1311+
assert res.json() == {
1312+
"meta": {"count": 0, "totalPages": 1},
1313+
"jsonapi": {"version": "1.0"},
1314+
"data": [],
1315+
}
1316+
12181317

12191318
class TestPatchObjects:
12201319
async def test_patch_object(

tests/test_data_layers/__init__.py

Whitespace-only changes.

tests/test_data_layers/test_filtering/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)