1+ import itertools
12from ..utils import type_from_ast , is_valid_literal_value
3+ from .utils import PairSet , DefaultOrderedDict
24from ..error import GraphQLError
35from ..type .definition import (
46 is_composite_type ,
57 is_input_type ,
68 is_leaf_type ,
9+ get_named_type ,
710 GraphQLNonNull ,
811 GraphQLList ,
912 GraphQLObjectType ,
@@ -243,7 +246,7 @@ def reduce_spread_fragments(spreads):
243246 )
244247 for fragment_definition in self .fragment_definitions
245248 if fragment_definition .name .value not in fragment_names_used
246- ]
249+ ]
247250
248251 if errors :
249252 return errors
@@ -295,11 +298,14 @@ def do_types_overlap(t1, t2):
295298
296299 @staticmethod
297300 def type_incompatible_spread_message (frag_name , parent_type , frag_type ):
298- return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}' .format (frag_name , parent_type , frag_type )
301+ return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}' .format (frag_name ,
302+ parent_type ,
303+ frag_type )
299304
300305 @staticmethod
301306 def type_incompatible_anon_spread_message (parent_type , frag_type ):
302- return 'Fragment cannot be spread here as objects of type {} can never be of type {}' .format (parent_type , frag_type )
307+ return 'Fragment cannot be spread here as objects of type {} can never be of type {}' .format (parent_type ,
308+ frag_type )
303309
304310
305311class NoFragmentCycles (ValidationRule ):
@@ -309,7 +315,7 @@ def __init__(self, context):
309315 node .name .value : self .gather_spreads (node )
310316 for node in context .get_ast ().definitions
311317 if isinstance (node , ast .FragmentDefinition )
312- }
318+ }
313319 self .known_to_lead_to_cycle = set ()
314320
315321 def enter_FragmentDefinition (self , node , * args ):
@@ -444,7 +450,7 @@ def leave_OperationDefinition(self, *args):
444450 )
445451 for variable_definition in self .variable_definitions
446452 if variable_definition .variable .name .value not in self .variable_name_used
447- ]
453+ ]
448454
449455 if errors :
450456 return errors
@@ -731,8 +737,233 @@ def var_type_allowed_for_type(cls, var_type, expected_type):
731737
732738 @staticmethod
733739 def bad_var_pos_message (var_name , var_type , expected_type ):
734- return 'Variable "{}" of type "{}" used in position expecting type "{}".' .format (var_name , var_type , expected_type )
740+ return 'Variable "{}" of type "{}" used in position expecting type "{}".' .format (var_name , var_type ,
741+ expected_type )
735742
736743
737744class OverlappingFieldsCanBeMerged (ValidationRule ):
738- pass
745+ def __init__ (self , context ):
746+ super (OverlappingFieldsCanBeMerged , self ).__init__ (context )
747+ self .compared_set = PairSet ()
748+
749+ def find_conflicts (self , field_map ):
750+ conflicts = []
751+ for response_name , fields in field_map .items ():
752+ field_len = len (fields )
753+ if field_len <= 1 :
754+ continue
755+
756+ for field_a in fields :
757+ for field_b in fields :
758+ conflict = self .find_conflict (response_name , field_a , field_b )
759+ if conflict :
760+ conflicts .append (conflict )
761+
762+ return conflicts
763+
764+ @staticmethod
765+ def ast_to_hashable (ast ):
766+ """
767+ This function will take an AST, and return a portion of it that is unique enough to identify the AST,
768+ but without the unhashable bits.
769+ """
770+ if not ast :
771+ return None
772+
773+ return ast .__class__ , ast .loc ['start' ], ast .loc ['end' ]
774+
775+ def find_conflict (self , response_name , pair1 , pair2 ):
776+ ast1 , def1 = pair1
777+ ast2 , def2 = pair2
778+
779+ ast1_hashable = self .ast_to_hashable (ast1 )
780+ ast2_hashable = self .ast_to_hashable (ast2 )
781+
782+ if ast1 is ast2 or self .compared_set .has (ast1_hashable , ast2_hashable ):
783+ return
784+
785+ self .compared_set .add (ast1_hashable , ast2_hashable )
786+
787+ name1 = ast1 .name .value
788+ name2 = ast2 .name .value
789+
790+ if name1 != name2 :
791+ return (
792+ (response_name , '{} and {} are different fields' .format (name1 , name2 )),
793+ (ast1 , ast2 )
794+ )
795+
796+ type1 = def1 and def1 .type
797+ type2 = def2 and def2 .type
798+
799+ if type1 and type2 and not self .same_type (type1 , type2 ):
800+ return (
801+ (response_name , 'they return differing types {} and {}' .format (type1 , type2 )),
802+ (ast1 , ast2 )
803+ )
804+
805+ if not self .same_arguments (ast1 .arguments , ast2 .arguments ):
806+ return (
807+ (response_name , 'they have differing arguments' ),
808+ (ast1 , ast2 )
809+ )
810+
811+ if not self .same_directives (ast1 .directives , ast2 .directives ):
812+ return (
813+ (response_name , 'they have differing directives' ),
814+ (ast1 , ast2 )
815+ )
816+
817+ selection_set1 = ast1 .selection_set
818+ selection_set2 = ast2 .selection_set
819+
820+ if selection_set1 and selection_set2 :
821+ visited_fragment_names = set ()
822+
823+ subfield_map = self .collect_field_asts_and_defs (
824+ get_named_type (type1 ),
825+ selection_set1 ,
826+ visited_fragment_names
827+ )
828+
829+ subfield_map = self .collect_field_asts_and_defs (
830+ get_named_type (type2 ),
831+ selection_set2 ,
832+ visited_fragment_names ,
833+ subfield_map
834+ )
835+
836+ conflicts = self .find_conflicts (subfield_map )
837+ if conflicts :
838+ return (
839+ (response_name , [conflict [0 ] for conflict in conflicts ]),
840+ tuple (itertools .chain ((ast1 , ast2 ), * [conflict [1 ] for conflict in conflicts ]))
841+ )
842+
843+ def leave_SelectionSet (self , node , * args ):
844+ field_map = self .collect_field_asts_and_defs (
845+ self .context .get_parent_type (),
846+ node
847+ )
848+
849+ conflicts = self .find_conflicts (field_map )
850+ if conflicts :
851+ return [
852+ GraphQLError (self .fields_conflict_message (reason_name , reason ), list (fields )) for
853+ (reason_name , reason ), fields in conflicts
854+ ]
855+
856+ @staticmethod
857+ def same_type (type1 , type2 ):
858+ return type1 .is_same_type (type2 )
859+
860+ @staticmethod
861+ def same_value (value1 , value2 ):
862+ return (not value1 and not value2 ) or print_ast (value1 ) == print_ast (value2 )
863+
864+ @classmethod
865+ def same_arguments (cls , arguments1 , arguments2 ):
866+ # Check to see if they are empty arguments or nones. If they are, we can
867+ # bail out early.
868+ if not (arguments1 or arguments2 ):
869+ return True
870+
871+ if len (arguments1 ) != len (arguments2 ):
872+ return False
873+
874+ arguments2_values_to_arg = {a .name .value : a for a in arguments2 }
875+
876+ for argument1 in arguments1 :
877+ argument2 = arguments2_values_to_arg .get (argument1 .name .value )
878+ if not argument2 :
879+ return False
880+
881+ if not cls .same_value (argument1 .value , argument2 .value ):
882+ return False
883+
884+ return True
885+
886+ @classmethod
887+ def same_directives (cls , directives1 , directives2 ):
888+ # Check to see if they are empty directives or nones. If they are, we can
889+ # bail out early.
890+ if not (directives1 or directives2 ):
891+ return True
892+
893+ if len (directives1 ) != len (directives2 ):
894+ return False
895+
896+ directives2_values_to_arg = {a .name .value : a for a in directives2 }
897+
898+ for directive1 in directives1 :
899+ directive2 = directives2_values_to_arg .get (directive1 .name .value )
900+ if not directive2 :
901+ return False
902+
903+ if not cls .same_arguments (directive1 .arguments , directive2 .arguments ):
904+ return False
905+
906+ return True
907+
908+ def collect_field_asts_and_defs (self , parent_type , selection_set , visited_fragment_names = None , ast_and_defs = None ):
909+ if visited_fragment_names is None :
910+ visited_fragment_names = set ()
911+
912+ if ast_and_defs is None :
913+ # An ordered dictionary is required, otherwise the error message will be out of order.
914+ # We need to preserve the order that the item was inserted into the dict, as that will dictate
915+ # in which order the reasons in the error message should show.
916+ # Otherwise, the error messages will be inconsistently ordered for the same AST.
917+ # And this can make it so that tests fail half the time, and fool a user into thinking that
918+ # the errors are different, when in-fact they are the same, just that the ordering of the reasons differ.
919+ ast_and_defs = DefaultOrderedDict (list )
920+
921+ for selection in selection_set .selections :
922+ if isinstance (selection , ast .Field ):
923+ field_name = selection .name .value
924+ field_def = None
925+ if isinstance (parent_type , (GraphQLObjectType , GraphQLInterfaceType )):
926+ field_def = parent_type .get_fields ().get (field_name )
927+
928+ response_name = selection .alias .value if selection .alias else field_name
929+ ast_and_defs [response_name ].append ((selection , field_def ))
930+
931+ elif isinstance (selection , ast .InlineFragment ):
932+ self .collect_field_asts_and_defs (
933+ type_from_ast (self .context .get_schema (), selection .type_condition ),
934+ selection .selection_set ,
935+ visited_fragment_names ,
936+ ast_and_defs
937+ )
938+
939+ elif isinstance (selection , ast .FragmentSpread ):
940+ fragment_name = selection .name .value
941+ if fragment_name in visited_fragment_names :
942+ continue
943+
944+ visited_fragment_names .add (fragment_name )
945+ fragment = self .context .get_fragment (fragment_name )
946+
947+ if not fragment :
948+ continue
949+
950+ self .collect_field_asts_and_defs (
951+ type_from_ast (self .context .get_schema (), fragment .type_condition ),
952+ fragment .selection_set ,
953+ visited_fragment_names ,
954+ ast_and_defs
955+ )
956+
957+ return ast_and_defs
958+
959+ @classmethod
960+ def fields_conflict_message (cls , reason_name , reason ):
961+ return 'Fields "{}" conflict because {}' .format (reason_name , cls .reason_message (reason ))
962+
963+ @classmethod
964+ def reason_message (cls , reason ):
965+ if isinstance (reason , list ):
966+ return ' and ' .join ('subfields "{}" conflict because {}' .format (reason_name , cls .reason_message (sub_reason ))
967+ for reason_name , sub_reason in reason )
968+
969+ return reason
0 commit comments