Skip to content

Commit 8959aa8

Browse files
authored
Merge pull request #11 from taskiq-python/bugfix/generics
2 parents 39c4766 + 31bce84 commit 8959aa8

File tree

3 files changed

+91
-15
lines changed

3 files changed

+91
-15
lines changed

.flake8

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ per-file-ignores =
122122
; Found wrong metadata variable
123123
WPS410,
124124

125+
swagger.py:
126+
; Too many local variables
127+
WPS210,
128+
125129
exclude =
126130
./.git,
127131
./venv,

aiohttp_deps/swagger.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
_T = TypeVar("_T") # noqa: WPS111
1515

16+
REF_TEMPLATE = "#/components/schemas/{model}"
1617
SCHEMA_KEY = "openapi_schema"
1718
SWAGGER_HTML_TEMPALTE = """
1819
<html lang="en">
@@ -75,12 +76,13 @@ def dummy(_var: annotation.annotation) -> None: # type: ignore
7576
return var == Optional[var]
7677

7778

78-
def _add_route_def( # noqa: C901, WPS210
79+
def _add_route_def( # noqa: C901, WPS210, WPS211
7980
openapi_schema: Dict[str, Any],
8081
route: web.ResourceRoute,
8182
method: str,
8283
graph: DependencyGraph,
8384
extra_openapi: Dict[str, Any],
85+
extra_openapi_schemas: Dict[str, Any],
8486
) -> None:
8587
route_info: Dict[str, Any] = {
8688
"description": inspect.getdoc(graph.target),
@@ -90,6 +92,9 @@ def _add_route_def( # noqa: C901, WPS210
9092
if route.resource is None: # pragma: no cover
9193
return
9294

95+
if extra_openapi_schemas:
96+
openapi_schema["components"]["schemas"].update(extra_openapi_schemas)
97+
9398
params: Dict[tuple[str, str], Any] = {}
9499

95100
def _insert_in_params(data: Dict[str, Any]) -> None:
@@ -114,9 +119,9 @@ def _insert_in_params(data: Dict[str, Any]) -> None:
114119
):
115120
input_schema = pydantic.TypeAdapter(
116121
dependency.signature.annotation,
117-
).json_schema()
122+
).json_schema(ref_template=REF_TEMPLATE)
118123
openapi_schema["components"]["schemas"].update(
119-
input_schema.pop("definitions", {}),
124+
input_schema.pop("$defs", {}),
120125
)
121126
route_info["requestBody"] = {
122127
"content": {content_type: {"schema": input_schema}},
@@ -216,13 +221,19 @@ async def event_handler(app: web.Application) -> None:
216221
"__extra_openapi__",
217222
{},
218223
)
224+
extra_schemas = getattr(
225+
route._handler.original_handler,
226+
"__extra_openapi_schemas__",
227+
{},
228+
)
219229
try:
220230
_add_route_def(
221231
openapi_schema,
222232
route, # type: ignore
223233
route.method,
224234
route._handler.graph,
225235
extra_openapi=extra_openapi,
236+
extra_openapi_schemas=extra_schemas,
226237
)
227238
except Exception as exc: # pragma: no cover
228239
logger.warn(
@@ -234,20 +245,23 @@ async def event_handler(app: web.Application) -> None:
234245
elif isinstance(route._handler, InjectableViewHandler):
235246
for key, graph in route._handler.graph_map.items():
236247
extra_openapi = getattr(
237-
getattr(
238-
route._handler.original_handler,
239-
key,
240-
),
248+
getattr(route._handler.original_handler, key),
241249
"__extra_openapi__",
242250
{},
243251
)
252+
extra_schemas = getattr(
253+
getattr(route._handler.original_handler, key),
254+
"__extra_openapi_schemas__",
255+
{},
256+
)
244257
try:
245258
_add_route_def(
246259
openapi_schema,
247260
route, # type: ignore
248261
key,
249262
graph,
250263
extra_openapi=extra_openapi,
264+
extra_openapi_schemas=extra_schemas,
251265
)
252266
except Exception as exc: # pragma: no cover
253267
logger.warn(
@@ -315,16 +329,20 @@ def openapi_response(
315329

316330
def decorator(func: _T) -> _T:
317331
openapi = getattr(func, "__extra_openapi__", {})
332+
openapi_schemas = getattr(func, "__extra_openapi_schemas__", {})
318333
adapter: "pydantic.TypeAdapter[Any]" = pydantic.TypeAdapter(model)
319334
responses = openapi.get("responses", {})
320335
status_response = responses.get(status, {})
321336
if not status_response:
322337
status_response["description"] = description
323338
status_response["content"] = status_response.get("content", {})
324-
status_response["content"][content_type] = {"schema": adapter.json_schema()}
339+
response_schema = adapter.json_schema(ref_template=REF_TEMPLATE)
340+
openapi_schemas.update(response_schema.pop("$defs", {}))
341+
status_response["content"][content_type] = {"schema": response_schema}
325342
responses[status] = status_response
326343
openapi["responses"] = responses
327344
func.__extra_openapi__ = openapi # type: ignore
345+
func.__extra_openapi_schemas__ = openapi_schemas # type: ignore
328346
return func
329347

330348
return decorator

tests/test_swagger.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, Generic, Optional, TypeVar
33

44
import pytest
55
from aiohttp import web
@@ -20,6 +20,21 @@
2020
from tests.conftest import ClientGenerator
2121

2222

23+
def follow_ref(ref: str, data: Dict[str, Any]) -> Dict[str, Any]:
24+
"""Function for following openapi references."""
25+
components = deque(ref.split("/"))
26+
current_model = None
27+
while components:
28+
component = components.popleft()
29+
if component.strip() == "#":
30+
current_model = data
31+
continue
32+
current_model = current_model.get(component)
33+
if current_model is None:
34+
return {}
35+
return current_model
36+
37+
2338
def get_schema_by_ref(full_schema: Dict[str, Any], ref: str):
2439
ref_path = deque(ref.split("/"))
2540
current_schema = full_schema
@@ -141,19 +156,26 @@ async def my_handler(body=Depends(Json())):
141156
assert resp.status == 200
142157
resp_json = await resp.json()
143158
handler_info = resp_json["paths"]["/a"]["get"]
144-
print(handler_info)
145159
assert handler_info["requestBody"]["content"]["application/json"] == {}
146160

147161

148162
@pytest.mark.anyio
149-
async def test_json_untyped(
163+
async def test_json_generic(
150164
my_app: web.Application,
151165
aiohttp_client: ClientGenerator,
152166
):
153167
OPENAPI_URL = "/my_api_def.json"
154168
my_app.on_startup.append(setup_swagger(schema_url=OPENAPI_URL))
155169

156-
async def my_handler(body=Depends(Json())):
170+
T = TypeVar("T")
171+
172+
class First(BaseModel):
173+
name: str
174+
175+
class Second(BaseModel, Generic[T]):
176+
data: T
177+
178+
async def my_handler(body: Second[First] = Depends(Json())):
157179
"""Nothing."""
158180

159181
my_app.router.add_get("/a", my_handler)
@@ -163,7 +185,10 @@ async def my_handler(body=Depends(Json())):
163185
assert resp.status == 200
164186
resp_json = await resp.json()
165187
handler_info = resp_json["paths"]["/a"]["get"]
166-
assert {} == handler_info["requestBody"]["content"]["application/json"]
188+
schema = handler_info["requestBody"]["content"]["application/json"]["schema"]
189+
first_ref = schema["properties"]["data"]["$ref"]
190+
first_obj = follow_ref(first_ref, resp_json)
191+
assert "name" in first_obj["properties"]
167192

168193

169194
@pytest.mark.anyio
@@ -438,7 +463,6 @@ async def my_handler():
438463
resp_json = await resp.json()
439464

440465
handler_info = resp_json["paths"]["/a"]["get"]
441-
print(handler_info)
442466
assert handler_info["responses"] == {"200": {}}
443467

444468

@@ -495,7 +519,6 @@ async def my_handler(
495519
assert resp.status == 200
496520
resp_json = await resp.json()
497521
params = resp_json["paths"]["/a"]["get"]["parameters"]
498-
print(params)
499522
assert len(params) == 1
500523
assert params[0]["name"] == "Head"
501524
assert params[0]["required"]
@@ -562,3 +585,34 @@ async def my_handler():
562585
assert "200" in route_info["responses"]
563586
assert "application/json" in route_info["responses"]["200"]["content"]
564587
assert "application/xml" in route_info["responses"]["200"]["content"]
588+
589+
590+
@pytest.mark.anyio
591+
async def test_custom_responses_generics(
592+
my_app: web.Application,
593+
aiohttp_client: ClientGenerator,
594+
) -> None:
595+
OPENAPI_URL = "/my_api_def.json"
596+
my_app.on_startup.append(setup_swagger(schema_url=OPENAPI_URL))
597+
598+
T = TypeVar("T")
599+
600+
class First(BaseModel):
601+
name: str
602+
603+
class Second(BaseModel, Generic[T]):
604+
data: T
605+
606+
@openapi_response(200, Second[First])
607+
async def my_handler():
608+
"""Nothing."""
609+
610+
my_app.router.add_get("/a", my_handler)
611+
client = await aiohttp_client(my_app)
612+
response = await client.get(OPENAPI_URL)
613+
resp_json = await response.json()
614+
first_ref = resp_json["paths"]["/a"]["get"]["responses"]["200"]["content"][
615+
"application/json"
616+
]["schema"]["properties"]["data"]["$ref"]
617+
first_obj = follow_ref(first_ref, resp_json)
618+
assert "name" in first_obj["properties"]

0 commit comments

Comments
 (0)