@@ -634,43 +634,53 @@ def enter_OperationDefinition(self, *args):
634634 self .var_def_map = {}
635635 self .visited_fragment_names = set ()
636636
637- def enter_VariableDefinition (self , var_def_ast , * args ):
638- self .var_def_map [var_def_ast .variable .name .value ] = var_def_ast
637+ def enter_VariableDefinition (self , node , * args ):
638+ self .var_def_map [node .variable .name .value ] = node
639+
640+ def enter_Variable (self , node , * args ):
641+ var_name = node .name .value
642+ var_def = self .var_def_map .get (var_name )
639643
640- def enter_Variable (self , variable_ast , * args ):
641- var_name = variable_ast .name .value
642- var_def = self .var_def_map [var_name ]
643644 var_type = var_def and type_from_ast (self .context .get_schema (), var_def .type )
644645 input_type = self .context .get_input_type ()
645- if var_type and input_type and not self .var_type_allowed_for_type (self .effective_type (var_type , var_def ), input_type ):
646- return GraphQlError (self .bad_var_pos_message (var_name , var_type , input_type , [variable_ast ]))
647646
648- def enter_FragmentSpread (self , spread_ast , * args ):
649- if spread_ast .name .value in self .visited_fragment_names :
647+ if var_type and input_type and not self .var_type_allowed_for_type (self .effective_type (var_type , var_def ),
648+ input_type ):
649+ return GraphQLError (self .bad_var_pos_message (var_name , var_type , input_type ),
650+ [node ])
651+
652+ def enter_FragmentSpread (self , node , * args ):
653+ if node .name .value in self .visited_fragment_names :
650654 return False
651- self .visited_fragment_names .add (spread_ast .name .value );
655+
656+ self .visited_fragment_names .add (node .name .value )
652657
653658 @staticmethod
654659 def effective_type (var_type , var_def ):
655660 if not var_def .default_value or isinstance (var_def , GraphQLNonNull ):
656661 return var_type
662+
657663 return GraphQLNonNull (var_type )
658664
659665 @staticmethod
660666 def var_type_allowed_for_type (var_type , expected_type ):
661667 if isinstance (expected_type , GraphQLNonNull ):
662668 if isinstance (var_type , GraphQLNonNull ):
663669 return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type .of_type )
670+
664671 return False
672+
665673 if isinstance (var_type , GraphQLNonNull ):
666674 return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type )
675+
667676 if isinstance (var_type , GraphQLList ) and isinstance (expected_type , GraphQLList ):
668677 return VariablesInAllowedPosition .var_type_allowed_for_type (var_type .of_type , expected_type .of_type )
678+
669679 return var_type == expected_type
670680
671681 @staticmethod
672682 def bad_var_pos_message (var_name , var_type , expected_type ):
673- return 'Variable {} of type {} used in position expecting type {} ' .format (var_name , var_type , expected_type )
683+ return 'Variable "${}" of type "{}" used in position expecting type "{}". ' .format (var_name , var_type , expected_type )
674684
675685
676686class OverlappingFieldsCanBeMerged (ValidationRule ):
0 commit comments