Skip to content

Commit 2210cab

Browse files
committed
Support multi content type in request body and responses
1 parent a64533d commit 2210cab

File tree

10 files changed

+404
-193
lines changed

10 files changed

+404
-193
lines changed

examples/multi_content_type.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# @Author : llc
3+
# @Time : 2024/12/27 15:30
4+
from pydantic import BaseModel
5+
6+
from flask_openapi3 import OpenAPI
7+
8+
app = OpenAPI(__name__)
9+
10+
11+
class DogBody(BaseModel):
12+
a: int = None
13+
b: str = None
14+
15+
model_config = {
16+
"openapi_extra": {
17+
"content_type": "application/vnd.dog+json"
18+
}
19+
}
20+
21+
22+
class CatBody(BaseModel):
23+
c: int = None
24+
d: str = None
25+
26+
model_config = {
27+
"openapi_extra": {
28+
"content_type": "application/vnd.cat+json"
29+
}
30+
}
31+
32+
33+
class ContentTypeModel(BaseModel):
34+
model_config = {
35+
"openapi_extra": {
36+
"content_type": "text/csv"
37+
}
38+
}
39+
40+
41+
@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel})
42+
def index_a(body: DogBody | CatBody | ContentTypeModel):
43+
print(body)
44+
return {"hello": "world"}
45+
46+
47+
if __name__ == '__main__':
48+
app.run(debug=True)

