|
1 | 1 | from copy import deepcopy |
2 | | -from typing import Dict, List, Optional, Type |
| 2 | +from typing import Dict, List, Optional, Set, Type |
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 | from fastapi import FastAPI, status |
|
12 | 12 | from fastapi_jsonapi import RoutersJSONAPI |
13 | 13 | from fastapi_jsonapi.exceptions import BadRequest |
14 | 14 | from fastapi_jsonapi.schema_builder import SchemaBuilder |
| 15 | +from fastapi_jsonapi.validation_utils import extract_field_validators |
15 | 16 | from tests.fixtures.app import build_app_custom |
16 | 17 | from tests.misc.utils import fake |
17 | 18 | from tests.models import ( |
@@ -645,3 +646,43 @@ class Config: |
645 | 646 | body=create_user_body, |
646 | 647 | expected_detail=expected_detail, |
647 | 648 | ) |
| 649 | + |
| 650 | + |
| 651 | +class TestValidationUtils: |
| 652 | + @mark.parametrize( |
| 653 | + ("include", "exclude", "expected"), |
| 654 | + [ |
| 655 | + param({"item_1"}, None, {"item_1_validator"}, id="include"), |
| 656 | + param(None, {"item_1"}, {"item_2_validator"}, id="exclude"), |
| 657 | + param(None, None, {"item_1_validator", "item_2_validator"}, id="empty_params"), |
| 658 | + param({"item_1", "item_2"}, {"item_2"}, {"item_1_validator"}, id="intersection"), |
| 659 | + ], |
| 660 | + ) |
| 661 | + def test_extract_field_validators_args( |
| 662 | + self, |
| 663 | + include: Set[str], |
| 664 | + exclude: Set[str], |
| 665 | + expected: Set[str], |
| 666 | + ): |
| 667 | + class ValidationSchema(BaseModel): |
| 668 | + item_1: str |
| 669 | + item_2: str |
| 670 | + |
| 671 | + @validator("item_1", allow_reuse=True) |
| 672 | + def item_1_validator(cls, v): |
| 673 | + return v |
| 674 | + |
| 675 | + @validator("item_2", allow_reuse=True) |
| 676 | + def item_2_validator(cls, v): |
| 677 | + return v |
| 678 | + |
| 679 | + validators = extract_field_validators( |
| 680 | + ValidationSchema, |
| 681 | + include_for_field_names=include, |
| 682 | + exclude_for_field_names=exclude, |
| 683 | + ) |
| 684 | + validator_func_names = { |
| 685 | + validator_item.__validator_config__[1].func.__name__ for validator_item in validators.values() |
| 686 | + } |
| 687 | + |
| 688 | + assert expected == validator_func_names |
0 commit comments