1+ from typing import Dict , Union
2+
13from ...error import GraphQLError
2- from ...language import DirectiveNode , FieldNode
3- from ...type import is_required_argument
4- from . import ValidationRule
4+ from ...language import (
5+ DirectiveDefinitionNode , DirectiveNode , FieldNode ,
6+ InputValueDefinitionNode , NonNullTypeNode , print_ast )
7+ from ...type import (
8+ GraphQLArgument , is_required_argument , is_type , specified_directives )
9+ from . import ASTValidationRule , SDLValidationContext , ValidationContext
510
611__all__ = [
712 'ProvidedRequiredArgumentsRule' ,
13+ 'ProvidedRequiredArgumentsOnDirectivesRule' ,
814 'missing_field_arg_message' , 'missing_directive_arg_message' ]
915
1016
@@ -20,37 +26,83 @@ def missing_directive_arg_message(
2026 f" of type '{ type_ } ' is required but not provided." )
2127
2228
23- class ProvidedRequiredArgumentsRule (ValidationRule ):
29+ class ProvidedRequiredArgumentsOnDirectivesRule (ASTValidationRule ):
30+ """Provided required arguments on directives
31+
32+ A directive is only valid if all required (non-null without a
33+ default value) arguments have been provided.
34+ """
35+
36+ context : Union [ValidationContext , SDLValidationContext ]
37+
38+ def __init__ (self , context : Union [
39+ ValidationContext , SDLValidationContext ]) -> None :
40+ super ().__init__ (context )
41+ required_args_map : Dict [str , Dict [str , GraphQLArgument ]] = {}
42+
43+ schema = context .schema
44+ defined_directives = (
45+ schema .directives if schema else specified_directives )
46+ for directive in defined_directives :
47+ required_args_map [directive .name ] = {
48+ name : arg for name , arg in directive .args .items ()
49+ if is_required_argument (arg )}
50+
51+ ast_definitions = context .document .definitions
52+ for def_ in ast_definitions :
53+ if isinstance (def_ , DirectiveDefinitionNode ):
54+ required_args_map [def_ .name .value ] = {
55+ arg .name .value : arg for arg in filter (
56+ is_required_argument_node , def_ .arguments )
57+ } if def_ .arguments else {}
58+
59+ self .required_args_map = required_args_map
60+
61+ def leave_directive (self , directive_node : DirectiveNode , * _args ):
62+ # Validate on leave to allow for deeper errors to appear first.
63+ directive_name = directive_node .name .value
64+ required_args = self .required_args_map .get (directive_name )
65+ if required_args :
66+
67+ arg_nodes = directive_node .arguments or []
68+ arg_node_set = {arg .name .value for arg in arg_nodes }
69+ for arg_name in required_args :
70+ if arg_name not in arg_node_set :
71+ arg_type = required_args [arg_name ].type
72+ self .report_error (GraphQLError (
73+ missing_directive_arg_message (
74+ directive_name , arg_name , str (arg_type )
75+ if is_type (arg_type ) else print_ast (arg_type )),
76+ [directive_node ]))
77+
78+
79+ class ProvidedRequiredArgumentsRule (ProvidedRequiredArgumentsOnDirectivesRule ):
2480 """Provided required arguments
2581
2682 A field or directive is only valid if all required (non-null without a
2783 default value) field arguments have been provided.
2884 """
2985
30- def leave_field (self , node : FieldNode , * _args ):
86+ context : ValidationContext
87+
88+ def __init__ (self , context : ValidationContext ) -> None :
89+ super ().__init__ (context )
90+
91+ def leave_field (self , field_node : FieldNode , * _args ):
3192 # Validate on leave to allow for deeper errors to appear first.
3293 field_def = self .context .get_field_def ()
3394 if not field_def :
3495 return self .SKIP
35- arg_nodes = node .arguments or []
96+ arg_nodes = field_node .arguments or []
3697
3798 arg_node_map = {arg .name .value : arg for arg in arg_nodes }
3899 for arg_name , arg_def in field_def .args .items ():
39100 arg_node = arg_node_map .get (arg_name )
40101 if not arg_node and is_required_argument (arg_def ):
41102 self .report_error (GraphQLError (missing_field_arg_message (
42- node .name .value , arg_name , str (arg_def .type )), [node ]))
103+ field_node .name .value , arg_name , str (arg_def .type )),
104+ [field_node ]))
43105
44- def leave_directive (self , node : DirectiveNode , * _args ):
45- # Validate on leave to allow for deeper errors to appear first.
46- directive_def = self .context .get_directive ()
47- if not directive_def :
48- return False
49- arg_nodes = node .arguments or []
50106
51- arg_node_map = {arg .name .value : arg for arg in arg_nodes }
52- for arg_name , arg_def in directive_def .args .items ():
53- arg_node = arg_node_map .get (arg_name )
54- if not arg_node and is_required_argument (arg_def ):
55- self .report_error (GraphQLError (missing_directive_arg_message (
56- node .name .value , arg_name , str (arg_def .type )), [node ]))
107+ def is_required_argument_node (arg : InputValueDefinitionNode ) -> bool :
108+ return isinstance (arg .type , NonNullTypeNode ) and arg .default_value is None
0 commit comments