flask_openapi3/blueprint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def _collect_openapi_info(
121121
security: list[dict[str, list[Any]]] | None = None,
122122
servers: list[Server] | None = None,
123123
openapi_extensions: dict[str, Any] | None = None,
124+
request_body_description: str | None = None,
125+
request_body_required: bool | None = True,
124126
doc_ui: bool = True,
125127
method: str = HTTPMethod.GET,
126128
) -> ParametersTuple:
@@ -140,6 +142,8 @@ def _collect_openapi_info(
140142
security: A declaration of which security mechanisms can be used for this operation.
141143
servers: An alternative server array to service this operation.
142144
openapi_extensions: Allows extensions to the OpenAPI Schema.
145+
request_body_description: A brief description of the request body.
146+
request_body_required: Determines if the request body is required in the request.
143147
doc_ui: Declares this operation to be shown. Default to True.
144148
"""
145149
if self.doc_ui is True and doc_ui is True:
@@ -191,6 +195,12 @@ def _collect_openapi_info(
191195
parse_method(uri, method, self.paths, operation)
192196

193197
# Parse parameters
194-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
198+
return parse_parameters(
199+
func,
200+
components_schemas=self.components_schemas,
201+
operation=operation,
202+
request_body_description=request_body_description,
203+
request_body_required=request_body_required
204+
)
195205
else:
196206
return parse_parameters(func, doc_ui=False)

flask_openapi3/openapi.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _collect_openapi_info(
370370
security: list[dict[str, list[Any]]] | None = None,
371371
servers: list[Server] | None = None,
372372
openapi_extensions: dict[str, Any] | None = None,
373+
request_body_description: str | None = None,
374+
request_body_required: bool | None = True,
373375
doc_ui: bool = True,
374376
method: str = HTTPMethod.GET,
375377
) -> ParametersTuple:
@@ -389,6 +391,8 @@ def _collect_openapi_info(
389391
security: A declaration of which security mechanisms can be used for this operation.
390392
servers: An alternative server array to service this operation.
391393
openapi_extensions: Allows extensions to the OpenAPI Schema.
394+
request_body_description: A brief description of the request body.
395+
request_body_required: Determines if the request body is required in the request.
392396
doc_ui: Declares this operation to be shown. Default to True.
393397
method: HTTP method for the operation. Defaults to GET.
394398
"""
@@ -437,6 +441,12 @@ def _collect_openapi_info(
437441
parse_method(uri, method, self.paths, operation)
438442

439443
# Parse parameters
440-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
444+
return parse_parameters(
445+
func,
446+
components_schemas=self.components_schemas,
447+
operation=operation,
448+
request_body_description=request_body_description,
449+
request_body_required=request_body_required
450+
)
441451
else:
442452
return parse_parameters(func, doc_ui=False)

flask_openapi3/request.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import json
66
from functools import wraps
77
from json import JSONDecodeError
8-
from typing import Any, Type
8+
from types import UnionType
9+
from typing import Any, Type, Union, get_args, get_origin
910

1011
from flask import abort, current_app, request
11-
from pydantic import BaseModel, ValidationError
12+
from pydantic import BaseModel, RootModel, ValidationError
1213
from pydantic.fields import FieldInfo
1314
from werkzeug.datastructures.structures import MultiDict
1415

15-
from .utils import parse_parameters
16+
from flask_openapi3.utils import is_application_json, parse_parameters
1617

1718

1819
def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo):
@@ -146,12 +147,20 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
146147

147148

148149
def _validate_body(body: Type[BaseModel], func_kwargs: dict):
149-
obj = request.get_json(silent=True)
150-
if isinstance(obj, str):
151-
body_model = body.model_validate_json(json_data=obj)
150+
if is_application_json(request.mimetype):
151+
if get_origin(body) == UnionType:
152+
root_model_list = [model for model in get_args(body)]
153+
Body = RootModel[Union[tuple(root_model_list)]] # type: ignore
154+
else:
155+
Body = body # type: ignore
156+
obj = request.get_json(silent=True)
157+
if isinstance(obj, str):
158+
body_model = Body.model_validate_json(json_data=obj)
159+
else:
160+
body_model = Body.model_validate(obj=obj)
161+
func_kwargs["body"] = body_model
152162
else:
153-
body_model = body.model_validate(obj=obj)
154-
func_kwargs["body"] = body_model
163+
func_kwargs["body"] = request
155164

156165

157166
def _validate_request(

flask_openapi3/scaffold.py

Lines changed: 94 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def _collect_openapi_info(
2929
security: list[dict[str, list[Any]]] | None = None,
3030
servers: list[Server] | None = None,
3131
openapi_extensions: dict[str, Any] | None = None,
32+
request_body_description: str | None = None,
33+
request_body_required: bool | None = True,
3234
doc_ui: bool = True,
3335
method: str = HTTPMethod.GET,
3436
) -> ParametersTuple:
@@ -192,6 +194,8 @@ def post(
192194
security: list[dict[str, list[Any]]] | None = None,
193195
servers: list[Server] | None = None,
194196
openapi_extensions: dict[str, Any] | None = None,
197+
request_body_description: str | None = None,
198+
request_body_required: bool | None = True,
195199
doc_ui: bool = True,
196200
**options: Any,
197201
) -> Callable:
@@ -211,26 +215,31 @@ def post(
211215
security: A declaration of which security mechanisms can be used for this operation.
212216
servers: An alternative server array to service this operation.
213217
openapi_extensions: Allows extensions to the OpenAPI Schema.
218+
request_body_description: A brief description of the request body.
219+
request_body_required: Determines if the request body is required in the request.
214220
doc_ui: Declares this operation to be shown. Default to True.
215221
"""
216222

217223
def decorator(func) -> Callable:
218-
header, cookie, path, query, form, body, raw = self._collect_openapi_info(
219-
rule,
220-
func,
221-
tags=tags,
222-
summary=summary,
223-
description=description,
224-
external_docs=external_docs,
225-
operation_id=operation_id,
226-
responses=responses,
227-
deprecated=deprecated,
228-
security=security,
229-
servers=servers,
230-
openapi_extensions=openapi_extensions,
231-
doc_ui=doc_ui,
232-
method=HTTPMethod.POST,
233-
)
224+
header, cookie, path, query, form, body, raw = \
225+
self._collect_openapi_info(
226+
rule,
227+
func,
228+
tags=tags,
229+
summary=summary,
230+
description=description,
231+
external_docs=external_docs,
232+
operation_id=operation_id,
233+
responses=responses,
234+
deprecated=deprecated,
235+
security=security,
236+
servers=servers,
237+
openapi_extensions=openapi_extensions,
238+
request_body_description=request_body_description,
239+
request_body_required=request_body_required,
240+
doc_ui=doc_ui,
241+
method=HTTPMethod.POST,
242+
)
234243

235244
view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw)
236245
options.update({"methods": [HTTPMethod.POST]})
@@ -254,6 +263,8 @@ def put(
254263
security: list[dict[str, list[Any]]] | None = None,
255264
servers: list[Server] | None = None,
256265
openapi_extensions: dict[str, Any] | None = None,
266+
request_body_description: str | None = None,
267+
request_body_required: bool | None = True,
257268
doc_ui: bool = True,
258269
**options: Any,
259270
) -> Callable:
@@ -273,26 +284,31 @@ def put(
273284
security: A declaration of which security mechanisms can be used for this operation.
274285
servers: An alternative server array to service this operation.
275286
openapi_extensions: Allows extensions to the OpenAPI Schema.
287+
request_body_description: A brief description of the request body.
288+
request_body_required: Determines if the request body is required in the request.
276289
doc_ui: Declares this operation to be shown. Default to True.
277290
"""
278291

279292
def decorator(func) -> Callable:
280-
header, cookie, path, query, form, body, raw = self._collect_openapi_info(
281-
rule,
282-
func,
283-
tags=tags,
284-
summary=summary,
285-
description=description,
286-
external_docs=external_docs,
287-
operation_id=operation_id,
288-
responses=responses,
289-
deprecated=deprecated,
290-
security=security,
291-
servers=servers,
292-
openapi_extensions=openapi_extensions,
293-
doc_ui=doc_ui,
294-
method=HTTPMethod.PUT,
295-
)
293+
header, cookie, path, query, form, body, raw = \
294+
self._collect_openapi_info(
295+
rule,
296+
func,
297+
tags=tags,
298+
summary=summary,
299+
description=description,
300+
external_docs=external_docs,
301+
operation_id=operation_id,
302+
responses=responses,
303+
deprecated=deprecated,
304+
security=security,
305+
servers=servers,
306+
openapi_extensions=openapi_extensions,
307+
request_body_description=request_body_description,
308+
request_body_required=request_body_required,
309+
doc_ui=doc_ui,
310+
method=HTTPMethod.PUT,
311+
)
296312

297313
view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw)
298314
options.update({"methods": [HTTPMethod.PUT]})
@@ -316,6 +332,8 @@ def delete(
316332
security: list[dict[str, list[Any]]] | None = None,
317333
servers: list[Server] | None = None,
318334
openapi_extensions: dict[str, Any] | None = None,
335+
request_body_description: str | None = None,
336+
request_body_required: bool | None = True,
319337
doc_ui: bool = True,
320338
**options: Any,
321339
) -> Callable:
@@ -335,26 +353,31 @@ def delete(
335353
security: A declaration of which security mechanisms can be used for this operation.
336354
servers: An alternative server array to service this operation.
337355
openapi_extensions: Allows extensions to the OpenAPI Schema.
356+
request_body_description: A brief description of the request body.
357+
request_body_required: Determines if the request body is required in the request.
338358
doc_ui: Declares this operation to be shown. Default to True.
339359
"""
340360

341361
def decorator(func) -> Callable:
342-
header, cookie, path, query, form, body, raw = self._collect_openapi_info(
343-
rule,
344-
func,
345-
tags=tags,
346-
summary=summary,
347-
description=description,
348-
external_docs=external_docs,
349-
operation_id=operation_id,
350-
responses=responses,
351-
deprecated=deprecated,
352-
security=security,
353-
servers=servers,
354-
openapi_extensions=openapi_extensions,
355-
doc_ui=doc_ui,
356-
method=HTTPMethod.DELETE,
357-
)
362+
header, cookie, path, query, form, body, raw = \
363+
self._collect_openapi_info(
364+
rule,
365+
func,
366+
tags=tags,
367+
summary=summary,
368+
description=description,
369+
external_docs=external_docs,
370+
operation_id=operation_id,
371+
responses=responses,
372+
deprecated=deprecated,
373+
security=security,
374+
servers=servers,
375+
openapi_extensions=openapi_extensions,
376+
request_body_description=request_body_description,
377+
request_body_required=request_body_required,
378+
doc_ui=doc_ui,
379+
method=HTTPMethod.DELETE,
380+
)
358381

359382
view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw)
360383
options.update({"methods": [HTTPMethod.DELETE]})
@@ -378,6 +401,8 @@ def patch(
378401
security: list[dict[str, list[Any]]] | None = None,
379402
servers: list[Server] | None = None,
380403
openapi_extensions: dict[str, Any] | None = None,
404+
request_body_description: str | None = None,
405+
request_body_required: bool | None = True,
381406
doc_ui: bool = True,
382407
**options: Any,
383408
) -> Callable:
@@ -397,26 +422,31 @@ def patch(
397422
security: A declaration of which security mechanisms can be used for this operation.
398423
servers: An alternative server array to service this operation.
399424
openapi_extensions: Allows extensions to the OpenAPI Schema.
425+
request_body_description: A brief description of the request body.
426+
request_body_required: Determines if the request body is required in the request.
400427
doc_ui: Declares this operation to be shown. Default to True.
401428
"""
402429

403430
def decorator(func) -> Callable:
404-
header, cookie, path, query, form, body, raw = self._collect_openapi_info(
405-
rule,
406-
func,
407-
tags=tags,
408-
summary=summary,
409-
description=description,
410-
external_docs=external_docs,
411-
operation_id=operation_id,
412-
responses=responses,
413-
deprecated=deprecated,
414-
security=security,
415-
servers=servers,
416-
openapi_extensions=openapi_extensions,
417-
doc_ui=doc_ui,
418-
method=HTTPMethod.PATCH,
419-
)
431+
header, cookie, path, query, form, body, raw = \
432+
self._collect_openapi_info(
433+
rule,
434+
func,
435+
tags=tags,
436+
summary=summary,
437+
description=description,
438+
external_docs=external_docs,
439+
operation_id=operation_id,
440+
responses=responses,
441+
deprecated=deprecated,
442+
security=security,
443+
servers=servers,
444+
openapi_extensions=openapi_extensions,
445+
request_body_description=request_body_description,
446+
request_body_required=request_body_required,
447+
doc_ui=doc_ui,
448+
method=HTTPMethod.PATCH,
449+
)
420450

421451
view_func = self.create_view_func(func, header, cookie, path, query, form, body, raw)
422452
options.update({"methods": [HTTPMethod.PATCH]})

0 commit comments

Comments
 (0)