|
| 1 | +import inspect |
| 2 | +from collections import defaultdict |
| 3 | +from logging import getLogger |
| 4 | +from typing import Any, Awaitable, Callable, Dict, Optional, Union |
| 5 | + |
| 6 | +from aiohttp import web |
| 7 | +from pydantic import schema_of |
| 8 | +from pydantic.utils import deep_update |
| 9 | +from taskiq_dependencies import DependencyGraph |
| 10 | + |
| 11 | +from aiohttp_deps.initializer import InjectableFuncHandler, InjectableViewHandler |
| 12 | +from aiohttp_deps.utils import Form, Header, Json, Path, Query |
| 13 | + |
| 14 | +REF_TEMPLATE = "#/components/schemas/{model}" |
| 15 | +SCHEMA_KEY = "openapi_schema" |
| 16 | +SWAGGER_HTML_TEMPALTE = """ |
| 17 | +<html lang="en"> |
| 18 | +
|
| 19 | +<head> |
| 20 | + <meta charset="utf-8" /> |
| 21 | + <meta name="viewport" content="width=device-width, initial-scale=1" /> |
| 22 | + <meta name="description" content="SwaggerUI" /> |
| 23 | + <title>SwaggerUI</title> |
| 24 | + <link rel="stylesheet" |
| 25 | + href="https://unpkg.com/swagger-ui-dist/swagger-ui.css" |
| 26 | + /> |
| 27 | +</head> |
| 28 | +
|
| 29 | +<body> |
| 30 | + <div id="swagger-ui"></div> |
| 31 | + <script src="https://unpkg.com/swagger-ui-dist/swagger-ui-bundle.js" |
| 32 | + crossorigin></script> |
| 33 | + <script> |
| 34 | + window.onload = () => { |
| 35 | + window.ui = SwaggerUIBundle({ |
| 36 | + url: '{schema_url}', |
| 37 | + dom_id: '#swagger-ui', |
| 38 | + }); |
| 39 | + }; |
| 40 | + </script> |
| 41 | +</body> |
| 42 | +</html> |
| 43 | +""" |
| 44 | +METHODS_WITH_BODY = {"POST", "PUT", "PATCH"} # noqa: WPS407 |
| 45 | + |
| 46 | +logger = getLogger() |
| 47 | + |
| 48 | + |
| 49 | +async def _schema_handler( |
| 50 | + request: web.Request, |
| 51 | +) -> web.Response: |
| 52 | + return web.json_response(request.app[SCHEMA_KEY]) |
| 53 | + |
| 54 | + |
| 55 | +def _get_swagger_handler( |
| 56 | + swagger_html: str, |
| 57 | +) -> Callable[[web.Request], Awaitable[web.Response]]: |
| 58 | + async def swagger_handler(_: web.Request) -> web.Response: |
| 59 | + return web.Response(text=swagger_html, content_type="text/html") |
| 60 | + |
| 61 | + return swagger_handler |
| 62 | + |
| 63 | + |
| 64 | +def _is_optional(annotation: Optional[inspect.Parameter]) -> bool: |
| 65 | + # If it's an empty annotation, |
| 66 | + # we guess that the value can be optional. |
| 67 | + if annotation is None or annotation.annotation == annotation.empty: |
| 68 | + return True |
| 69 | + |
| 70 | + origin = getattr(annotation.annotation, "__origin__", None) |
| 71 | + if origin is None: |
| 72 | + return False |
| 73 | + |
| 74 | + if origin == Union: |
| 75 | + args = getattr(annotation.annotation, "__args__", ()) |
| 76 | + for arg in args: |
| 77 | + if arg is type(None): # noqa: E721, WPS516 |
| 78 | + return True |
| 79 | + return False |
| 80 | + |
| 81 | + |
| 82 | +def _add_route_def( # noqa: C901 |
| 83 | + openapi_schema: Dict[str, Any], |
| 84 | + route: web.ResourceRoute, |
| 85 | + method: str, |
| 86 | + graph: DependencyGraph, |
| 87 | + extra_openapi: Dict[str, Any], |
| 88 | +) -> None: |
| 89 | + route_info: Dict[str, Any] = { |
| 90 | + "description": inspect.getdoc(graph.target), |
| 91 | + "responses": {}, |
| 92 | + "parameters": [], |
| 93 | + } |
| 94 | + if route.resource is None: # pragma: no cover |
| 95 | + return |
| 96 | + |
| 97 | + for dependency in graph.ordered_deps: |
| 98 | + if isinstance(dependency.dependency, (Json, Form)): |
| 99 | + content_type = "application/json" |
| 100 | + if isinstance(dependency.dependency, Form): |
| 101 | + content_type = "application/x-www-form-urlencoded" |
| 102 | + if ( |
| 103 | + dependency.signature |
| 104 | + and dependency.signature.annotation != inspect.Parameter.empty |
| 105 | + ): |
| 106 | + input_schema = schema_of( |
| 107 | + dependency.signature.annotation, |
| 108 | + ref_template=REF_TEMPLATE, |
| 109 | + ) |
| 110 | + openapi_schema["components"]["schemas"].update( |
| 111 | + input_schema.pop("definitions", {}), |
| 112 | + ) |
| 113 | + route_info["requestBody"] = { |
| 114 | + "content": {content_type: {"schema": input_schema}}, |
| 115 | + } |
| 116 | + else: |
| 117 | + route_info["requestBody"] = { |
| 118 | + "content": {content_type: {}}, |
| 119 | + } |
| 120 | + elif isinstance(dependency.dependency, Query): |
| 121 | + route_info["parameters"].append( |
| 122 | + { |
| 123 | + "name": dependency.dependency.alias or dependency.param_name, |
| 124 | + "in": "query", |
| 125 | + "description": dependency.dependency.description, |
| 126 | + "required": not _is_optional(dependency.signature), |
| 127 | + }, |
| 128 | + ) |
| 129 | + elif isinstance(dependency.dependency, Header): |
| 130 | + route_info["parameters"].append( |
| 131 | + { |
| 132 | + "name": dependency.dependency.alias or dependency.param_name, |
| 133 | + "in": "header", |
| 134 | + "description": dependency.dependency.description, |
| 135 | + "required": not _is_optional(dependency.signature), |
| 136 | + }, |
| 137 | + ) |
| 138 | + elif isinstance(dependency.dependency, Path): |
| 139 | + route_info["parameters"].append( |
| 140 | + { |
| 141 | + "name": dependency.dependency.alias or dependency.param_name, |
| 142 | + "in": "path", |
| 143 | + "description": dependency.dependency.description, |
| 144 | + "required": not _is_optional(dependency.signature), |
| 145 | + "allowEmptyValue": _is_optional(dependency.signature), |
| 146 | + }, |
| 147 | + ) |
| 148 | + |
| 149 | + openapi_schema["paths"][route.resource.canonical].update( |
| 150 | + {method.lower(): deep_update(route_info, extra_openapi)}, |
| 151 | + ) |
| 152 | + |
| 153 | + |
| 154 | +def setup_swagger( # noqa: C901, WPS211 |
| 155 | + schema_url: str = "/openapi.json", |
| 156 | + swagger_ui_url: str = "/docs", |
| 157 | + enable_ui: bool = True, |
| 158 | + hide_heads: bool = True, |
| 159 | + title: str = "AioHTTP", |
| 160 | + description: Optional[str] = None, |
| 161 | + version: str = "1.0.0", |
| 162 | +) -> Callable[[web.Application], Awaitable[None]]: |
| 163 | + """ |
| 164 | + Add swagger documentation. |
| 165 | +
|
| 166 | + This function creates new function, |
| 167 | + that can be used in on_startup. |
| 168 | +
|
| 169 | + Add outputs of this function in on_startup array |
| 170 | + to enable swagger. |
| 171 | +
|
| 172 | + >>> app.on_startup.append(setup_swagger()) |
| 173 | +
|
| 174 | + This function will generate swagger schema based |
| 175 | + on dependencies that were used. |
| 176 | +
|
| 177 | + :param schema_url: URL where schema will be served. |
| 178 | + :param swagger_ui_url: URL where swagger ui will be served. |
| 179 | + :param enable_ui: whether you want to enable bundled swagger ui. |
| 180 | + :param hide_heads: hide HEAD requests. |
| 181 | + :param title: Title of an application. |
| 182 | + :param description: description of an application. |
| 183 | + :param version: version of an application. |
| 184 | + :return: startup event handler. |
| 185 | + """ |
| 186 | + |
| 187 | + async def event_handler(app: web.Application) -> None: |
| 188 | + openapi_schema = { |
| 189 | + "openapi": "3.0.0", |
| 190 | + "info": { |
| 191 | + "title": title, |
| 192 | + "description": description, |
| 193 | + "version": version, |
| 194 | + }, |
| 195 | + "components": {"schemas": {}}, |
| 196 | + "paths": defaultdict(dict), |
| 197 | + } |
| 198 | + for route in app.router.routes(): |
| 199 | + if route.resource is None: # pragma: no cover |
| 200 | + continue |
| 201 | + if hide_heads and route.method == "HEAD": |
| 202 | + continue |
| 203 | + if isinstance(route._handler, InjectableFuncHandler): |
| 204 | + extra_openapi = getattr( |
| 205 | + route._handler.original_handler, |
| 206 | + "__extra_openapi__", |
| 207 | + {}, |
| 208 | + ) |
| 209 | + try: |
| 210 | + _add_route_def( |
| 211 | + openapi_schema, |
| 212 | + route, # type: ignore |
| 213 | + route.method, |
| 214 | + route._handler.graph, |
| 215 | + extra_openapi=extra_openapi, |
| 216 | + ) |
| 217 | + except Exception as exc: # pragma: no cover |
| 218 | + logger.warn( |
| 219 | + "Cannot add route info: %s", |
| 220 | + exc, |
| 221 | + exc_info=True, |
| 222 | + ) |
| 223 | + |
| 224 | + elif isinstance(route._handler, InjectableViewHandler): |
| 225 | + for key, graph in route._handler.graph_map.items(): |
| 226 | + extra_openapi = getattr( |
| 227 | + getattr( |
| 228 | + route._handler.original_handler, |
| 229 | + key, |
| 230 | + ), |
| 231 | + "__extra_openapi__", |
| 232 | + {}, |
| 233 | + ) |
| 234 | + try: |
| 235 | + _add_route_def( |
| 236 | + openapi_schema, |
| 237 | + route, # type: ignore |
| 238 | + key, |
| 239 | + graph, |
| 240 | + extra_openapi=extra_openapi, |
| 241 | + ) |
| 242 | + except Exception as exc: # pragma: no cover |
| 243 | + logger.warn( |
| 244 | + "Cannot add route info: %s", |
| 245 | + exc, |
| 246 | + exc_info=True, |
| 247 | + ) |
| 248 | + |
| 249 | + app[SCHEMA_KEY] = openapi_schema |
| 250 | + |
| 251 | + app.router.add_get( |
| 252 | + schema_url, |
| 253 | + _schema_handler, |
| 254 | + ) |
| 255 | + |
| 256 | + if enable_ui: |
| 257 | + app.router.add_get( |
| 258 | + swagger_ui_url, |
| 259 | + _get_swagger_handler( |
| 260 | + SWAGGER_HTML_TEMPALTE.replace("{schema_url}", schema_url), |
| 261 | + ), |
| 262 | + ) |
| 263 | + |
| 264 | + return event_handler |
| 265 | + |
| 266 | + |
| 267 | +def extra_openapi(additional_schema: Dict[str, Any]) -> Callable[..., Any]: |
| 268 | + """ |
| 269 | + Add extra openapi schema. |
| 270 | +
|
| 271 | + This function just adds a parameter for later use |
| 272 | + by openapi schema generator. |
| 273 | +
|
| 274 | + :param additional_schema: dict with updates. |
| 275 | + :return: same function with new attributes. |
| 276 | + """ |
| 277 | + |
| 278 | + def decorator(func: Any) -> Any: |
| 279 | + func.__extra_openapi__ = additional_schema |
| 280 | + |
| 281 | + return func |
| 282 | + |
| 283 | + return decorator |
0 commit comments