11import inspect
22from collections import defaultdict
33from logging import getLogger
4- from typing import Any , Awaitable , Callable , Dict , Optional , Union
4+ from typing import Any , Awaitable , Callable , Dict , Optional , TypeVar , get_type_hints
55
66import pydantic
77from aiohttp import web
8- from pydantic . utils import deep_update
8+ from deepmerge import always_merger
99from taskiq_dependencies import DependencyGraph
1010
1111from aiohttp_deps .initializer import InjectableFuncHandler , InjectableViewHandler
1212from aiohttp_deps .utils import Form , Header , Json , Path , Query
1313
14- REF_TEMPLATE = "#/components/schemas/{model}"
14+ _T = TypeVar ("_T" ) # noqa: WPS111
15+
1516SCHEMA_KEY = "openapi_schema"
1617SWAGGER_HTML_TEMPALTE = """
1718<html lang="en">
@@ -67,19 +68,14 @@ def _is_optional(annotation: Optional[inspect.Parameter]) -> bool:
6768 if annotation is None or annotation .annotation == annotation .empty :
6869 return True
6970
70- origin = getattr (annotation .annotation , "__origin__" , None )
71- if origin is None :
72- return False
71+ def dummy (_var : annotation .annotation ) -> None : # type: ignore
72+ """Dummy function to use for type resolution."""
7373
74- if origin == Union :
75- args = getattr (annotation .annotation , "__args__" , ())
76- for arg in args :
77- if arg is type (None ): # noqa: E721, WPS516
78- return True
79- return False
74+ var = get_type_hints (dummy ).get ("_var" )
75+ return var == Optional [var ]
8076
8177
82- def _add_route_def ( # noqa: C901
78+ def _add_route_def ( # noqa: C901, WPS210
8379 openapi_schema : Dict [str , Any ],
8480 route : web .ResourceRoute ,
8581 method : str ,
@@ -94,6 +90,19 @@ def _add_route_def( # noqa: C901
9490 if route .resource is None : # pragma: no cover
9591 return
9692
93+ params : Dict [tuple [str , str ], Any ] = {}
94+
95+ def _insert_in_params (data : Dict [str , Any ]) -> None :
96+ element = params .get ((data ["name" ], data ["in" ]))
97+ if element is None :
98+ params [(data ["name" ], data ["in" ])] = data
99+ return
100+ element ["required" ] = element .get ("required" ) or data .get ("required" )
101+ element ["allowEmptyValue" ] = bool (element .get ("allowEmptyValue" )) and bool (
102+ data .get ("allowEmptyValue" ),
103+ )
104+ params [(data ["name" ], data ["in" ])] = element
105+
97106 for dependency in graph .ordered_deps :
98107 if isinstance (dependency .dependency , (Json , Form )):
99108 content_type = "application/json"
@@ -105,9 +114,7 @@ def _add_route_def( # noqa: C901
105114 ):
106115 input_schema = pydantic .TypeAdapter (
107116 dependency .signature .annotation ,
108- ).json_schema (
109- ref_template = REF_TEMPLATE ,
110- )
117+ ).json_schema ()
111118 openapi_schema ["components" ]["schemas" ].update (
112119 input_schema .pop ("definitions" , {}),
113120 )
@@ -119,7 +126,7 @@ def _add_route_def( # noqa: C901
119126 "content" : {content_type : {}},
120127 }
121128 elif isinstance (dependency .dependency , Query ):
122- route_info [ "parameters" ]. append (
129+ _insert_in_params (
123130 {
124131 "name" : dependency .dependency .alias or dependency .param_name ,
125132 "in" : "query" ,
@@ -128,16 +135,17 @@ def _add_route_def( # noqa: C901
128135 },
129136 )
130137 elif isinstance (dependency .dependency , Header ):
131- route_info ["parameters" ].append (
138+ name = dependency .dependency .alias or dependency .param_name
139+ _insert_in_params (
132140 {
133- "name" : dependency . dependency . alias or dependency . param_name ,
141+ "name" : name . capitalize () ,
134142 "in" : "header" ,
135143 "description" : dependency .dependency .description ,
136144 "required" : not _is_optional (dependency .signature ),
137145 },
138146 )
139147 elif isinstance (dependency .dependency , Path ):
140- route_info [ "parameters" ]. append (
148+ _insert_in_params (
141149 {
142150 "name" : dependency .dependency .alias or dependency .param_name ,
143151 "in" : "path" ,
@@ -147,8 +155,9 @@ def _add_route_def( # noqa: C901
147155 },
148156 )
149157
158+ route_info ["parameters" ] = list (params .values ())
150159 openapi_schema ["paths" ][route .resource .canonical ].update (
151- {method .lower (): deep_update (route_info , extra_openapi )},
160+ {method .lower (): always_merger . merge (route_info , extra_openapi )},
152161 )
153162
154163
@@ -265,7 +274,7 @@ async def event_handler(app: web.Application) -> None:
265274 return event_handler
266275
267276
268- def extra_openapi (additional_schema : Dict [str , Any ]) -> Callable [..., Any ]:
277+ def extra_openapi (additional_schema : Dict [str , Any ]) -> Callable [[ _T ], _T ]:
269278 """
270279 Add extra openapi schema.
271280
@@ -276,8 +285,46 @@ def extra_openapi(additional_schema: Dict[str, Any]) -> Callable[..., Any]:
276285 :return: same function with new attributes.
277286 """
278287
279- def decorator (func : Any ) -> Any :
280- func .__extra_openapi__ = additional_schema
288+ def decorator (func : _T ) -> _T :
289+ func .__extra_openapi__ = additional_schema # type: ignore
290+ return func
291+
292+ return decorator
293+
294+
295+ def openapi_response (
296+ status : int ,
297+ model : Any ,
298+ * ,
299+ content_type : str = "application/json" ,
300+ description : Optional [str ] = None ,
301+ ) -> Callable [[_T ], _T ]:
302+ """
303+ Add response schema to the endpoint.
304+
305+ This function takes a status and model,
306+ which is going to represent the response.
307+
308+ :param status: Status of a response.
309+ :param model: Response model.
310+ :param content_type: Content-type of a response.
311+ :param description: Response's description.
312+
313+ :returns: decorator that modifies your function.
314+ """
315+
316+ def decorator (func : _T ) -> _T :
317+ openapi = getattr (func , "__extra_openapi__" , {})
318+ adapter : "pydantic.TypeAdapter[Any]" = pydantic .TypeAdapter (model )
319+ responses = openapi .get ("responses" , {})
320+ status_response = responses .get (status , {})
321+ if not status_response :
322+ status_response ["description" ] = description
323+ status_response ["content" ] = status_response .get ("content" , {})
324+ status_response ["content" ][content_type ] = {"schema" : adapter .json_schema ()}
325+ responses [status ] = status_response
326+ openapi ["responses" ] = responses
327+ func .__extra_openapi__ = openapi # type: ignore
281328 return func
282329
283330 return decorator
0 commit comments