11"""Helper to create sqlalchemy filters according to filter querystring parameter"""
2+ import inspect
3+ import logging
24from typing import (
35 Any ,
46 Callable ,
57 Dict ,
68 List ,
9+ Optional ,
710 Set ,
11+ Tuple ,
812 Type ,
913 Union ,
1014)
1115
12- from pydantic import BaseModel
16+ from pydantic import BaseConfig , BaseModel
1317from pydantic .fields import ModelField
18+ from pydantic .validators import _VALIDATORS , find_validators
1419from sqlalchemy import and_ , not_ , or_
1520from sqlalchemy .orm import aliased
1621from sqlalchemy .orm .attributes import InstrumentedAttribute
1924
2025from fastapi_jsonapi .data_typing import TypeModel , TypeSchema
2126from fastapi_jsonapi .exceptions import InvalidFilters , InvalidType
27+ from fastapi_jsonapi .exceptions .json_api import HTTPException
2228from fastapi_jsonapi .schema import get_model_field , get_relationships
2329
30+ log = logging .getLogger (__name__ )
31+
2432RELATIONSHIP_SPLITTER = "."
2533
34+ # The mapping with validators using by to cast raw value to instance of target type
35+ REGISTERED_PYDANTIC_TYPES : Dict [Type , List [Callable ]] = dict (_VALIDATORS )
36+
37+ cast_failed = object ()
38+
2639RelationshipPath = str
2740
2841
29- class RelationshipInfo (BaseModel ):
42+ class RelationshipFilteringInfo (BaseModel ):
3043 target_schema : Type [TypeSchema ]
3144 model : Type [TypeModel ]
3245 aliased_model : AliasedClass
@@ -36,6 +49,129 @@ class Config:
3649 arbitrary_types_allowed = True
3750
3851
52+ def check_can_be_none (fields : list [ModelField ]) -> bool :
53+ """
54+ Return True if None is possible value for target field
55+ """
56+ return any (field_item .allow_none for field_item in fields )
57+
58+
59+ def separate_types (types : List [Type ]) -> Tuple [List [Type ], List [Type ]]:
60+ """
61+ Separates the types into two kinds.
62+
63+ The first are those for which there are already validators
64+ defined by pydantic - str, int, datetime and some other built-in types.
65+ The second are all other types for which the `arbitrary_types_allowed`
66+ config is applied when defining the pydantic model
67+ """
68+ pydantic_types = [
69+ # skip format
70+ type_
71+ for type_ in types
72+ if type_ in REGISTERED_PYDANTIC_TYPES
73+ ]
74+ userspace_types = [
75+ # skip format
76+ type_
77+ for type_ in types
78+ if type_ not in REGISTERED_PYDANTIC_TYPES
79+ ]
80+ return pydantic_types , userspace_types
81+
82+
83+ def validator_requires_model_field (validator : Callable ) -> bool :
84+ """
85+ Check if validator accepts the `field` param
86+
87+ :param validator:
88+ :return:
89+ """
90+ signature = inspect .signature (validator )
91+ parameters = signature .parameters
92+
93+ if "field" not in parameters :
94+ return False
95+
96+ field_param = parameters ["field" ]
97+ field_type = field_param .annotation
98+
99+ return field_type == "ModelField" or field_type is ModelField
100+
101+
102+ def cast_value_with_pydantic (
103+ types : List [Type ],
104+ value : Any ,
105+ schema_field : ModelField ,
106+ ) -> Tuple [Optional [Any ], List [str ]]:
107+ result_value , errors = None , []
108+
109+ for type_to_cast in types :
110+ for validator in find_validators (type_to_cast , BaseConfig ):
111+ args = [value ]
112+ # TODO: some other way to get all the validator's dependencies?
113+ if validator_requires_model_field (validator ):
114+ args .append (schema_field )
115+ try :
116+ result_value = validator (* args )
117+ except Exception as ex :
118+ errors .append (str (ex ))
119+ else :
120+ return result_value , errors
121+
122+ return None , errors
123+
124+
125+ def cast_iterable_with_pydantic (
126+ types : List [Type ],
127+ values : List ,
128+ schema_field : ModelField ,
129+ ) -> Tuple [List , List [str ]]:
130+ type_cast_failed = False
131+ failed_values = []
132+
133+ result_values : List [Any ] = []
134+ errors : List [str ] = []
135+
136+ for value in values :
137+ casted_value , cast_errors = cast_value_with_pydantic (
138+ types ,
139+ value ,
140+ schema_field ,
141+ )
142+ errors .extend (cast_errors )
143+
144+ if casted_value is None :
145+ type_cast_failed = True
146+ failed_values .append (value )
147+
148+ continue
149+
150+ result_values .append (casted_value )
151+
152+ if type_cast_failed :
153+ msg = f"Can't parse items { failed_values } of value { values } "
154+ raise InvalidFilters (msg , pointer = schema_field .name )
155+
156+ return result_values , errors
157+
158+
159+ def cast_value_with_scheme (field_types : List [Type ], value : Any ) -> Tuple [Any , List [str ]]:
160+ errors : List [str ] = []
161+ casted_value = cast_failed
162+
163+ for field_type in field_types :
164+ try :
165+ if isinstance (value , list ): # noqa: SIM108
166+ casted_value = [field_type (item ) for item in value ]
167+ else :
168+ casted_value = field_type (value )
169+ except (TypeError , ValueError ) as ex :
170+ errors .append (str (ex ))
171+
172+ return casted_value , errors
173+
174+
39175def build_filter_expression (
40176 schema_field : ModelField ,
41177 model_column : InstrumentedAttribute ,
@@ -61,26 +197,51 @@ def build_filter_expression(
61197 if schema_field .sub_fields :
62198 fields = list (schema_field .sub_fields )
63199
200+ can_be_none = check_can_be_none (fields )
201+
202+ if value is None :
203+ if can_be_none :
204+ return getattr (model_column , operator )(value )
205+
206+ raise InvalidFilters (detail = f"The field `{ schema_field .name } ` can't be null" )
207+
208+ types = [i .type_ for i in fields ]
64209 casted_value = None
65210 errors : List [str ] = []
66211
67- for cast_type in [field .type_ for field in fields ]:
68- try :
69- casted_value = [cast_type (item ) for item in value ] if isinstance (value , list ) else cast_type (value )
70- except (TypeError , ValueError ) as ex :
71- errors .append (str (ex ))
212+ pydantic_types , userspace_types = separate_types (types )
213+
214+ if pydantic_types :
215+ func = cast_value_with_pydantic
216+ if isinstance (value , list ):
217+ func = cast_iterable_with_pydantic
218+ casted_value , errors = func (pydantic_types , value , schema_field )
72219
73- all_fields_required = all (field .required for field in fields )
220+ if casted_value is None and userspace_types :
221+ log .warning ("Filtering by user type values is not properly tested yet. Use this on your own risk." )
74222
75- if casted_value is None and all_fields_required :
76- raise InvalidType (detail = ", " .join (errors ))
223+ casted_value , errors = cast_value_with_scheme (types , value )
224+
225+ if casted_value is cast_failed :
226+ raise InvalidType (
227+ detail = f"Can't cast filter value `{ value } ` to arbitrary type." ,
228+ errors = [HTTPException (status_code = InvalidType .status_code , detail = str (err )) for err in errors ],
229+ )
230+
231+ # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
232+ if casted_value is None and not can_be_none :
233+ raise InvalidType (
234+ detail = ", " .join (errors ),
235+ pointer = schema_field .name ,
236+ )
77237
78238 return getattr (model_column , operator )(casted_value )
79239
80240
81241def is_terminal_node (filter_item : dict ) -> bool :
82242 """
83243 If node shape is:
244+
84245 {
85246 "name: ...,
86247 "op: ...,
@@ -166,7 +327,7 @@ def gather_relationships_info(
166327 relationship_path : List [str ],
167328 collected_info : dict ,
168329 target_relationship_idx : int = 0 ,
169- ) -> dict [RelationshipPath , RelationshipInfo ]:
330+ ) -> dict [RelationshipPath , RelationshipFilteringInfo ]:
170331 is_last_relationship = target_relationship_idx == len (relationship_path ) - 1
171332 target_relationship_path = RELATIONSHIP_SPLITTER .join (
172333 relationship_path [: target_relationship_idx + 1 ],
@@ -184,7 +345,7 @@ def gather_relationships_info(
184345 schema ,
185346 target_relationship_name ,
186347 )
187- collected_info [target_relationship_path ] = RelationshipInfo (
348+ collected_info [target_relationship_path ] = RelationshipFilteringInfo (
188349 target_schema = target_schema ,
189350 model = target_model ,
190351 aliased_model = aliased (target_model ),
@@ -207,7 +368,7 @@ def gather_relationships(
207368 entrypoint_model : Type [TypeModel ],
208369 schema : Type [TypeSchema ],
209370 relationship_paths : Set [str ],
210- ) -> dict [RelationshipPath , RelationshipInfo ]:
371+ ) -> dict [RelationshipPath , RelationshipFilteringInfo ]:
211372 collected_info = {}
212373 for relationship_path in sorted (relationship_paths ):
213374 gather_relationships_info (
@@ -238,19 +399,22 @@ def build_filter_expressions(
238399 filter_item : Union [dict , list ],
239400 target_schema : Type [TypeSchema ],
240401 target_model : Type [TypeModel ],
241- relationships_info : dict [RelationshipPath , RelationshipInfo ],
402+ relationships_info : dict [RelationshipPath , RelationshipFilteringInfo ],
242403) -> Union [BinaryExpression , BooleanClauseList ]:
243404 """
405+ Return sqla expressions.
406+
244407 Builds sqlalchemy expression which can be use
245408 in where condition: query(Model).where(build_filter_expressions(...))
246409 """
247410 if is_terminal_node (filter_item ):
248411 name = filter_item ["name" ]
249- target_schema = target_schema
250412
251413 if is_relationship_filter (name ):
252414 * relationship_path , field_name = name .split (RELATIONSHIP_SPLITTER )
253- relationship_info : RelationshipInfo = relationships_info [RELATIONSHIP_SPLITTER .join (relationship_path )]
415+ relationship_info : RelationshipFilteringInfo = relationships_info [
416+ RELATIONSHIP_SPLITTER .join (relationship_path )
417+ ]
254418 model_column = get_model_column (
255419 model = relationship_info .aliased_model ,
256420 schema = relationship_info .target_schema ,
0 commit comments