Skip to content

Commit 7961231

Browse files
luolingchunluolingchun
andauthored
Fix alias in query and form (#184)
* Fix alias in query and form * Fix list * Fix mypy --------- Co-authored-by: luolingchun <luolingchun@outloook.com>
1 parent 8f86f0b commit 7961231

File tree

2 files changed

+174
-45
lines changed

2 files changed

+174
-45
lines changed

flask_openapi3/request.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,78 +3,121 @@
33
# @Time : 2022/4/1 16:54
44
import json
55
from json import JSONDecodeError
6-
from typing import Any, Type, Optional, Dict, get_origin, get_args
6+
from typing import Any, Type, Optional, Dict
77

88
from flask import request, current_app, abort
99
from pydantic import ValidationError, BaseModel
10+
from pydantic.fields import FieldInfo
11+
from werkzeug.datastructures.structures import MultiDict
1012

11-
from .models import FileStorage
1213

14+
def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo):
15+
if model_field_value.alias and model.model_config.get("populate_by_name"):
16+
key = model_field_value.alias
17+
value = args.getlist(model_field_value.alias) or args.getlist(model_field_key)
18+
elif model_field_value.alias:
19+
key = model_field_value.alias
20+
value = args.getlist(model_field_value.alias)
21+
else:
22+
key = model_field_key
23+
value = args.getlist(model_field_key)
24+
25+
return key, value
26+
27+
28+
def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo):
29+
if model_field_value.alias and model.model_config.get("populate_by_name"):
30+
key = model_field_value.alias
31+
value = args.get(model_field_value.alias) or args.get(model_field_key)
32+
elif model_field_value.alias:
33+
key = model_field_value.alias
34+
value = args.get(model_field_value.alias)
35+
else:
36+
key = model_field_key
37+
value = args.get(model_field_key)
1338

14-
def _validate_header(header: Type[BaseModel], func_kwargs):
39+
return key, value
40+
41+
42+
def _validate_header(header: Type[BaseModel], func_kwargs: dict):
1543
request_headers = dict(request.headers)
16-
for key, value in header.model_fields.items():
17-
key_title = key.replace("_", "-").title()
18-
# Add original key
19-
if key_title in request_headers.keys():
20-
if value.alias:
21-
request_headers[value.alias] = request_headers[key] = request_headers[key_title]
22-
else:
23-
request_headers[key] = request_headers[key_title]
24-
func_kwargs["header"] = header.model_validate(obj=request_headers)
44+
header_dict = {}
45+
for model_field_key, model_field_value in header.model_fields.items():
46+
key_title = model_field_key.replace("_", "-").title()
47+
if model_field_value.alias and header.model_config.get("populate_by_name"):
48+
key = model_field_value.alias
49+
key_alias_title = model_field_value.alias.replace("_", "-").title()
50+
value = request_headers.get(key_alias_title) or request_headers.get(key_title)
51+
elif model_field_value.alias:
52+
key = model_field_value.alias
53+
key_alias_title = model_field_value.alias.replace("_", "-").title()
54+
value = request_headers.get(key_alias_title)
55+
else:
56+
key = model_field_key
57+
value = request_headers[key_title]
58+
if value is not None:
59+
header_dict[key] = value
60+
func_kwargs["header"] = header.model_validate(obj=header_dict)
2561

2662

27-
def _validate_cookie(cookie: Type[BaseModel], func_kwargs):
63+
def _validate_cookie(cookie: Type[BaseModel], func_kwargs: dict):
2864
request_cookies = dict(request.cookies)
2965
func_kwargs["cookie"] = cookie.model_validate(obj=request_cookies)
3066

3167

32-
def _validate_path(path: Type[BaseModel], path_kwargs, func_kwargs):
68+
def _validate_path(path: Type[BaseModel], path_kwargs: dict, func_kwargs: dict):
3369
func_kwargs["path"] = path.model_validate(obj=path_kwargs)
3470

3571

36-
def _validate_query(query: Type[BaseModel], func_kwargs):
72+
def _validate_query(query: Type[BaseModel], func_kwargs: dict):
3773
request_args = request.args
3874
query_dict = {}
39-
for k, v in query.model_fields.items():
40-
if get_origin(v.annotation) is list:
41-
value = request_args.getlist(v.alias or k) or request_args.getlist(k)
75+
model_properties = query.model_json_schema().get("properties", {})
76+
for model_field_key, model_field_value in query.model_fields.items():
77+
model_field_schema = model_properties.get(model_field_value.alias or model_field_key)
78+
if model_field_schema.get("type") == "array":
79+
key, value = _get_list_value(query, request_args, model_field_key, model_field_value)
4280
else:
43-
value = request_args.get(v.alias or k) or request_args.get(k) # type:ignore
81+
key, value = _get_value(query, request_args, model_field_key, model_field_value)
4482
if value is not None and value != []:
45-
query_dict[k] = value
83+
query_dict[key] = value
4684
func_kwargs["query"] = query.model_validate(obj=query_dict)
4785

4886

