Skip to content

Commit 0b072eb

Browse files
committed
Optimize performance and unit testing
1 parent 3021cfa commit 0b072eb

32 files changed

+309
-322
lines changed

flask_openapi3/blueprint.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# @Author : llc
33
# @Time : 2022/4/1 16:54
4-
from copy import deepcopy
54
from typing import Optional, List, Dict, Any, Callable
65

76
from flask import Blueprint
@@ -64,9 +63,8 @@ def __init__(
6463
self.abp_tags = abp_tags or []
6564
self.abp_security = abp_security or []
6665

67-
abp_responses = abp_responses or {}
6866
# Convert key to string
69-
self.abp_responses = convert_responses_key_to_string(abp_responses)
67+
self.abp_responses = convert_responses_key_to_string(abp_responses or {})
7068

7169
self.doc_ui = doc_ui
7270

@@ -93,7 +91,7 @@ def register_api(self, api: "APIBlueprint") -> None:
9391
self.paths[uri] = path_item
9492

9593
# Merge component schemas from the nested APIBlueprint
96-
self.components_schemas.update(**api.components_schemas)
94+
self.components_schemas.update(api.components_schemas)
9795

9896
# Register the nested APIBlueprint as a blueprint
9997
self.register_blueprint(api)
@@ -145,44 +143,41 @@ def _collect_openapi_info(
145143
doc_ui: Declares this operation to be shown. Default to True.
146144
"""
147145
if self.doc_ui is True and doc_ui is True:
148-
if responses is None:
149-
new_responses = {}
150-
else:
151-
# Convert key to string
152-
new_responses = convert_responses_key_to_string(responses)
146+
# Convert key to string
147+
new_responses = convert_responses_key_to_string(responses or {})
148+
153149
# Global response: combine API responses
154-
combine_responses = deepcopy(self.abp_responses)
155-
combine_responses.update(**new_responses)
150+
combine_responses = {**self.abp_responses, **new_responses}
151+
156152
# Create operation
157153
operation = get_operation(
158154
func,
159155
summary=summary,
160156
description=description,
161157
openapi_extensions=openapi_extensions
162158
)
159+
163160
# Set external docs
164161
operation.externalDocs = external_docs
162+
165163
# Unique string used to identify the operation.
166164
operation.operationId = operation_id or self.operation_id_callback(
167165
name=self.name, path=rule, method=method
168166
)
167+
169168
# Only set `deprecated` if True, otherwise leave it as None
170169
operation.deprecated = deprecated
170+
171171
# Add security
172-
if security is None:
173-
security = []
174-
operation.security = security + self.abp_security or None
172+
operation.security = (security or []) + self.abp_security or None
173+
175174
# Add servers
176175
operation.servers = servers
176+
177177
# Store tags
178-
tags = tags + self.abp_tags if tags else self.abp_tags
178+
tags = (tags or []) + self.abp_tags
179179
parse_and_store_tags(tags, self.tags, self.tag_names, operation)
180-
# Parse parameters
181-
header, cookie, path, query, form, body, raw = parse_parameters(
182-
func,
183-
components_schemas=self.components_schemas,
184-
operation=operation
185-
)
180+
186181
# Parse response
187182
get_responses(combine_responses, self.components_schemas, operation)
188183

@@ -191,8 +186,8 @@ def _collect_openapi_info(
191186

192187
# Parse method
193188
parse_method(uri, method, self.paths, operation)
194-
return header, cookie, path, query, form, body, raw
195-
else:
189+
196190
# Parse parameters
197-
header, cookie, path, query, form, body, raw = parse_parameters(func, doc_ui=False)
198-
return header, cookie, path, query, form, body, raw
191+
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
192+
else:
193+
return parse_parameters(func, doc_ui=False)

flask_openapi3/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def openapi_command(output, _format, indent):
2424
if _format == "yaml":
2525
try:
2626
import yaml # type: ignore
27-
except ImportError:
27+
except ImportError: # pragma: no cover
2828
raise ImportError("pyyaml must be installed.")
2929
openapi = yaml.safe_dump(obj, allow_unicode=True)
3030
else:

flask_openapi3/models/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# @Time : 2023/7/4 9:35
44
from typing import TYPE_CHECKING, Dict
55

6-
if TYPE_CHECKING:
6+
if TYPE_CHECKING: # pragma: no cover
77
from .path_item import PathItem
88
else:
99
PathItem = "PathItem"

flask_openapi3/models/encoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .reference import Reference
99

10-
if TYPE_CHECKING:
10+
if TYPE_CHECKING: # pragma: no cover
1111
from .header import Header
1212
else:
1313
Header = "Header"

flask_openapi3/models/path_item.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .reference import Reference
1111
from .server import Server
1212

13-
if typing.TYPE_CHECKING:
13+
if typing.TYPE_CHECKING: # pragma: no cover
1414
from .operation import Operation
1515

1616

flask_openapi3/openapi.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import re
77
import sys
8-
from copy import deepcopy
98
from importlib import import_module
109
from typing import Optional, List, Dict, Union, Any, Type, Callable
1110

@@ -14,7 +13,7 @@
1413

1514
if sys.version_info >= (3, 10):
1615
from importlib.metadata import entry_points
17-
else:
16+
else: # pragma: no cover
1817
from importlib_metadata import entry_points # type: ignore
1918

2019
from .blueprint import APIBlueprint
@@ -108,10 +107,10 @@ def __init__(
108107
# Set security schemes, responses, paths and components
109108
self.security_schemes = security_schemes
110109

111-
responses = responses or {}
112110
# Convert key to string
113-
self.responses = convert_responses_key_to_string(responses)
111+
self.responses = convert_responses_key_to_string(responses or {})
114112

113+
# Initialize instance variables
115114
self.paths: Dict = dict()
116115
self.components_schemas: Dict = dict()
117116
self.components = Components()
@@ -132,7 +131,7 @@ def __init__(
132131
self.operation_id_callback: Callable = operation_id_callback
133132

134133
# Set OpenAPI extensions
135-
self.openapi_extensions = openapi_extensions or dict()
134+
self.openapi_extensions = openapi_extensions or {}
136135

137136
# Set HTTP Response of validation errors within OpenAPI
138137
self.validation_error_status = str(validation_error_status)
@@ -147,7 +146,7 @@ def __init__(
147146
self.cli.add_command(openapi_command) # type: ignore
148147

149148
# Initialize specification JSON
150-
self.spec_json: Dict = dict()
149+
self.spec_json: Dict = {}
151150

152151
def _init_doc(self) -> None:
153152
"""
@@ -174,6 +173,7 @@ def _init_doc(self) -> None:
174173
)
175174

176175
ui_templates = []
176+
# Iterate over all entry points in the "flask_openapi3.plugins" group
177177
for entry_point in entry_points(group="flask_openapi3.plugins"):
178178
try:
179179
module_path = entry_point.value
@@ -186,7 +186,7 @@ def _init_doc(self) -> None:
186186
bp = plugin_register(doc_url=self.doc_url.lstrip("/"))
187187
self.register_blueprint(bp, url_prefix=self.doc_prefix)
188188
ui_templates.append({"name": plugin_name, "display_name": plugin_display_name})
189-
except (ModuleNotFoundError, AttributeError):
189+
except (ModuleNotFoundError, AttributeError): # pragma: no cover
190190
import traceback
191191
print(f"Warning: plugin '{entry_point.value}' registration failed.")
192192
traceback.print_exc()
@@ -281,6 +281,7 @@ def register_api(self, api: APIBlueprint) -> None:
281281
if tag.name not in self.tag_names:
282282
# Append tag to the list of tags
283283
self.tags.append(tag)
284+
284285
# Append tag name to the list of tag names
285286
self.tag_names.append(tag.name)
286287

@@ -309,6 +310,7 @@ def register_api_view(self, api_view: APIView, view_kwargs: Optional[Dict[Any, A
309310
if tag.name not in self.tag_names:
310311
# Append tag to the list of tags
311312
self.tags.append(tag)
313+
312314
# Append tag name to the list of tag names
313315
self.tag_names.append(tag.name)
314316

@@ -369,14 +371,12 @@ def _collect_openapi_info(
369371
method: HTTP method for the operation. Defaults to GET.
370372
"""
371373
if doc_ui is True:
372-
if responses is None:
373-
new_responses = {}
374-
else:
375-
# Convert key to string
376-
new_responses = convert_responses_key_to_string(responses)
374+
# Convert key to string
375+
new_responses = convert_responses_key_to_string(responses or {})
376+
377377
# Global response: combine API responses
378-
combine_responses = deepcopy(self.responses)
379-
combine_responses.update(**new_responses)
378+
combine_responses = {**self.responses, **new_responses}
379+
380380
# Create operation
381381
operation = get_operation(
382382
func,
@@ -386,34 +386,34 @@ def _collect_openapi_info(
386386
)
387387
# Set external docs
388388
operation.externalDocs = external_docs
389+
389390
# Unique string used to identify the operation.
390391
operation.operationId = operation_id or self.operation_id_callback(
391392
name=func.__name__, path=rule, method=method
392393
)
394+
393395
# Only set `deprecated` if True, otherwise leave it as None
394396
operation.deprecated = deprecated
397+
395398
# Add security
396399
operation.security = security
400+
397401
# Add servers
398402
operation.servers = servers
403+
399404
# Store tags
400-
if tags is None:
401-
tags = []
402-
parse_and_store_tags(tags, self.tags, self.tag_names, operation)
403-
# Parse parameters
404-
header, cookie, path, query, form, body, raw = parse_parameters(
405-
func,
406-
components_schemas=self.components_schemas,
407-
operation=operation
408-
)
405+
parse_and_store_tags(tags or [], self.tags, self.tag_names, operation)
406+
409407
# Parse response
410408
get_responses(combine_responses, self.components_schemas, operation)
409+
411410
# Convert a route parameter format from /pet/<petId> to /pet/{petId}
412411
uri = re.sub(r"<([^<:]+:)?", "{", rule).replace(">", "}")
412+
413413
# Parse method
414414
parse_method(uri, method, self.paths, operation)
415-
return header, cookie, path, query, form, body, raw
416-
else:
415+
417416
# Parse parameters
418-
header, cookie, path, query, form, body, raw = parse_parameters(func, doc_ui=False)
419-
return header, cookie, path, query, form, body, raw
417+
return parse_parameters(func, components_schemas=self.components_schemas, operation=operation)
418+
else:
419+
return parse_parameters(func, doc_ui=False)

flask_openapi3/request.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,33 @@ def _validate_header(header: Type[BaseModel], func_kwargs):
1717
key_title = key.replace("_", "-").title()
1818
# Add original key
1919
if key_title in request_headers.keys():
20-
request_headers[value.alias or key] = request_headers[key_title]
21-
func_kwargs.update({"header": header.model_validate(obj=request_headers)})
20+
if value.alias:
21+
request_headers[value.alias] = request_headers[key] = request_headers[key_title]
22+
else:
23+
request_headers[key] = request_headers[key_title]
24+
func_kwargs["header"] = header.model_validate(obj=request_headers)
2225

2326

2427
def _validate_cookie(cookie: Type[BaseModel], func_kwargs):
2528
request_cookies = dict(request.cookies)
26-
func_kwargs.update({"cookie": cookie.model_validate(obj=request_cookies)})
29+
func_kwargs["cookie"] = cookie.model_validate(obj=request_cookies)
2730

2831

2932
def _validate_path(path: Type[BaseModel], path_kwargs, func_kwargs):
30-
func_kwargs.update({"path": path.model_validate(obj=path_kwargs)})
33+
func_kwargs["path"] = path.model_validate(obj=path_kwargs)
3134

3235

3336
def _validate_query(query: Type[BaseModel], func_kwargs):
3437
request_args = request.args
3538
query_dict = {}
3639
for k, v in query.model_fields.items():
3740
if get_origin(v.annotation) is list:
38-
value = request_args.getlist(v.alias or k)
41+
value = request_args.getlist(v.alias or k) or request_args.getlist(k)
3942
else:
40-
value = request_args.get(v.alias or k) # type:ignore
43+
value = request_args.get(v.alias or k) or request_args.get(k) # type:ignore
4144
if value is not None:
4245
query_dict[k] = value
43-
func_kwargs.update({"query": query.model_validate(obj=query_dict)})
46+
func_kwargs["query"] = query.model_validate(obj=query_dict)
4447

4548

4649
def _validate_form(form: Type[BaseModel], func_kwargs):
@@ -50,25 +53,25 @@ def _validate_form(form: Type[BaseModel], func_kwargs):
5053
for k, v in form.model_fields.items():
5154
if get_origin(v.annotation) is list:
5255
if get_args(v.annotation)[0] is FileStorage:
53-
value = request_files.getlist(v.alias or k)
56+
value = request_files.getlist(v.alias or k) or request_files.getlist(k)
5457
else:
5558
value = []
56-
for i in request_form.getlist(v.alias or k):
59+
for i in request_form.getlist(v.alias or k) or request_form.getlist(k):
5760
try:
5861
value.append(json.loads(i))
5962
except (JSONDecodeError, TypeError):
6063
value.append(i) # type:ignore
6164
elif v.annotation is FileStorage:
62-
value = request_files.get(v.alias or k) # type:ignore
65+
value = request_files.get(v.alias or k) or request_files.get(k) # type:ignore
6366
else:
64-
_value = request_form.get(v.alias or k)
67+
_value = request_form.get(v.alias or k) or request_form.get(k)
6568
try:
6669
value = json.loads(_value) # type:ignore
6770
except (JSONDecodeError, TypeError):
6871
value = _value # type:ignore
6972
if value is not None:
7073
form_dict[k] = value
71-
func_kwargs.update({"form": form.model_validate(obj=form_dict)})
74+
func_kwargs["form"] = form.model_validate(obj=form_dict)
7275

7376

7477
def _validate_body(body: Type[BaseModel], func_kwargs):
@@ -77,7 +80,7 @@ def _validate_body(body: Type[BaseModel], func_kwargs):
7780
body_model = body.model_validate_json(json_data=obj)
7881
else:
7982
body_model = body.model_validate(obj=obj)
80-
func_kwargs.update({"body": body_model})
83+
func_kwargs["body"] = body_model
8184

8285

8386
def _validate_request(
@@ -103,14 +106,14 @@ def _validate_request(
103106
path_kwargs: Path parameters.
104107
105108
Returns:
106-
Union[Response, Dict]: Request kwargs.
109+
Dict: Request kwargs.
107110
108111
Raises:
109112
ValidationError: If validation fails.
110113
"""
111114

112115
# Dictionary to store func kwargs
113-
func_kwargs: Dict = dict()
116+
func_kwargs: Dict = {}
114117

115118
try:
116119
# Validate header, cookie, path, and query parameters
@@ -127,7 +130,7 @@ def _validate_request(
127130
if body:
128131
_validate_body(body, func_kwargs)
129132
if raw:
130-
func_kwargs.update({"raw": request})
133+
func_kwargs["raw"] = request
131134
except ValidationError as e:
132135
# Create a response with validation error details
133136
validation_error_callback = getattr(current_app, "validation_error_callback")

0 commit comments

Comments
 (0)