Skip to content

Commit 1e0f9c8

Browse files
committed
added validation_utils test
1 parent 6a6ef5a commit 1e0f9c8

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

fastapi_jsonapi/validation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def extract_field_validators(
8181
exclude_for_field_names = exclude_for_field_names or set()
8282

8383
if include_for_field_names and exclude_for_field_names:
84-
exclude_for_field_names = include_for_field_names.difference(
84+
include_for_field_names = include_for_field_names.difference(
8585
exclude_for_field_names,
8686
)
8787

tests/test_api/test_validators.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import Dict, List, Optional, Type
2+
from typing import Dict, List, Optional, Set, Type
33

44
import pytest
55
from fastapi import FastAPI, status
@@ -12,6 +12,7 @@
1212
from fastapi_jsonapi import RoutersJSONAPI
1313
from fastapi_jsonapi.exceptions import BadRequest
1414
from fastapi_jsonapi.schema_builder import SchemaBuilder
15+
from fastapi_jsonapi.validation_utils import extract_field_validators
1516
from tests.fixtures.app import build_app_custom
1617
from tests.misc.utils import fake
1718
from tests.models import (
@@ -645,3 +646,43 @@ class Config:
645646
body=create_user_body,
646647
expected_detail=expected_detail,
647648
)
649+
650+
651+
class TestValidationUtils:
652+
@mark.parametrize(
653+
("include", "exclude", "expected"),
654+
[
655+
param({"item_1"}, None, {"item_1_validator"}, id="include"),
656+
param(None, {"item_1"}, {"item_2_validator"}, id="exclude"),
657+
param(None, None, {"item_1_validator", "item_2_validator"}, id="empty_params"),
658+
param({"item_1", "item_2"}, {"item_2"}, {"item_1_validator"}, id="intersection"),
659+
],
660+
)
661+
def test_extract_field_validators_args(
662+
self,
663+
include: Set[str],
664+
exclude: Set[str],
665+
expected: Set[str],
666+
):
667+
class ValidationSchema(BaseModel):
668+
item_1: str
669+
item_2: str
670+
671+
@validator("item_1", allow_reuse=True)
672+
def item_1_validator(cls, v):
673+
return v
674+
675+
@validator("item_2", allow_reuse=True)
676+
def item_2_validator(cls, v):
677+
return v
678+
679+
validators = extract_field_validators(
680+
ValidationSchema,
681+
include_for_field_names=include,
682+
exclude_for_field_names=exclude,
683+
)
684+
validator_func_names = {
685+
validator_item.__validator_config__[1].func.__name__ for validator_item in validators.values()
686+
}
687+
688+
assert expected == validator_func_names

0 commit comments

Comments
 (0)