49-
def _validate_form(form: Type[BaseModel], func_kwargs):
87+
def _validate_form(form: Type[BaseModel], func_kwargs: dict):
5088
request_form = request.form
5189
request_files = request.files
5290
form_dict = {}
53-
for k, v in form.model_fields.items():
54-
if get_origin(v.annotation) is list:
55-
if get_args(v.annotation)[0] is FileStorage:
56-
value = request_files.getlist(v.alias or k) or request_files.getlist(k)
91+
model_properties = form.model_json_schema().get("properties", {})
92+
for model_field_key, model_field_value in form.model_fields.items():
93+
model_field_schema = model_properties.get(model_field_value.alias or model_field_key)
94+
if model_field_schema.get("type") == "array":
95+
if model_field_schema.get("items") == {"format": "binary", "type": "string"}:
96+
# list[FileStorage]
97+
key, value = _get_list_value(form, request_files, model_field_key, model_field_value)
5798
else:
5899
value = []
59-
for i in request_form.getlist(v.alias or k) or request_form.getlist(k):
100+
key, value_list = _get_list_value(form, request_form, model_field_key, model_field_value)
101+
for _value in value_list:
60102
try:
61-
value.append(json.loads(i))
103+
value.append(json.loads(_value))
62104
except (JSONDecodeError, TypeError):
63-
value.append(i) # type:ignore
64-
elif v.annotation is FileStorage:
65-
value = request_files.get(v.alias or k) or request_files.get(k) # type:ignore
105+
value.append(_value)
106+
elif model_field_schema.get("type") == "string" and model_field_schema.get("format") == "binary":
107+
# FileStorage
108+
key, value = _get_value(form, request_files, model_field_key, model_field_value)
66109
else:
67-
_value = request_form.get(v.alias or k) or request_form.get(k)
110+
key, _value = _get_value(form, request_form, model_field_key, model_field_value)
68111
try:
69-
value = json.loads(_value) # type:ignore
112+
value = json.loads(_value)
70113
except (JSONDecodeError, TypeError):
71-
value = _value # type:ignore
114+
value = _value
72115
if value is not None and value != []:
73-
form_dict[k] = value
116+
form_dict[key] = value
74117
func_kwargs["form"] = form.model_validate(obj=form_dict)
75118

76119

77-
def _validate_body(body: Type[BaseModel], func_kwargs):
120+
def _validate_body(body: Type[BaseModel], func_kwargs: dict):
78121
obj = request.get_json(silent=True)
79122
if isinstance(obj, str):
80123
body_model = body.model_validate_json(json_data=obj)
@@ -122,7 +165,7 @@ def _validate_request(
122165
if cookie:
123166
_validate_cookie(cookie, func_kwargs)
124167
if path:
125-
_validate_path(path, path_kwargs, func_kwargs)
168+
_validate_path(path, path_kwargs or {}, func_kwargs)
126169
if query:
127170
_validate_query(query, func_kwargs)
128171
if form:

tests/test_populate_by_name.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author : llc
33
# @Time : 2024/8/31 15:35
4+
from typing import Sequence, List, Tuple
5+
46
import pytest
57
from pydantic import BaseModel, Field
68

@@ -29,17 +31,101 @@ def get_book(query: BookQuery):
2931
"""
3032
get all books
3133
"""
32-
return {
33-
"code": 0,
34-
"message": "ok",
35-
"data": [
36-
{"bid": 1, "age": query.age, "author": query.author},
37-
{"bid": 2, "age": query.age, "author": query.author}
38-
]
39-
}
34+
print(query)
35+
return "ok"
36+
37+
38+
class QueryModel(BaseModel):
39+
aliased_field: str = Field(alias="aliasedField")
40+
41+
aliased_list_field: List[str] = Field(alias="aliasedListField")
42+
43+
44+
@app.get("/query-alias-test")
45+
def query_alias_test(query: QueryModel):
46+
assert query.aliased_field == "test"
47+
return "ok"
48+
49+
50+
class HeaderModel(BaseModel):
51+
Hello1: str = Field(..., alias="Hello2")
52+
53+
model_config = {"populate_by_name": True}
54+
55+
56+
@app.get("/header")
57+
def get_book_header(header: HeaderModel):
58+
return header.model_dump(by_alias=True)
59+
60+
61+
class TupleModel(BaseModel):
62+
values: Tuple[int, int]
63+
sequence: Sequence[int] = Field(alias="Sequence")
64+
65+
model_config = {"populate_by_name": True}
66+
67+
68+
@app.get("/tuple-test")
69+
def tuple_test(query: TupleModel):
70+
assert query.values == (2, 2)
71+
return b"", 200
72+
73+
74+
class AliasModel(BaseModel):
75+
aliased_field: str = Field(alias="aliasedField")
76+
77+
78+
@app.post("/form-alias-test")
79+
def alias_test(form: AliasModel):
80+
assert form.aliased_field == "test"
81+
return b"", 200
82+
83+
84+
def test_header(client):
85+
headers = {"Hello2": "111"}
86+
resp = client.get("/header", headers=headers)
87+
print(resp.json)
88+
assert resp.status_code == 200
89+
assert resp.json == headers
90+
91+
92+
def test_tuple_query(client):
93+
resp = client.get(
94+
"/tuple-test",
95+
query_string={"values": [2, 2], "sequence": [1, 2, 3]},
96+
)
97+
assert resp.status_code == 200
98+
99+
100+
def test_form_alias(client):
101+
resp = client.post(
102+
"/form-alias-test",
103+
data={"aliasedField": "test"},
104+
)
105+
assert resp.status_code == 200
106+
107+
resp = client.post(
108+
"/form-alias-test",
109+
data={"aliased_field": "test"},
110+
)
111+
assert resp.status_code == 422
112+
113+
114+
def test_query_alias(client):
115+
resp = client.get(
116+
"/query-alias-test",
117+
query_string={"aliasedField": "test", "aliasedListField": ["test"]},
118+
)
119+
assert resp.status_code == 200
120+
121+
resp = client.get(
122+
"/query-alias-test",
123+
data={"aliased_field": "test", "aliased_list_field": ["test"]},
124+
)
125+
assert resp.status_code == 422
40126

41127

42-
def test_openapi(client):
128+
def test_query_populate_by_name(client):
43129
resp = client.get("/book?age=1&author=aa")
44130
assert resp.status_code == 200
45131
resp = client.get("/book?age=1&author_name=aa")

0 commit comments

Comments
 (0)