1- from typing import Dict , List
1+ from typing import Dict , List , Union , cast
22
33from ...error import GraphQLError
4- from ...language import DirectiveNode , Node
5- from . import ASTValidationRule
4+ from ...language import DirectiveDefinitionNode , DirectiveNode , Node
5+ from ...type import specified_directives
6+ from . import ASTValidationRule , SDLValidationContext , ValidationContext
67
78__all__ = ["UniqueDirectivesPerLocationRule" , "duplicate_directive_message" ]
89
@@ -14,10 +15,28 @@ def duplicate_directive_message(directive_name: str) -> str:
1415class UniqueDirectivesPerLocationRule (ASTValidationRule ):
1516 """Unique directive names per location
1617
17- A GraphQL document is only valid if all directives at a given location are uniquely
18- named.
18+ A GraphQL document is only valid if all non-repeatable directives at a given
19+ location are uniquely named.
1920 """
2021
22+ context : Union [ValidationContext , SDLValidationContext ]
23+
24+ def __init__ (self , context : Union [ValidationContext , SDLValidationContext ]) -> None :
25+ super ().__init__ (context )
26+ unique_directive_map : Dict [str , bool ] = {}
27+
28+ schema = context .schema
29+ defined_directives = (
30+ schema .directives if schema else cast (List , specified_directives )
31+ )
32+ for directive in defined_directives :
33+ unique_directive_map [directive .name ] = not directive .is_repeatable
34+ ast_definitions = context .document .definitions
35+ for def_ in ast_definitions :
36+ if isinstance (def_ , DirectiveDefinitionNode ):
37+ unique_directive_map [def_ .name .value ] = not def_ .repeatable
38+ self .unique_directive_map = unique_directive_map
39+
2140 # Many different AST nodes may contain directives. Rather than listing them all,
2241 # just listen for entering any node, and check to see if it defines any directives.
2342 def enter (self , node : Node , * _args ):
@@ -26,12 +45,14 @@ def enter(self, node: Node, *_args):
2645 known_directives : Dict [str , DirectiveNode ] = {}
2746 for directive in directives :
2847 directive_name = directive .name .value
29- if directive_name in known_directives :
30- self .report_error (
31- GraphQLError (
32- duplicate_directive_message (directive_name ),
33- [known_directives [directive_name ], directive ],
48+
49+ if self .unique_directive_map .get (directive_name ):
50+ if directive_name in known_directives :
51+ self .report_error (
52+ GraphQLError (
53+ duplicate_directive_message (directive_name ),
54+ [known_directives [directive_name ], directive ],
55+ )
3456 )
35- )
36- else :
37- known_directives [directive_name ] = directive
57+ else :
58+ known_directives [directive_name ] = directive
0 commit comments