Skip to content

Commit 79ed3fe

Browse files
luolingchunluolingchun
andauthored
Fix query, form, header model extra not honored (#201)
* Fix query, form, header model extra not honored * update --------- Co-authored-by: luolingchun <luolingchun@outloook.com>
1 parent 012d1f4 commit 79ed3fe

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

flask_openapi3/request.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, mo
4242
def _validate_header(header: Type[BaseModel], func_kwargs: dict):
4343
request_headers = dict(request.headers)
4444
header_dict = {}
45+
model_properties = header.model_json_schema().get("properties", {})
4546
for model_field_key, model_field_value in header.model_fields.items():
4647
key_title = model_field_key.replace("_", "-").title()
48+
model_field_schema = model_properties.get(model_field_value.alias or model_field_key)
4749
if model_field_value.alias and header.model_config.get("populate_by_name"):
4850
key = model_field_value.alias
4951
key_alias_title = model_field_value.alias.replace("_", "-").title()
@@ -57,6 +59,12 @@ def _validate_header(header: Type[BaseModel], func_kwargs: dict):
5759
value = request_headers[key_title]
5860
if value is not None:
5961
header_dict[key] = value
62+
if model_field_schema.get("type") == "null":
63+
header_dict[key] = value # type:ignore
64+
# extra keys
65+
for key, value in request_headers.items():
66+
if key not in header_dict.keys():
67+
header_dict[key] = value
6068
func_kwargs["header"] = header.model_validate(obj=header_dict)
6169

6270

@@ -81,6 +89,12 @@ def _validate_query(query: Type[BaseModel], func_kwargs: dict):
8189
key, value = _get_value(query, request_args, model_field_key, model_field_value)
8290
if value is not None and value != []:
8391
query_dict[key] = value
92+
if model_field_schema.get("type") == "null":
93+
query_dict[key] = value
94+
# extra keys
95+
for key, value in request_args.items():
96+
if key not in query_dict.keys():
97+
query_dict[key] = value
8498
func_kwargs["query"] = query.model_validate(obj=query_dict)
8599

86100

@@ -114,6 +128,12 @@ def _validate_form(form: Type[BaseModel], func_kwargs: dict):
114128
value = _value
115129
if value is not None and value != []:
116130
form_dict[key] = value
131+
if model_field_schema.get("type") == "null":
132+
form_dict[key] = value
133+
# extra keys
134+
for key, value in {**dict(request_form), **dict(request_files)}.items():
135+
if key not in form_dict.keys():
136+
form_dict[key] = value
117137
func_kwargs["form"] = form.model_validate(obj=form_dict)
118138

119139

tests/test_model_extra.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# -*- coding: utf-8 -*-
2+
# @Author : llc
3+
# @Time : 2024/11/20 14:45
4+
from typing import Optional
5+
6+
import pytest
7+
from pydantic import BaseModel, Field, ConfigDict
8+
9+
from flask_openapi3 import OpenAPI
10+
11+
app = OpenAPI(__name__)
12+
app.config["TESTING"] = True
13+
14+
15+
class BookQuery(BaseModel):
16+
age: Optional[int] = Field(None, description="Age")
17+
18+
model_config = ConfigDict(extra="allow")
19+
20+
21+
class BookForm(BaseModel):
22+
string: str
23+
24+
model_config = ConfigDict(extra="forbid")
25+
26+
27+
class BookHeader(BaseModel):
28+
api_key: str = Field(..., description="API Key")
29+
30+
model_config = ConfigDict(extra="forbid")
31+
32+
33+
@pytest.fixture
34+
def client():
35+
client = app.test_client()
36+
37+
return client
38+
39+
40+
@app.get("/book")
41+
def get_books(query: BookQuery):
42+
"""get books
43+
to get all books
44+
"""
45+
assert query.age == 3
46+
assert query.author == "joy"
47+
return {"code": 0, "message": "ok"}
48+
49+
50+
@app.post("/form")
51+
def api_form(form: BookForm):
52+
print(form)
53+
return {"code": 0, "message": "ok"}
54+
55+
56+
def test_query(client):
57+
resp = client.get("/book?age=3&author=joy")
58+
assert resp.status_code == 200
59+
60+
61+
@app.get("/header")
62+
def get_book(header: BookHeader):
63+
return header.model_dump(by_alias=True)
64+
65+
66+
def test_form(client):
67+
data = {
68+
"string": "a",
69+
"string_list": ["a", "b", "c"]
70+
}
71+
r = client.post("/form", data=data, content_type="multipart/form-data")
72+
assert r.status_code == 422
73+
74+
75+
def test_header(client):
76+
headers = {"Hello1": "111", "hello2": "222", "api_key": "333", "api_type": "A", "x-hello": "444"}
77+
resp = client.get("/header", headers=headers)
78+
assert resp.status_code == 422

0 commit comments

Comments
 (0)