diff --git a/docs/Usage/Request.md b/docs/Usage/Request.md index c5ff9839..c41dee0b 100644 --- a/docs/Usage/Request.md +++ b/docs/Usage/Request.md @@ -167,6 +167,88 @@ def get_book(query: BookQuery, client_id:str = None): ... ``` +## Multiple content types in the request body + +```python +from typing import Union + +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_10-44-00.png) + + ## Request model First, you need to define a [pydantic](https://github.com/pydantic/pydantic) model: @@ -191,7 +273,7 @@ class BookQuery(BaseModel): author: str = Field(None, description='Author', json_schema_extra={"deprecated": True}) ``` -Magic: +The effect in swagger: ![](../assets/Snipaste_2022-09-04_10-10-03.png) diff --git a/docs/Usage/Response.md b/docs/Usage/Response.md index fae8bf93..b229a385 100644 --- a/docs/Usage/Response.md +++ b/docs/Usage/Response.md @@ -56,6 +56,122 @@ def hello(path: HelloPath): ![image-20210526104627124](../assets/image-20210526104627124.png) +*Sometimes you may need more description fields about the response, such as description, headers and links. + +You can use the following form: + +```python +@app.get( + "/test", + responses={ + "201": { + "model": BaseResponse, + "description": "Custom description", + "headers": { + "location": { + "description": "URL of the new resource", + "schema": {"type": "string"} + } + }, + "links": { + "dummy": { + "description": "dummy link" + } + } + } + } + ) + def endpoint_test(): + ... +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_11-08-40.png) + + +## Multiple content types in the responses + +```python +from typing import Union + +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.dog+json" + } + } + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/vnd.cat+json" + } + } + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = { + "openapi_extra": { + "content_type": "application/bson" + } + } + + +class ContentTypeModel(BaseModel): + model_config = { + "openapi_extra": { + "content_type": "text/csv" + } + } + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} +``` + +The effect in swagger: + +![](../assets/Snipaste_2025-01-14_10-49-19.png) + + ## More information about OpenAPI responses - [OpenAPI Responses Object](https://spec.openapis.org/oas/v3.1.0#responses-object), it includes the Response Object. diff --git a/docs/Usage/Route_Operation.md b/docs/Usage/Route_Operation.md index 20aadf25..077d4b6c 100644 --- a/docs/Usage/Route_Operation.md +++ b/docs/Usage/Route_Operation.md @@ -289,6 +289,29 @@ class BookListAPIView: app.register_api_view(api_view) ``` +## request_body_description + +A brief description of the request body. + +```python +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + +@app.post( + "/", + request_body_description="A brief description of the request body." +) +def create_book(body: Bookbody): + ... +``` + +![](../assets/Snipaste_2025-01-14_10-56-40.png) + +## request_body_required + +Determines if the request body is required in the request. + ## doc_ui You can pass `doc_ui=False` to disable the `OpenAPI spec` when init `OpenAPI `. diff --git a/docs/assets/Snipaste_2025-01-14_10-44-00.png b/docs/assets/Snipaste_2025-01-14_10-44-00.png new file mode 100644 index 00000000..d716d979 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-44-00.png differ diff --git a/docs/assets/Snipaste_2025-01-14_10-49-19.png b/docs/assets/Snipaste_2025-01-14_10-49-19.png new file mode 100644 index 00000000..f1bb52e7 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-49-19.png differ diff --git a/docs/assets/Snipaste_2025-01-14_10-56-40.png b/docs/assets/Snipaste_2025-01-14_10-56-40.png new file mode 100644 index 00000000..e66488dc Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_10-56-40.png differ diff --git a/docs/assets/Snipaste_2025-01-14_11-08-40.png b/docs/assets/Snipaste_2025-01-14_11-08-40.png new file mode 100644 index 00000000..dd537d87 Binary files /dev/null and b/docs/assets/Snipaste_2025-01-14_11-08-40.png differ diff --git a/examples/multi_content_type.py b/examples/multi_content_type.py new file mode 100644 index 00000000..857528f1 --- /dev/null +++ b/examples/multi_content_type.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# @Author : llc +# @Time : 2024/12/27 15:30 +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = {"openapi_extra": {"content_type": "application/vnd.dog+json"}} + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = {"openapi_extra": {"content_type": "application/vnd.cat+json"}} + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = {"openapi_extra": {"content_type": "application/bson"}} + + +class ContentTypeModel(BaseModel): + model_config = {"openapi_extra": {"content_type": "text/csv"}} + + +@app.post("/a", responses={200: DogBody | CatBody | ContentTypeModel | BsonModel}) +def index_a(body: DogBody | CatBody | ContentTypeModel | BsonModel): + """ + multiple content types examples. + + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + from bson import BSON + + obj = BSON(body.data).decode() + new_body = body.model_validate(obj=obj) + print(new_body) + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +if __name__ == "__main__": + app.run(debug=True) diff --git a/flask_openapi3/blueprint.py b/flask_openapi3/blueprint.py index ba7e95c7..361acfb6 100644 --- a/flask_openapi3/blueprint.py +++ b/flask_openapi3/blueprint.py @@ -121,6 +121,8 @@ def _collect_openapi_info( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, method: str = HTTPMethod.GET, ) -> ParametersTuple: @@ -140,6 +142,8 @@ def _collect_openapi_info( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ if self.doc_ui is True and doc_ui is True: @@ -191,6 +195,12 @@ def _collect_openapi_info( parse_method(uri, method, self.paths, operation) # Parse parameters - return parse_parameters(func, components_schemas=self.components_schemas, operation=operation) + return parse_parameters( + func, + components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required, + ) else: return parse_parameters(func, doc_ui=False) diff --git a/flask_openapi3/models/path_item.py b/flask_openapi3/models/path_item.py index 5622531b..2131b4ae 100644 --- a/flask_openapi3/models/path_item.py +++ b/flask_openapi3/models/path_item.py @@ -1,18 +1,14 @@ # -*- coding: utf-8 -*- # @Author : llc # @Time : 2023/7/4 9:50 -import typing -from typing import Optional from pydantic import BaseModel, Field +from .operation import Operation from .parameter import Parameter from .reference import Reference from .server import Server -if typing.TYPE_CHECKING: # pragma: no cover - from .operation import Operation - class PathItem(BaseModel): """ @@ -22,14 +18,14 @@ class PathItem(BaseModel): ref: str | None = Field(default=None, alias="$ref") summary: str | None = None description: str | None = None - get: Optional["Operation"] = None - put: Optional["Operation"] = None - post: Optional["Operation"] = None - delete: Optional["Operation"] = None - options: Optional["Operation"] = None - head: Optional["Operation"] = None - patch: Optional["Operation"] = None - trace: Optional["Operation"] = None + get: Operation | None = None + put: Operation | None = None + post: Operation | None = None + delete: Operation | None = None + options: Operation | None = None + head: Operation | None = None + patch: Operation | None = None + trace: Operation | None = None servers: list[Server] | None = None parameters: list[Parameter | Reference] | None = None diff --git a/flask_openapi3/openapi.py b/flask_openapi3/openapi.py index c9c04e2b..dbc1187b 100644 --- a/flask_openapi3/openapi.py +++ b/flask_openapi3/openapi.py @@ -370,6 +370,8 @@ def _collect_openapi_info( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, method: str = HTTPMethod.GET, ) -> ParametersTuple: @@ -389,6 +391,8 @@ def _collect_openapi_info( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. method: HTTP method for the operation. Defaults to GET. """ @@ -437,6 +441,12 @@ def _collect_openapi_info( parse_method(uri, method, self.paths, operation) # Parse parameters - return parse_parameters(func, components_schemas=self.components_schemas, operation=operation) + return parse_parameters( + func, + components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required, + ) else: return parse_parameters(func, doc_ui=False) diff --git a/flask_openapi3/request.py b/flask_openapi3/request.py index d5936dbd..5aef305b 100644 --- a/flask_openapi3/request.py +++ b/flask_openapi3/request.py @@ -5,14 +5,15 @@ import json from functools import wraps from json import JSONDecodeError -from typing import Any, Type +from types import UnionType +from typing import Any, Type, Union, get_args, get_origin from flask import abort, current_app, request -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, RootModel, ValidationError from pydantic.fields import FieldInfo from werkzeug.datastructures.structures import MultiDict -from .utils import parse_parameters +from flask_openapi3.utils import is_application_json, parse_parameters def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo): @@ -60,7 +61,7 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict): value = request_headers.get(key_alias_title) else: key = model_field_key - value = request_headers[key_title] + value = request_headers.get(key_title) if value is not None: header_dict[key] = value if model_field_schema.get("type") == "null": @@ -149,12 +150,20 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict): def _validate_body(body: Type[BaseModel], func_kwargs: dict): - obj = request.get_json(silent=True) - if isinstance(obj, str): - body_model = body.model_validate_json(json_data=obj) + if is_application_json(request.mimetype): + if get_origin(body) in (Union, UnionType): + root_model_list = [model for model in get_args(body)] + Body = RootModel[Union[tuple(root_model_list)]] # type: ignore + else: + Body = body # type: ignore + obj = request.get_json(silent=True) + if isinstance(obj, str): + body_model = Body.model_validate_json(json_data=obj) + else: + body_model = Body.model_validate(obj=obj) + func_kwargs["body"] = body_model else: - body_model = body.model_validate(obj=obj) - func_kwargs["body"] = body_model + func_kwargs["body"] = request def _validate_request( diff --git a/flask_openapi3/scaffold.py b/flask_openapi3/scaffold.py index 02210445..795b2727 100644 --- a/flask_openapi3/scaffold.py +++ b/flask_openapi3/scaffold.py @@ -29,6 +29,8 @@ def _collect_openapi_info( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, method: str = HTTPMethod.GET, ) -> ParametersTuple: @@ -192,6 +194,8 @@ def post( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, **options: Any, ) -> Callable: @@ -211,6 +215,8 @@ def post( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -228,6 +234,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.POST, ) @@ -254,6 +262,8 @@ def put( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, **options: Any, ) -> Callable: @@ -273,6 +283,8 @@ def put( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -290,6 +302,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.PUT, ) @@ -316,6 +330,8 @@ def delete( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, **options: Any, ) -> Callable: @@ -335,6 +351,8 @@ def delete( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -352,6 +370,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.DELETE, ) @@ -378,6 +398,8 @@ def patch( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, **options: Any, ) -> Callable: @@ -397,6 +419,8 @@ def patch( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -414,6 +438,8 @@ def decorator(func) -> Callable: security=security, servers=servers, openapi_extensions=openapi_extensions, + request_body_description=request_body_description, + request_body_required=request_body_required, doc_ui=doc_ui, method=HTTPMethod.PATCH, ) diff --git a/flask_openapi3/types.py b/flask_openapi3/types.py index 22154a23..0984c6af 100644 --- a/flask_openapi3/types.py +++ b/flask_openapi3/types.py @@ -2,13 +2,15 @@ # @Author : llc # @Time : 2023/7/9 15:25 from http import HTTPStatus -from typing import Any, Type +from typing import Any, Type, TypeVar, Union from pydantic import BaseModel from .models import RawModel, SecurityScheme -_ResponseDictValue = Type[BaseModel] | dict[Any, Any] | None +_MultiBaseModel = TypeVar("_MultiBaseModel", bound=Type[BaseModel]) + +_ResponseDictValue = Union[Type[BaseModel], _MultiBaseModel, dict[Any, Any], None] ResponseDict = dict[str | int | HTTPStatus, _ResponseDictValue] diff --git a/flask_openapi3/utils.py b/flask_openapi3/utils.py index e03f65ac..330f37cb 100644 --- a/flask_openapi3/utils.py +++ b/flask_openapi3/utils.py @@ -7,7 +7,17 @@ import sys from enum import Enum from http import HTTPStatus -from typing import Any, Callable, DefaultDict, Type, get_type_hints +from types import UnionType +from typing import ( + Any, + Callable, + DefaultDict, + Type, + Union, + get_args, + get_origin, + get_type_hints, +) from flask import current_app, make_response from flask.wrappers import Response as FlaskResponse @@ -265,6 +275,11 @@ def parse_form( ) -> tuple[dict[str, MediaType], dict]: """Parses a form model and returns a list of parameters and component schemas.""" schema = get_model_schema(form) + + model_config: DefaultDict[str, Any] = form.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "multipart/form-data") + components_schemas = dict() properties = schema.get("properties", {}) @@ -277,14 +292,22 @@ def parse_form( for k, v in properties.items(): if v.get("type") == "array": encoding[k] = Encoding(style="form", explode=True) - content = { - "multipart/form-data": MediaType( - schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}), - ) - } + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + if encoding: - content["multipart/form-data"].encoding = encoding + media_type.encoding = encoding + content = {content_type: media_type} # Parse definitions definitions = schema.get("$defs", {}) for name, value in definitions.items(): @@ -297,69 +320,128 @@ def parse_body( body: Type[BaseModel], ) -> tuple[dict[str, MediaType], dict]: """Parses a body model and returns a list of parameters and component schemas.""" - schema = get_model_schema(body) - components_schemas = dict() - original_title = schema.get("title") or body.__name__ - title = normalize_name(original_title) - components_schemas[title] = Schema(**schema) - content = {"application/json": MediaType(schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}))} + content = {} + components_schemas = {} - # Parse definitions - definitions = schema.get("$defs", {}) - for name, value in definitions.items(): - components_schemas[name] = Schema(**value) + def _parse_body(_model): + model_config: DefaultDict[str, Any] = _model.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "application/json") + + if not is_application_json(content_type): + content_schema = openapi_extra.get("content_schema", {"type": DataType.STRING}) + content[content_type] = MediaType(**{"schema": content_schema}) + return + + schema = get_model_schema(_model) + + original_title = schema.get("title") or _model.__name__ + title = normalize_name(original_title) + components_schemas[title] = Schema(**schema) + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + + content[content_type] = media_type + + # Parse definitions + definitions = schema.get("$defs", {}) + for name, value in definitions.items(): + components_schemas[name] = Schema(**value) + + if get_origin(body) in (Union, UnionType): + for model in get_args(body): + _parse_body(model) + else: + _parse_body(body) return content, components_schemas def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, operation: Operation) -> None: - _responses = {} - _schemas = {} + _responses: dict = {} + _schemas: dict = {} + + def _parse_response(_key, _model): + model_config: DefaultDict[str, Any] = _model.model_config # type: ignore + openapi_extra = model_config.get("openapi_extra", {}) + content_type = openapi_extra.get("content_type", "application/json") + + if not is_application_json(content_type): + content_schema = openapi_extra.get("content_schema", {"type": DataType.STRING}) + media_type = MediaType(**{"schema": content_schema}) + if _responses.get(_key): + _responses[_key].content[content_type] = media_type + else: + _responses[_key] = Response(description=HTTP_STATUS.get(_key, ""), content={content_type: media_type}) + return + + schema = get_model_schema(_model, mode="serialization") + # OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__ + original_title = schema.get("title") or _model.__name__ + name = normalize_name(original_title) + + media_type = MediaType(**{"schema": Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"})}) + + if openapi_extra: + openapi_extra_keys = openapi_extra.keys() + if "example" in openapi_extra_keys: + media_type.example = openapi_extra.get("example") + if "examples" in openapi_extra_keys: + media_type.examples = openapi_extra.get("examples") + if "encoding" in openapi_extra_keys: + media_type.encoding = openapi_extra.get("encoding") + if _responses.get(_key): + _responses[_key].content[content_type] = media_type + else: + _responses[_key] = Response(description=HTTP_STATUS.get(_key, ""), content={content_type: media_type}) + + _schemas[name] = Schema(**schema) + definitions = schema.get("$defs") + if definitions: + # Add schema definitions to _schemas + for name, value in definitions.items(): + _schemas[normalize_name(name)] = Schema(**value) for key, response in responses.items(): - if response is None: + if isinstance(response, dict) and "model" in response: + response_model = response.get("model") + response_description = response.get("description") + response_headers = response.get("headers") + response_links = response.get("links") + else: + response_model = response + response_description = None + response_headers = None + response_links = None + + if response_model is None: # If the response is None, it means HTTP status code "204" (No Content) _responses[key] = Response(description=HTTP_STATUS.get(key, "")) - elif isinstance(response, dict): - response["description"] = response.get("description", HTTP_STATUS.get(key, "")) - _responses[key] = Response(**response) + elif isinstance(response_model, dict): + response_model["description"] = response_model.get("description", HTTP_STATUS.get(key, "")) + _responses[key] = Response(**response_model) + elif get_origin(response_model) in [UnionType, Union]: + for model in get_args(response_model): + _parse_response(key, model) else: - # OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__ - schema = get_model_schema(response, mode="serialization") - original_title = schema.get("title") or response.__name__ - name = normalize_name(original_title) - _responses[key] = Response( - description=HTTP_STATUS.get(key, ""), - content={"application/json": MediaType(schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"}))}, - ) - - model_config: DefaultDict[str, Any] = response.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - # Add additional information from model_config to the response - if "description" in openapi_extra_keys: - _responses[key].description = openapi_extra.get("description") - if "headers" in openapi_extra_keys: - _responses[key].headers = openapi_extra.get("headers") - if "links" in openapi_extra_keys: - _responses[key].links = openapi_extra.get("links") - _content = _responses[key].content - if "example" in openapi_extra_keys: - _content["application/json"].example = openapi_extra.get("example") # type: ignore - if "examples" in openapi_extra_keys: - _content["application/json"].examples = openapi_extra.get("examples") # type: ignore - if "encoding" in openapi_extra_keys: - _content["application/json"].encoding = openapi_extra.get("encoding") # type: ignore - _content.update(openapi_extra.get("content", {})) # type: ignore - - _schemas[name] = Schema(**schema) - definitions = schema.get("$defs") - if definitions: - # Add schema definitions to _schemas - for name, value in definitions.items(): - _schemas[normalize_name(name)] = Schema(**value) + _parse_response(key, response_model) + + if response_description is not None: + _responses[key].description = response_description + if response_headers is not None: + _responses[key].headers = response_headers + if response_links is not None: + _responses[key].links = response_links components_schemas.update(**_schemas) operation.responses = _responses @@ -397,6 +479,8 @@ def parse_parameters( *, components_schemas: dict | None = None, operation: Operation | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, ) -> ParametersTuple: """ @@ -407,6 +491,8 @@ def parse_parameters( func: The function to parse the parameters from. components_schemas: Dictionary to store the parsed components schemas (default: None). operation: Operation object to populate with parsed parameters (default: None). + request_body_description: A brief description of the request body (default: None). + request_body_required: Determines if the request body is required in the request (default: True). doc_ui: Flag indicating whether to return types for documentation UI (default: True). Returns: @@ -465,47 +551,31 @@ def parse_parameters( _content, _components_schemas = parse_form(form) components_schemas.update(**_components_schemas) request_body = RequestBody(content=_content, required=True) - model_config: DefaultDict[str, Any] = form.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - if "description" in openapi_extra_keys: - request_body.description = openapi_extra.get("description") - if "example" in openapi_extra_keys: - request_body.content["multipart/form-data"].example = openapi_extra.get("example") - if "examples" in openapi_extra_keys: - request_body.content["multipart/form-data"].examples = openapi_extra.get("examples") - if "encoding" in openapi_extra_keys: - request_body.content["multipart/form-data"].encoding = openapi_extra.get("encoding") + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if body: _content, _components_schemas = parse_body(body) components_schemas.update(**_components_schemas) request_body = RequestBody(content=_content, required=True) - model_config: DefaultDict[str, Any] = body.model_config # type: ignore - openapi_extra = model_config.get("openapi_extra", {}) - if openapi_extra: - openapi_extra_keys = openapi_extra.keys() - if "description" in openapi_extra_keys: - request_body.description = openapi_extra.get("description") - request_body.required = openapi_extra.get("required", True) - if "example" in openapi_extra_keys: - request_body.content["application/json"].example = openapi_extra.get("example") - if "examples" in openapi_extra_keys: - request_body.content["application/json"].examples = openapi_extra.get("examples") - if "encoding" in openapi_extra_keys: - request_body.content["application/json"].encoding = openapi_extra.get("encoding") + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if raw: _content = {} for mimetype in raw.mimetypes: - if mimetype.startswith("application/json"): - _content[mimetype] = MediaType(schema=Schema(type=DataType.OBJECT)) + if is_application_json(mimetype): + _content[mimetype] = MediaType(**{"schema": Schema(type=DataType.OBJECT)}) else: - _content[mimetype] = MediaType(schema=Schema(type=DataType.STRING)) + _content[mimetype] = MediaType(**{"schema": Schema(type=DataType.STRING)}) request_body = RequestBody(content=_content) + if request_body_description: + request_body.description = request_body_description + request_body.required = request_body_required operation.requestBody = request_body if parameters: @@ -595,3 +665,7 @@ def convert_responses_key_to_string(responses: ResponseDict) -> ResponseStrKeyDi def normalize_name(name: str) -> str: return re.sub(r"[^\w.\-]", "_", name) + + +def is_application_json(content_type: str) -> bool: + return "application" in content_type and "json" in content_type diff --git a/flask_openapi3/view.py b/flask_openapi3/view.py index c469a161..44cfc0b6 100644 --- a/flask_openapi3/view.py +++ b/flask_openapi3/view.py @@ -110,6 +110,8 @@ def doc( security: list[dict[str, list[Any]]] | None = None, servers: list[Server] | None = None, openapi_extensions: dict[str, Any] | None = None, + request_body_description: str | None = None, + request_body_required: bool | None = True, doc_ui: bool = True, ) -> Callable: """ @@ -127,6 +129,8 @@ def doc( security: A declaration of which security mechanisms can be used for this operation. servers: An alternative server array to service this operation. openapi_extensions: Allows extensions to the OpenAPI Schema. + request_body_description: A brief description of the request body. + request_body_required: Determines if the request body is required in the request. doc_ui: Declares this operation to be shown. Default to True. """ @@ -171,7 +175,13 @@ def decorator(func): parse_and_store_tags(tags, self.tags, self.tag_names, operation) # Parse parameters - parse_parameters(func, components_schemas=self.components_schemas, operation=operation) + parse_parameters( + func, + components_schemas=self.components_schemas, + operation=operation, + request_body_description=request_body_description, + request_body_required=request_body_required, + ) # Parse response get_responses(combine_responses, self.components_schemas, operation) diff --git a/tests/test_api_blueprint.py b/tests/test_api_blueprint.py index 6723cb30..00203cb1 100644 --- a/tests/test_api_blueprint.py +++ b/tests/test_api_blueprint.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel, Field -from flask_openapi3 import APIBlueprint, Info, OpenAPI, Tag +from flask_openapi3 import APIBlueprint, ExternalDocumentation, Info, OpenAPI, Server, Tag info = Info(title="book API", version="1.0.0") @@ -83,7 +83,14 @@ def update_book1(path: BookPath, body: BookBody): return {"code": 0, "message": "ok"} -@api.patch("/v2/book/") +@api.patch( + "/v2/book/", + servers=[Server(url="http://127.0.0.1:5000", variables=None)], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", description="Something great got better, get excited!" + ), + deprecated=True, +) def update_book1_v2(path: BookPath, body: BookBody): assert path.bid == 1 assert body.age == 3 diff --git a/tests/test_api_view.py b/tests/test_api_view.py index f2ef4245..b80d3235 100644 --- a/tests/test_api_view.py +++ b/tests/test_api_view.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel, Field -from flask_openapi3 import APIView, Info, OpenAPI, Tag +from flask_openapi3 import APIView, ExternalDocumentation, Info, OpenAPI, Server, Tag info = Info(title="book API", version="1.0.0") jwt = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"} @@ -62,7 +62,14 @@ def put(self, path: BookPath): print(path) return "put" - @api_view.doc(summary="delete book", deprecated=True) + @api_view.doc( + summary="delete book", + servers=[Server(url="http://127.0.0.1:5000", variables=None)], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", description="Something great got better, get excited!" + ), + deprecated=True, + ) def delete(self, path: BookPath): print(path) return "delete" diff --git a/tests/test_model_config.py b/tests/test_model_config.py index 8db1a64b..593654b3 100644 --- a/tests/test_model_config.py +++ b/tests/test_model_config.py @@ -34,7 +34,6 @@ class BookBody(BaseModel): model_config = dict( openapi_extra={ - "description": "This is post RequestBody", "example": {"age": 12, "author": "author1"}, "examples": { "example1": { @@ -73,7 +72,7 @@ def api_form(form: UploadFilesForm): print(form) # pragma: no cover -@app.post("/body", responses={"200": MessageResponse}) +@app.post("/body", request_body_description="This is post RequestBody", responses={"200": MessageResponse}) def api_error_json(body: BookBody): print(body) # pragma: no cover diff --git a/tests/test_multi_content_type.py b/tests/test_multi_content_type.py new file mode 100644 index 00000000..115f0da0 --- /dev/null +++ b/tests/test_multi_content_type.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# @Author : llc +# @Time : 2025/1/6 16:37 +from typing import Union + +import pytest +from flask import Request +from pydantic import BaseModel + +from flask_openapi3 import OpenAPI + +app = OpenAPI(__name__) +app.config["TESTING"] = True + + +class DogBody(BaseModel): + a: int = None + b: str = None + + model_config = {"openapi_extra": {"content_type": "application/vnd.dog+json"}} + + +class CatBody(BaseModel): + c: int = None + d: str = None + + model_config = {"openapi_extra": {"content_type": "application/vnd.cat+json"}} + + +class BsonModel(BaseModel): + e: int = None + f: str = None + + model_config = {"openapi_extra": {"content_type": "application/bson"}} + + +class ContentTypeModel(BaseModel): + model_config = {"openapi_extra": {"content_type": "text/csv"}} + + +@app.post("/a", responses={200: Union[DogBody, CatBody, ContentTypeModel, BsonModel]}) +def index_a(body: Union[DogBody, CatBody, ContentTypeModel, BsonModel]): + """ + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +@app.post("/b", responses={200: Union[ContentTypeModel, BsonModel]}) +def index_b(body: Union[ContentTypeModel, BsonModel]): + """ + This may be confusing, if the content-type is application/json, the type of body will be auto parsed to + DogBody or CatBody, otherwise it cannot be parsed to ContentTypeModel or BsonModel. + The body is equivalent to the request variable in Flask, and you can use body.data, body.text, etc ... + """ + print(body) + if isinstance(body, Request): + if body.mimetype == "text/csv": + # processing csv data + ... + elif body.mimetype == "application/bson": + # processing bson data + ... + else: + # DogBody or CatBody + ... + return {"hello": "world"} + + +@pytest.fixture +def client(): + client = app.test_client() + + return client + + +def test_openapi(client): + resp = client.get("/openapi/openapi.json") + assert resp.status_code == 200 + + resp = client.post("/a", json={"a": 1, "b": "2"}) + assert resp.status_code == 200 + + resp = client.post("/a", data="a,b,c\n1,2,3", headers={"Content-Type": "text/csv"}) + assert resp.status_code == 200 diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 46e01d39..3bf5baab 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -18,16 +18,17 @@ class BaseResponse(BaseModel): test: int - model_config = dict( - openapi_extra={ + @test_app.get( + "/test", + responses={ + "201": { + "model": BaseResponse, "description": "Custom description", "headers": {"location": {"description": "URL of the new resource", "schema": {"type": "string"}}}, - "content": {"text/plain": {"schema": {"type": "string"}}}, "links": {"dummy": {"description": "dummy link"}}, } - ) - - @test_app.get("/test", responses={"201": BaseResponse}) + }, + ) def endpoint_test(): return b"", 201 # pragma: no cover @@ -39,9 +40,7 @@ def endpoint_test(): "headers": {"location": {"description": "URL of the new resource", "schema": {"type": "string"}}}, "content": { # This content is coming from responses - "application/json": {"schema": {"$ref": "#/components/schemas/BaseResponse"}}, - # While this one comes from responses - "text/plain": {"schema": {"type": "string"}}, + "application/json": {"schema": {"$ref": "#/components/schemas/BaseResponse"}} }, "links": {"dummy": {"description": "dummy link"}}, } diff --git a/tests/test_restapi.py b/tests/test_restapi.py index 24e02ece..55088a2f 100644 --- a/tests/test_restapi.py +++ b/tests/test_restapi.py @@ -10,7 +10,7 @@ from flask import Response from pydantic import BaseModel, Field, RootModel -from flask_openapi3 import ExternalDocumentation, Info, OpenAPI, Tag +from flask_openapi3 import ExternalDocumentation, Info, OpenAPI, Server, Tag info = Info(title="book API", version="1.0.0") @@ -43,6 +43,8 @@ def get_operation_id_for_path_callback(*, name: str, path: str, method: str) -> class BookQuery(BaseModel): age: int | None = Field(None, description="Age") + author: str + none: None = None class BookBody(BaseModel): @@ -97,8 +99,10 @@ def client(): external_docs=ExternalDocumentation( url="https://www.openapis.org/", description="Something great got better, get excited!" ), + servers=[Server(url="http://127.0.0.1:5000", variables=None)], responses={"200": BookResponse}, security=security, + deprecated=True, ) def get_book(path: BookPath): """Get a book @@ -110,7 +114,7 @@ def get_book(path: BookPath): @app.get("/book", tags=[book_tag], responses={"200": BookListResponseV1}) -def get_books(query: BookBody): +def get_books(query: BookQuery): """get books to get all books """ diff --git a/tests/test_server.py b/tests/test_server.py index 731de4ce..c1b4dacd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,29 +3,33 @@ # @Time : 2024/11/10 12:17 from pydantic import ValidationError -from flask_openapi3 import Server, ServerVariable +from flask_openapi3 import ExternalDocumentation, OpenAPI, Server, ServerVariable def test_server_variable(): Server(url="http://127.0.0.1:5000", variables=None) + error = 0 try: variables = {"one": ServerVariable(default="one", enum=[])} - Server(url="http://127.0.0.1:5000", variables=variables) - error = 0 except ValidationError: error = 1 assert error == 1 - try: - variables = {"one": ServerVariable(default="one")} - Server(url="http://127.0.0.1:5000", variables=variables) - error = 0 - except ValidationError: - error = 1 + variables = {"one": ServerVariable(default="one")} + Server(url="http://127.0.0.1:5000", variables=variables) + error = 0 assert error == 0 - try: - variables = {"one": ServerVariable(default="one", enum=["one", "two"])} - Server(url="http://127.0.0.1:5000", variables=variables) - error = 0 - except ValidationError: - error = 1 + variables = {"one": ServerVariable(default="one", enum=["one", "two"])} + Server(url="http://127.0.0.1:5000", variables=variables) + error = 0 assert error == 0 + + app = OpenAPI( + __name__, + servers=[Server(url="http://127.0.0.1:5000", variables=None)], + external_docs=ExternalDocumentation( + url="https://www.openapis.org/", description="Something great got better, get excited!" + ), + ) + + assert "servers" in app.api_doc + assert "externalDocs" in app.api_doc