Skip to content

Commit 6d95399

Browse files
committed
Support multi content type in request body and responses
1 parent cd12527 commit 6d95399

File tree

10 files changed

+467
-257
lines changed

10 files changed

+467
-257
lines changed

examples/multi_content_type.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# @Author : llc
3+
# @Time : 2024/12/27 15:30
4+
from pydantic import BaseModel
5+
6+
from flask_openapi3 import OpenAPI
7+
8+
app = OpenAPI(__name__)
9+
10+
11+
class DogBody(BaseModel):
12+
a: int = None
13+
b: str = None
14+
15+
model_config = {
16+
"openapi_extra": {
17+
"content_type": "application/vnd.dog+json"
18+
}
19+
}
20+
21+
22+
class CatBody(BaseModel):
23+
c: int = None
24+
d: str = None
25+
26+
model_config = {
27+
"openapi_extra": {
28+
"content_type": "application/vnd.cat+json"
29+
}
30+
}
31+
32+
33+
class ContentTypeModel(BaseModel):
34+
model_config = {
35+
"openapi_extra": {
36+
"content_type": "text/csv"
37+
}
38+
}
39+
40+
41+
@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel})
42+
def index_a(body: DogBody | CatBody | ContentTypeModel):
43+
print(body)
44+
return {"hello": "world"}
45+
46+
47+
if __name__ == '__main__':
48+
app.run(debug=True)

flask_openapi3/blueprint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def _collect_openapi_info(
121121
security: Optional[list[dict[str, list[Any]]]] = None,
122122
servers: Optional[list[Server]] = None,
123123
openapi_extensions: Optional[dict[str, Any]] = None,
124+
request_body_description: Optional[str] = None,
125+
request_body_required: Optional[bool] = True,
124126
doc_ui: bool = True,
125127
method: str = HTTPMethod.GET,
126128
) -> ParametersTuple:
@@ -140,6 +142,8 @@ def _collect_openapi_info(
140142
security: A declaration of which security mechanisms can be used for this operation.
141143
servers: An alternative server array to service this operation.
142144
openapi_extensions: Allows extensions to the OpenAPI Schema.
145+
request_body_description: A brief description of the request body.
146+
request_body_required: Determines if the request body is required in the request.
143147
doc_ui: Declares this operation to be shown. Default to True.
144148
"""
145149
if self.doc_ui is True and doc_ui is True:
@@ -191,6 +195,12 @@ def _collect_openapi_info(
191195
parse_method(uri, method, self.paths, operation)
192196

193197
# Parse parameters
194-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
198+
return parse_parameters(
199+
func,
200+
components_schemas=self.components_schemas,
201+
operation=operation,
202+
request_body_description=request_body_description,
203+
request_body_required=request_body_required
204+
)
195205
else:
196206
return parse_parameters(func, doc_ui=False)

flask_openapi3/openapi.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def _collect_openapi_info(
375375
security: Optional[list[dict[str, list[Any]]]] = None,
376376
servers: Optional[list[Server]] = None,
377377
openapi_extensions: Optional[dict[str, Any]] = None,
378+
request_body_description: Optional[str] = None,
379+
request_body_required: Optional[bool] = True,
378380
doc_ui: bool = True,
379381
method: str = HTTPMethod.GET,
380382
) -> ParametersTuple:
@@ -394,6 +396,8 @@ def _collect_openapi_info(
394396
security: A declaration of which security mechanisms can be used for this operation.
395397
servers: An alternative server array to service this operation.
396398
openapi_extensions: Allows extensions to the OpenAPI Schema.
399+
request_body_description: A brief description of the request body.
400+
request_body_required: Determines if the request body is required in the request.
397401
doc_ui: Declares this operation to be shown. Default to True.
398402
method: HTTP method for the operation. Defaults to GET.
399403
"""
@@ -442,6 +446,12 @@ def _collect_openapi_info(
442446
parse_method(uri, method, self.paths, operation)
443447

444448
# Parse parameters
445-
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
449+
return parse_parameters(
450+
func,
451+
components_schemas=self.components_schemas,
452+
operation=operation,
453+
request_body_description=request_body_description,
454+
request_body_required=request_body_required
455+
)
446456
else:
447457
return parse_parameters(func, doc_ui=False)

flask_openapi3/request.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
# @Time : 2022/4/1 16:54
44
import inspect
55
import json
6-
from functools import wraps
76
from json import JSONDecodeError
8-
from typing import Any, Optional, Type
7+
from types import UnionType
8+
from typing import Any, Optional, Type, Union, get_args, get_origin
99

1010
from flask import abort, current_app, request
11-
from pydantic import BaseModel, ValidationError
11+
from pydantic import BaseModel, RootModel, ValidationError
1212
from pydantic.fields import FieldInfo
1313
from werkzeug.datastructures.structures import MultiDict
1414

15-
from .utils import parse_parameters
15+
from flask_openapi3.utils import is_application_json, parse_parameters
1616

1717

1818
def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo):
@@ -146,12 +146,20 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
146146

147147

148148
def _validate_body(body: Type[BaseModel], func_kwargs: dict):
149-
obj = request.get_json(silent=True)
150-
if isinstance(obj, str):
151-
body_model = body.model_validate_json(json_data=obj)
149+
if is_application_json(request.mimetype):
150+
if get_origin(body) == UnionType:
151+
root_model_list = [model for model in get_args(body)]
152+
Body = RootModel[Union[tuple(root_model_list)]] # type: ignore
153+
else:
154+
Body = body # type: ignore
155+
obj = request.get_json(silent=True)
156+
if isinstance(obj, str):
157+
body_model = Body.model_validate_json(json_data=obj)
158+
else:
159+
body_model = Body.model_validate(obj=obj)
160+
func_kwargs["body"] = body_model
152161
else:
153-
body_model = body.model_validate(obj=obj)
154-
func_kwargs["body"] = body_model
162+
func_kwargs["body"] = request
155163

156164

157165
def _validate_request(

0 commit comments

Comments
 (0)