|
3 | 3 | # @Time : 2022/4/1 16:54 |
4 | 4 | import json |
5 | 5 | from json import JSONDecodeError |
6 | | -from typing import Any, Type, Optional, Dict, get_origin, get_args |
| 6 | +from typing import Any, Type, Optional, Dict |
7 | 7 |
|
8 | 8 | from flask import request, current_app, abort |
9 | 9 | from pydantic import ValidationError, BaseModel |
| 10 | +from pydantic.fields import FieldInfo |
| 11 | +from werkzeug.datastructures.structures import MultiDict |
10 | 12 |
|
11 | | -from .models import FileStorage |
12 | 13 |
|
| 14 | +def _get_list_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo): |
| 15 | + if model_field_value.alias and model.model_config.get("populate_by_name"): |
| 16 | + key = model_field_value.alias |
| 17 | + value = args.getlist(model_field_value.alias) or args.getlist(model_field_key) |
| 18 | + elif model_field_value.alias: |
| 19 | + key = model_field_value.alias |
| 20 | + value = args.getlist(model_field_value.alias) |
| 21 | + else: |
| 22 | + key = model_field_key |
| 23 | + value = args.getlist(model_field_key) |
| 24 | + |
| 25 | + return key, value |
| 26 | + |
| 27 | + |
| 28 | +def _get_value(model: Type[BaseModel], args: MultiDict, model_field_key: str, model_field_value: FieldInfo): |
| 29 | + if model_field_value.alias and model.model_config.get("populate_by_name"): |
| 30 | + key = model_field_value.alias |
| 31 | + value = args.get(model_field_value.alias) or args.get(model_field_key) |
| 32 | + elif model_field_value.alias: |
| 33 | + key = model_field_value.alias |
| 34 | + value = args.get(model_field_value.alias) |
| 35 | + else: |
| 36 | + key = model_field_key |
| 37 | + value = args.get(model_field_key) |
13 | 38 |
|
14 | | -def _validate_header(header: Type[BaseModel], func_kwargs): |
| 39 | + return key, value |
| 40 | + |
| 41 | + |
| 42 | +def _validate_header(header: Type[BaseModel], func_kwargs: dict): |
15 | 43 | request_headers = dict(request.headers) |
16 | | - for key, value in header.model_fields.items(): |
17 | | - key_title = key.replace("_", "-").title() |
18 | | - # Add original key |
19 | | - if key_title in request_headers.keys(): |
20 | | - if value.alias: |
21 | | - request_headers[value.alias] = request_headers[key] = request_headers[key_title] |
22 | | - else: |
23 | | - request_headers[key] = request_headers[key_title] |
24 | | - func_kwargs["header"] = header.model_validate(obj=request_headers) |
| 44 | + header_dict = {} |
| 45 | + for model_field_key, model_field_value in header.model_fields.items(): |
| 46 | + key_title = model_field_key.replace("_", "-").title() |
| 47 | + if model_field_value.alias and header.model_config.get("populate_by_name"): |
| 48 | + key = model_field_value.alias |
| 49 | + key_alias_title = model_field_value.alias.replace("_", "-").title() |
| 50 | + value = request_headers.get(key_alias_title) or request_headers.get(key_title) |
| 51 | + elif model_field_value.alias: |
| 52 | + key = model_field_value.alias |
| 53 | + key_alias_title = model_field_value.alias.replace("_", "-").title() |
| 54 | + value = request_headers.get(key_alias_title) |
| 55 | + else: |
| 56 | + key = model_field_key |
| 57 | + value = request_headers[key_title] |
| 58 | + if value is not None: |
| 59 | + header_dict[key] = value |
| 60 | + func_kwargs["header"] = header.model_validate(obj=header_dict) |
25 | 61 |
|
26 | 62 |
|
27 | | -def _validate_cookie(cookie: Type[BaseModel], func_kwargs): |
| 63 | +def _validate_cookie(cookie: Type[BaseModel], func_kwargs: dict): |
28 | 64 | request_cookies = dict(request.cookies) |
29 | 65 | func_kwargs["cookie"] = cookie.model_validate(obj=request_cookies) |
30 | 66 |
|
31 | 67 |
|
32 | | -def _validate_path(path: Type[BaseModel], path_kwargs, func_kwargs): |
| 68 | +def _validate_path(path: Type[BaseModel], path_kwargs: dict, func_kwargs: dict): |
33 | 69 | func_kwargs["path"] = path.model_validate(obj=path_kwargs) |
34 | 70 |
|
35 | 71 |
|
36 | | -def _validate_query(query: Type[BaseModel], func_kwargs): |
| 72 | +def _validate_query(query: Type[BaseModel], func_kwargs: dict): |
37 | 73 | request_args = request.args |
38 | 74 | query_dict = {} |
39 | | - for k, v in query.model_fields.items(): |
40 | | - if get_origin(v.annotation) is list: |
41 | | - value = request_args.getlist(v.alias or k) or request_args.getlist(k) |
| 75 | + model_properties = query.model_json_schema().get("properties", {}) |
| 76 | + for model_field_key, model_field_value in query.model_fields.items(): |
| 77 | + model_field_schema = model_properties.get(model_field_value.alias or model_field_key) |
| 78 | + if model_field_schema.get("type") == "array": |
| 79 | + key, value = _get_list_value(query, request_args, model_field_key, model_field_value) |
42 | 80 | else: |
43 | | - value = request_args.get(v.alias or k) or request_args.get(k) # type:ignore |
| 81 | + key, value = _get_value(query, request_args, model_field_key, model_field_value) |
44 | 82 | if value is not None and value != []: |
45 | | - query_dict[k] = value |
| 83 | + query_dict[key] = value |
46 | 84 | func_kwargs["query"] = query.model_validate(obj=query_dict) |
47 | 85 |
|
48 | 86 |
|
49 | | -def _validate_form(form: Type[BaseModel], func_kwargs): |
| 87 | +def _validate_form(form: Type[BaseModel], func_kwargs: dict): |
50 | 88 | request_form = request.form |
51 | 89 | request_files = request.files |
52 | 90 | form_dict = {} |
53 | | - for k, v in form.model_fields.items(): |
54 | | - if get_origin(v.annotation) is list: |
55 | | - if get_args(v.annotation)[0] is FileStorage: |
56 | | - value = request_files.getlist(v.alias or k) or request_files.getlist(k) |
| 91 | + model_properties = form.model_json_schema().get("properties", {}) |
| 92 | + for model_field_key, model_field_value in form.model_fields.items(): |
| 93 | + model_field_schema = model_properties.get(model_field_value.alias or model_field_key) |
| 94 | + if model_field_schema.get("type") == "array": |
| 95 | + if model_field_schema.get("items") == {"format": "binary", "type": "string"}: |
| 96 | + # list[FileStorage] |
| 97 | + key, value = _get_list_value(form, request_files, model_field_key, model_field_value) |
57 | 98 | else: |
58 | 99 | value = [] |
59 | | - for i in request_form.getlist(v.alias or k) or request_form.getlist(k): |
| 100 | + key, value_list = _get_list_value(form, request_form, model_field_key, model_field_value) |
| 101 | + for _value in value_list: |
60 | 102 | try: |
61 | | - value.append(json.loads(i)) |
| 103 | + value.append(json.loads(_value)) |
62 | 104 | except (JSONDecodeError, TypeError): |
63 | | - value.append(i) # type:ignore |
64 | | - elif v.annotation is FileStorage: |
65 | | - value = request_files.get(v.alias or k) or request_files.get(k) # type:ignore |
| 105 | + value.append(_value) |
| 106 | + elif model_field_schema.get("type") == "string" and model_field_schema.get("format") == "binary": |
| 107 | + # FileStorage |
| 108 | + key, value = _get_value(form, request_files, model_field_key, model_field_value) |
66 | 109 | else: |
67 | | - _value = request_form.get(v.alias or k) or request_form.get(k) |
| 110 | + key, _value = _get_value(form, request_form, model_field_key, model_field_value) |
68 | 111 | try: |
69 | | - value = json.loads(_value) # type:ignore |
| 112 | + value = json.loads(_value) |
70 | 113 | except (JSONDecodeError, TypeError): |
71 | | - value = _value # type:ignore |
| 114 | + value = _value |
72 | 115 | if value is not None and value != []: |
73 | | - form_dict[k] = value |
| 116 | + form_dict[key] = value |
74 | 117 | func_kwargs["form"] = form.model_validate(obj=form_dict) |
75 | 118 |
|
76 | 119 |
|
77 | | -def _validate_body(body: Type[BaseModel], func_kwargs): |
| 120 | +def _validate_body(body: Type[BaseModel], func_kwargs: dict): |
78 | 121 | obj = request.get_json(silent=True) |
79 | 122 | if isinstance(obj, str): |
80 | 123 | body_model = body.model_validate_json(json_data=obj) |
@@ -122,7 +165,7 @@ def _validate_request( |
122 | 165 | if cookie: |
123 | 166 | _validate_cookie(cookie, func_kwargs) |
124 | 167 | if path: |
125 | | - _validate_path(path, path_kwargs, func_kwargs) |
| 168 | + _validate_path(path, path_kwargs or {}, func_kwargs) |
126 | 169 | if query: |
127 | 170 | _validate_query(query, func_kwargs) |
128 | 171 | if form: |
|
0 commit comments