1- from collections import OrderedDict
1+ from collections import Iterable , OrderedDict , defaultdict
22from functools import reduce
33
44from ..utils .type_comparators import is_equal_type , is_type_sub_type_of
@@ -23,9 +23,9 @@ class GraphQLSchema(object):
2323 mutation=MyAppMutationRootType
2424 )
2525 """
26- __slots__ = '_query' , '_mutation' , '_subscription' , '_type_map' , '_directives' ,
26+ __slots__ = '_query' , '_mutation' , '_subscription' , '_type_map' , '_directives' , '_implementations' , '_possible_type_map'
2727
28- def __init__ (self , query , mutation = None , subscription = None , directives = None ):
28+ def __init__ (self , query , mutation = None , subscription = None , directives = None , types = None ):
2929 assert isinstance (query , GraphQLObjectType ), 'Schema query must be Object Type but got: {}.' .format (query )
3030 if mutation :
3131 assert isinstance (mutation , GraphQLObjectType ), \
@@ -35,6 +35,10 @@ def __init__(self, query, mutation=None, subscription=None, directives=None):
3535 assert isinstance (subscription , GraphQLObjectType ), \
3636 'Schema subscription must be Object Type but got: {}.' .format (subscription )
3737
38+ if types :
39+ assert isinstance (types , Iterable ), \
40+ 'Schema types must be iterable if provided but got: {}.' .format (types )
41+
3842 self ._query = query
3943 self ._mutation = mutation
4044 self ._subscription = subscription
@@ -50,13 +54,20 @@ def __init__(self, query, mutation=None, subscription=None, directives=None):
5054 )
5155
5256 self ._directives = directives
53- self ._type_map = self ._build_type_map ()
57+ self ._possible_type_map = defaultdict (set )
58+ self ._type_map = self ._build_type_map (types )
59+ # Keep track of all implementations by interface name.
60+ self ._implementations = defaultdict (list )
61+ for type in self ._type_map .values ():
62+ if isinstance (type , GraphQLObjectType ):
63+ for interface in type .get_interfaces ():
64+ self ._implementations [interface .name ].append (type )
5465
5566 # Enforce correct interface implementations.
5667 for type in self ._type_map .values ():
5768 if isinstance (type , GraphQLObjectType ):
5869 for interface in type .get_interfaces ():
59- assert_object_implements_interface (type , interface )
70+ assert_object_implements_interface (self , type , interface )
6071
6172 def get_query_type (self ):
6273 return self ._query
@@ -83,11 +94,32 @@ def get_directive(self, name):
8394
8495 return None
8596
86- def _build_type_map (self ):
87- types = [self .get_query_type (), self .get_mutation_type (), self .get_subscription_type (), IntrospectionSchema ]
97+ def _build_type_map (self , _types ):
98+ types = [
99+ self .get_query_type (),
100+ self .get_mutation_type (),
101+ self .get_subscription_type (),
102+ IntrospectionSchema
103+ ]
104+ if _types :
105+ types += _types
106+
88107 type_map = reduce (type_map_reducer , types , OrderedDict ())
89108 return type_map
90109
110+ def get_possible_types (self , abstract_type ):
111+ if isinstance (abstract_type , GraphQLUnionType ):
112+ return abstract_type .get_types ()
113+ assert isinstance (abstract_type , GraphQLInterfaceType )
114+ return self ._implementations [abstract_type .name ]
115+
116+ def is_possible_type (self , abstract_type , possible_type ):
117+ if not self ._possible_type_map [abstract_type .name ]:
118+ possible_types = self .get_possible_types (abstract_type )
119+ self ._possible_type_map [abstract_type .name ].update ([p .name for p in possible_types ])
120+
121+ return possible_type .name in self ._possible_type_map [abstract_type .name ]
122+
91123
92124def type_map_reducer (map , type ):
93125 if not type :
@@ -107,8 +139,8 @@ def type_map_reducer(map, type):
107139
108140 reduced_map = map
109141
110- if isinstance (type , (GraphQLUnionType , GraphQLInterfaceType )):
111- for t in type .get_possible_types ():
142+ if isinstance (type , (GraphQLUnionType )):
143+ for t in type .get_types ():
112144 reduced_map = type_map_reducer (reduced_map , t )
113145
114146 if isinstance (type , GraphQLObjectType ):
@@ -129,7 +161,7 @@ def type_map_reducer(map, type):
129161 return reduced_map
130162
131163
132- def assert_object_implements_interface (object , interface ):
164+ def assert_object_implements_interface (schema , object , interface ):
133165 object_field_map = object .get_fields ()
134166 interface_field_map = interface .get_fields ()
135167
@@ -140,7 +172,7 @@ def assert_object_implements_interface(object, interface):
140172 interface , field_name , object
141173 )
142174
143- assert is_type_sub_type_of (object_field .type , interface_field .type ), (
175+ assert is_type_sub_type_of (schema , object_field .type , interface_field .type ), (
144176 '{}.{} expects type "{}" but {}.{} provides type "{}".'
145177 ).format (interface , field_name , interface_field .type , object , field_name , object_field .type )
146178
0 commit comments