Skip to content

Commit 3fae8f0

Browse files
committed
rebase
1 parent 5e25556 commit 3fae8f0

File tree

2 files changed

+198
-46
lines changed

2 files changed

+198
-46
lines changed

fastapi_jsonapi/data_layers/filtering/sqlalchemy.py

Lines changed: 180 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2+
import inspect
3+
import logging
24
from typing import (
35
Any,
46
Callable,
57
Dict,
68
List,
9+
Optional,
710
Set,
11+
Tuple,
812
Type,
913
Union,
1014
)
1115

12-
from pydantic import BaseModel
16+
from pydantic import BaseConfig, BaseModel
1317
from pydantic.fields import ModelField
18+
from pydantic.validators import _VALIDATORS, find_validators
1419
from sqlalchemy import and_, not_, or_
1520
from sqlalchemy.orm import aliased
1621
from sqlalchemy.orm.attributes import InstrumentedAttribute
@@ -19,14 +24,22 @@
1924

2025
from fastapi_jsonapi.data_typing import TypeModel, TypeSchema
2126
from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType
27+
from fastapi_jsonapi.exceptions.json_api import HTTPException
2228
from fastapi_jsonapi.schema import get_model_field, get_relationships
2329

30+
log = logging.getLogger(__name__)
31+
2432
RELATIONSHIP_SPLITTER = "."
2533

34+
# The mapping with validators using by to cast raw value to instance of target type
35+
REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS)
36+
37+
cast_failed = object()
38+
2639
RelationshipPath = str
2740

2841

29-
class RelationshipInfo(BaseModel):
42+
class RelationshipFilteringInfo(BaseModel):
3043
target_schema: Type[TypeSchema]
3144
model: Type[TypeModel]
3245
aliased_model: AliasedClass
@@ -36,6 +49,129 @@ class Config:
3649
arbitrary_types_allowed = True
3750

3851

