|
1 | | -from typing import Dict, Union |
| 1 | +from typing import cast, Dict, List, Union |
2 | 2 |
|
3 | 3 | from ...error import GraphQLError |
4 | 4 | from ...language import ( |
5 | 5 | DirectiveDefinitionNode, DirectiveNode, FieldNode, |
6 | | - InputValueDefinitionNode, NonNullTypeNode, print_ast) |
| 6 | + InputValueDefinitionNode, NonNullTypeNode, TypeNode, print_ast) |
7 | 7 | from ...type import ( |
8 | 8 | GraphQLArgument, is_required_argument, is_type, specified_directives) |
9 | 9 | from . import ASTValidationRule, SDLValidationContext, ValidationContext |
@@ -38,12 +38,13 @@ class ProvidedRequiredArgumentsOnDirectivesRule(ASTValidationRule): |
38 | 38 | def __init__(self, context: Union[ |
39 | 39 | ValidationContext, SDLValidationContext]) -> None: |
40 | 40 | super().__init__(context) |
41 | | - required_args_map: Dict[str, Dict[str, GraphQLArgument]] = {} |
| 41 | + required_args_map: Dict[str, Dict[str, Union[ |
| 42 | + GraphQLArgument, InputValueDefinitionNode]]] = {} |
42 | 43 |
|
43 | 44 | schema = context.schema |
44 | 45 | defined_directives = ( |
45 | 46 | schema.directives if schema else specified_directives) |
46 | | - for directive in defined_directives: |
| 47 | + for directive in cast(List, defined_directives): |
47 | 48 | required_args_map[directive.name] = { |
48 | 49 | name: arg for name, arg in directive.args.items() |
49 | 50 | if is_required_argument(arg)} |
@@ -72,7 +73,8 @@ def leave_directive(self, directive_node: DirectiveNode, *_args): |
72 | 73 | self.report_error(GraphQLError( |
73 | 74 | missing_directive_arg_message( |
74 | 75 | directive_name, arg_name, str(arg_type) |
75 | | - if is_type(arg_type) else print_ast(arg_type)), |
| 76 | + if is_type(arg_type) |
| 77 | + else print_ast(cast(TypeNode, arg_type))), |
76 | 78 | [directive_node])) |
77 | 79 |
|
78 | 80 |
|
|
0 commit comments