Skip to content

Commit 34b90db

Browse files
committed
added tests for userspace type cast
1 parent 6d7fbd1 commit 34b90db

File tree

2 files changed

+88
-5
lines changed

2 files changed

+88
-5
lines changed

fastapi_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2+
import logging
23
from typing import (
34
Any,
45
Callable,
@@ -24,6 +25,8 @@
2425
from fastapi_jsonapi.splitter import SPLIT_REL
2526
from fastapi_jsonapi.utils.sqla import get_related_model_cls
2627

28+
log = logging.getLogger(__name__)
29+
2730
Filter = BinaryExpression
2831
Join = List[Any]
2932

@@ -61,7 +64,7 @@ def __init__(self, model: Type[TypeModel], filter_: dict, schema: Type[TypeSchem
6164
self.filter_ = filter_
6265
self.schema = schema
6366

64-
def create_filter(self, schema_field: ModelField, model_column, operator, value):
67+
def create_filter(self, schema_field: ModelField, model_column, operator, value): # noqa: PLR0912 temporary
6568
"""
6669
Create sqlalchemy filter
6770
:param schema_field:
@@ -101,6 +104,11 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
101104
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
102105

103106
if clear_value is None and userspace_types:
107+
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
108+
109+
cast_failed = object()
110+
clear_value = cast_failed
111+
104112
for i_type in types:
105113
try:
106114
if isinstance(value, list): # noqa: SIM108
@@ -110,6 +118,9 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
110118
except (TypeError, ValueError) as ex:
111119
errors.append(str(ex))
112120

121+
if clear_value is cast_failed:
122+
raise InvalidType(detail=f"Can't cast filter value `{value}` to user type. {', '.join(errors)}")
123+
113124
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
114125
if clear_value is None and not any(not i_f.required for i_f in fields):
115126
raise InvalidType(detail=", ".join(errors))

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
from collections import defaultdict
55
from itertools import chain, zip_longest
66
from json import dumps
7-
from typing import Dict, List, Optional
7+
from typing import Any, Dict, List, Optional
8+
from unittest.mock import Mock
89
from uuid import UUID, uuid4
910

1011
from fastapi import FastAPI, status
1112
from httpx import AsyncClient
1213
from pydantic import BaseModel, Field
13-
from pytest import fixture, mark, param # noqa PT013
14+
from pytest import fixture, mark, param, raises # noqa PT013
1415
from sqlalchemy import select
1516
from sqlalchemy.ext.asyncio import AsyncSession
1617

18+
from fastapi_jsonapi.data_layers.filtering.sqlalchemy import Node
19+
from fastapi_jsonapi.exceptions.json_api import InvalidType
1720
from fastapi_jsonapi.views.view_base import ViewBase
1821
from tests.common import is_postgres_tests
1922
from tests.fixtures.app import build_app_custom
@@ -1235,7 +1238,7 @@ class ContainsTimestampAttrsSchema(BaseModel):
12351238
)
12361239

12371240
create_timestamp = datetime.now(tz=timezone.utc)
1238-
create_user_body = {
1241+
create_body = {
12391242
"data": {
12401243
"attributes": {
12411244
"timestamp": create_timestamp.isoformat(),
@@ -1245,7 +1248,7 @@ class ContainsTimestampAttrsSchema(BaseModel):
12451248

12461249
async with AsyncClient(app=app, base_url="http://test") as client:
12471250
url = app.url_path_for(f"get_{resource_type}_list")
1248-
res = await client.post(url, json=create_user_body)
1251+
res = await client.post(url, json=create_body)
12491252
assert res.status_code == status.HTTP_201_CREATED, res.text
12501253
response_json = res.json()
12511254

@@ -2355,4 +2358,73 @@ async def test_sort(
23552358
}
23562359

23572360

2361+
# TODO: move to it's own test module
2362+
class TestSQLAFilteringModule:
2363+
def test_user_type_cast_success(self):
2364+
class UserType:
2365+
def __init__(self, *args, **kwargs):
2366+
self.value = "success"
2367+
2368+
class ModelSchema(BaseModel):
2369+
user_type: UserType
2370+
2371+
class Config:
2372+
arbitrary_types_allowed = True
2373+
2374+
node = Node(
2375+
model=Mock(),
2376+
filter_={
2377+
"name": "user_type",
2378+
"op": "eq",
2379+
"val": Any,
2380+
},
2381+
schema=ModelSchema,
2382+
)
2383+
2384+
model_column_mock = Mock()
2385+
model_column_mock.eq = lambda clear_value: clear_value
2386+
2387+
clear_value = node.create_filter(
2388+
schema_field=ModelSchema.__fields__["user_type"],
2389+
model_column=model_column_mock,
2390+
operator=Mock(),
2391+
value=Any,
2392+
)
2393+
assert isinstance(clear_value, UserType)
2394+
assert clear_value.value == "success"
2395+
2396+
def test_user_type_cast_fail(self):
2397+
class UserType:
2398+
def __init__(self, *args, **kwargs):
2399+
msg = "Cast failed"
2400+
raise ValueError(msg)
2401+
2402+
class ModelSchema(BaseModel):
2403+
user_type: UserType
2404+
2405+
class Config:
2406+
arbitrary_types_allowed = True
2407+
2408+
node = Node(
2409+
model=Mock(),
2410+
filter_=Mock(),
2411+
schema=ModelSchema,
2412+
)
2413+
2414+
with raises(InvalidType) as exc_info:
2415+
node.create_filter(
2416+
schema_field=ModelSchema.__fields__["user_type"],
2417+
model_column=Mock(),
2418+
operator=Mock(),
2419+
value=Any,
2420+
)
2421+
2422+
assert exc_info.value.as_dict == {
2423+
"detail": "Can't cast filter value `typing.Any` to user type. Cast failed",
2424+
"source": {"pointer": ""},
2425+
"status_code": status.HTTP_409_CONFLICT,
2426+
"title": "Invalid type.",
2427+
}
2428+
2429+
23582430
# todo: test errors

0 commit comments

Comments
 (0)