52+
def check_can_be_none(fields: list[ModelField]) -> bool:
53+
"""
54+
Return True if None is possible value for target field
55+
"""
56+
return any(field_item.allow_none for field_item in fields)
57+
58+
59+
def separate_types(types: List[Type]) -> Tuple[List[Type], List[Type]]:
60+
"""
61+
Separates the types into two kinds.
62+
63+
The first are those for which there are already validators
64+
defined by pydantic - str, int, datetime and some other built-in types.
65+
The second are all other types for which the `arbitrary_types_allowed`
66+
config is applied when defining the pydantic model
67+
"""
68+
pydantic_types = [
69+
# skip format
70+
type_
71+
for type_ in types
72+
if type_ in REGISTERED_PYDANTIC_TYPES
73+
]
74+
userspace_types = [
75+
# skip format
76+
type_
77+
for type_ in types
78+
if type_ not in REGISTERED_PYDANTIC_TYPES
79+
]
80+
return pydantic_types, userspace_types
81+
82+
83+
def validator_requires_model_field(validator: Callable) -> bool:
84+
"""
85+
Check if validator accepts the `field` param
86+
87+
:param validator:
88+
:return:
89+
"""
90+
signature = inspect.signature(validator)
91+
parameters = signature.parameters
92+
93+
if "field" not in parameters:
94+
return False
95+
96+
field_param = parameters["field"]
97+
field_type = field_param.annotation
98+
99+
return field_type == "ModelField" or field_type is ModelField
100+
101+
102+
def cast_value_with_pydantic(
103+
types: List[Type],
104+
value: Any,
105+
schema_field: ModelField,
106+
) -> Tuple[Optional[Any], List[str]]:
107+
result_value, errors = None, []
108+
109+
for type_to_cast in types:
110+
for validator in find_validators(type_to_cast, BaseConfig):
111+
args = [value]
112+
# TODO: some other way to get all the validator's dependencies?
113+
if validator_requires_model_field(validator):
114+
args.append(schema_field)
115+
try:
116+
result_value = validator(*args)
117+
except Exception as ex:
118+
errors.append(str(ex))
119+
else:
120+
return result_value, errors
121+
122+
return None, errors
123+
124+
125+
def cast_iterable_with_pydantic(
126+
types: List[Type],
127+
values: List,
128+
schema_field: ModelField,
129+
) -> Tuple[List, List[str]]:
130+
type_cast_failed = False
131+
failed_values = []
132+
133+
result_values: List[Any] = []
134+
errors: List[str] = []
135+
136+
for value in values:
137+
casted_value, cast_errors = cast_value_with_pydantic(
138+
types,
139+
value,
140+
schema_field,
141+
)
142+
errors.extend(cast_errors)
143+
144+
if casted_value is None:
145+
type_cast_failed = True
146+
failed_values.append(value)
147+
148+
continue
149+
150+
result_values.append(casted_value)
151+
152+
if type_cast_failed:
153+
msg = f"Can't parse items {failed_values} of value {values}"
154+
raise InvalidFilters(msg, pointer=schema_field.name)
155+
156+
return result_values, errors
157+
158+
159+
def cast_value_with_scheme(field_types: List[Type], value: Any) -> Tuple[Any, List[str]]:
160+
errors: List[str] = []
161+
casted_value = cast_failed
162+
163+
for field_type in field_types:
164+
try:
165+
if isinstance(value, list): # noqa: SIM108
166+
casted_value = [field_type(item) for item in value]
167+
else:
168+
casted_value = field_type(value)
169+
except (TypeError, ValueError) as ex:
170+
errors.append(str(ex))
171+
172+
return casted_value, errors
173+
174+
39175
def build_filter_expression(
40176
schema_field: ModelField,
41177
model_column: InstrumentedAttribute,
@@ -61,26 +197,51 @@ def build_filter_expression(
61197
if schema_field.sub_fields:
62198
fields = list(schema_field.sub_fields)
63199

200+
can_be_none = check_can_be_none(fields)
201+
202+
if value is None:
203+
if can_be_none:
204+
return getattr(model_column, operator)(value)
205+
206+
raise InvalidFilters(detail=f"The field `{schema_field.name}` can't be null")
207+
208+
types = [i.type_ for i in fields]
64209
casted_value = None
65210
errors: List[str] = []
66211

67-
for cast_type in [field.type_ for field in fields]:
68-
try:
69-
casted_value = [cast_type(item) for item in value] if isinstance(value, list) else cast_type(value)
70-
except (TypeError, ValueError) as ex:
71-
errors.append(str(ex))
212+
pydantic_types, userspace_types = separate_types(types)
213+
214+
if pydantic_types:
215+
func = cast_value_with_pydantic
216+
if isinstance(value, list):
217+
func = cast_iterable_with_pydantic
218+
casted_value, errors = func(pydantic_types, value, schema_field)
72219

73-
all_fields_required = all(field.required for field in fields)
220+
if casted_value is None and userspace_types:
221+
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
74222

75-
if casted_value is None and all_fields_required:
76-
raise InvalidType(detail=", ".join(errors))
223+
casted_value, errors = cast_value_with_scheme(types, value)
224+
225+
if casted_value is cast_failed:
226+
raise InvalidType(
227+
detail=f"Can't cast filter value `{value}` to arbitrary type.",
228+
errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors],
229+
)
230+
231+
# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
232+
if casted_value is None and not can_be_none:
233+
raise InvalidType(
234+
detail=", ".join(errors),
235+
pointer=schema_field.name,
236+
)
77237

78238
return getattr(model_column, operator)(casted_value)
79239

80240

81241
def is_terminal_node(filter_item: dict) -> bool:
82242
"""
83243
If node shape is:
244+
84245
{
85246
"name: ...,
86247
"op: ...,
@@ -166,7 +327,7 @@ def gather_relationships_info(
166327
relationship_path: List[str],
167328
collected_info: dict,
168329
target_relationship_idx: int = 0,
169-
) -> dict[RelationshipPath, RelationshipInfo]:
330+
) -> dict[RelationshipPath, RelationshipFilteringInfo]:
170331
is_last_relationship = target_relationship_idx == len(relationship_path) - 1
171332
target_relationship_path = RELATIONSHIP_SPLITTER.join(
172333
relationship_path[: target_relationship_idx + 1],
@@ -184,7 +345,7 @@ def gather_relationships_info(
184345
schema,
185346
target_relationship_name,
186347
)
187-
collected_info[target_relationship_path] = RelationshipInfo(
348+
collected_info[target_relationship_path] = RelationshipFilteringInfo(
188349
target_schema=target_schema,
189350
model=target_model,
190351
aliased_model=aliased(target_model),
@@ -207,7 +368,7 @@ def gather_relationships(
207368
entrypoint_model: Type[TypeModel],
208369
schema: Type[TypeSchema],
209370
relationship_paths: Set[str],
210-
) -> dict[RelationshipPath, RelationshipInfo]:
371+
) -> dict[RelationshipPath, RelationshipFilteringInfo]:
211372
collected_info = {}
212373
for relationship_path in sorted(relationship_paths):
213374
gather_relationships_info(
@@ -238,19 +399,22 @@ def build_filter_expressions(
238399
filter_item: Union[dict, list],
239400
target_schema: Type[TypeSchema],
240401
target_model: Type[TypeModel],
241-
relationships_info: dict[RelationshipPath, RelationshipInfo],
402+
relationships_info: dict[RelationshipPath, RelationshipFilteringInfo],
242403
) -> Union[BinaryExpression, BooleanClauseList]:
243404
"""
405+
Return sqla expressions.
406+
244407
Builds sqlalchemy expression which can be use
245408
in where condition: query(Model).where(build_filter_expressions(...))
246409
"""
247410
if is_terminal_node(filter_item):
248411
name = filter_item["name"]
249-
target_schema = target_schema
250412

251413
if is_relationship_filter(name):
252414
*relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER)
253-
relationship_info: RelationshipInfo = relationships_info[RELATIONSHIP_SPLITTER.join(relationship_path)]
415+
relationship_info: RelationshipFilteringInfo = relationships_info[
416+
RELATIONSHIP_SPLITTER.join(relationship_path)
417+
]
254418
model_column = get_model_column(
255419
model=relationship_info.aliased_model,
256420
schema=relationship_info.target_schema,

tests/test_data_layers/test_filtering/test_sqlalchemy.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,41 @@
11
from typing import Any
2-
from unittest.mock import Mock
2+
from unittest.mock import MagicMock, Mock
33

44
from fastapi import status
55
from pydantic import BaseModel
66
from pytest import raises # noqa PT013
77

8-
from fastapi_jsonapi.data_layers.filtering.sqlalchemy import Node
9-
from fastapi_jsonapi.exceptions.json_api import InvalidType
8+
from fastapi_jsonapi.data_layers.filtering.sqlalchemy import (
9+
build_filter_expression,
10+
)
11+
from fastapi_jsonapi.exceptions import InvalidType
1012

1113

12-
class TestNode:
14+
class TestFilteringFuncs:
1315
def test_user_type_cast_success(self):
1416
class UserType:
1517
def __init__(self, *args, **kwargs):
16-
self.value = "success"
18+
pass
1719

1820
class ModelSchema(BaseModel):
19-
user_type: UserType
21+
value: UserType
2022

2123
class Config:
2224
arbitrary_types_allowed = True
2325

24-
node = Node(
25-
model=Mock(),
26-
filter_={
27-
"name": "user_type",
28-
"op": "eq",
29-
"val": Any,
30-
},
31-
schema=ModelSchema,
32-
)
33-
34-
model_column_mock = Mock()
35-
model_column_mock.eq = lambda clear_value: clear_value
26+
model_column_mock = MagicMock()
3627

37-
clear_value = node.create_filter(
38-
schema_field=ModelSchema.__fields__["user_type"],
28+
build_filter_expression(
29+
schema_field=ModelSchema.__fields__["value"],
3930
model_column=model_column_mock,
40-
operator=Mock(),
31+
operator="__eq__",
4132
value=Any,
4233
)
43-
assert isinstance(clear_value, UserType)
44-
assert clear_value.value == "success"
34+
35+
model_column_mock.__eq__.assert_called_once()
36+
37+
call_arg = model_column_mock.__eq__.call_args[0]
38+
isinstance(call_arg, UserType)
4539

4640
def test_user_type_cast_fail(self):
4741
class UserType:
@@ -55,14 +49,8 @@ class ModelSchema(BaseModel):
5549
class Config:
5650
arbitrary_types_allowed = True
5751

58-
node = Node(
59-
model=Mock(),
60-
filter_=Mock(),
61-
schema=ModelSchema,
62-
)
63-
6452
with raises(InvalidType) as exc_info:
65-
node.create_filter(
53+
build_filter_expression(
6654
schema_field=ModelSchema.__fields__["user_type"],
6755
model_column=Mock(),
6856
operator=Mock(),

0 commit comments

Comments
 (0)