99 Mapping ,
1010 Optional ,
1111 Tuple ,
12+ TypeVar ,
1213 Union ,
1314 cast ,
1415)
@@ -133,6 +134,38 @@ def extend_schema(
133134 )
134135
135136
137+ TEN = TypeVar ("TEN" , bound = TypeExtensionNode )
138+
139+
140+ class TypeExtensionsMap :
141+ """Mappings from types to their extensions."""
142+
143+ scalar : DefaultDict [str , List [ScalarTypeExtensionNode ]]
144+ object : DefaultDict [str , List [ObjectTypeExtensionNode ]]
145+ interface : DefaultDict [str , List [InterfaceTypeExtensionNode ]]
146+ union : DefaultDict [str , List [UnionTypeExtensionNode ]]
147+ enum : DefaultDict [str , List [EnumTypeExtensionNode ]]
148+ input_object : DefaultDict [str , List [InputObjectTypeExtensionNode ]]
149+
150+ def __init__ (self ) -> None :
151+ self .scalar = defaultdict (list )
152+ self .object = defaultdict (list )
153+ self .interface = defaultdict (list )
154+ self .union = defaultdict (list )
155+ self .enum = defaultdict (list )
156+ self .input_object = defaultdict (list )
157+
158+ def for_node (self , node : TEN ) -> DefaultDict [str , List [TEN ]]:
159+ """Get type extensions map for the given node kind."""
160+ kind = node .kind
161+ try :
162+ kind = kind .removesuffix ("_type_extension" )
163+ except AttributeError : # pragma: no cover (Python < 3.9)
164+ if kind .endswith ("_type_extension" ):
165+ kind = kind [:- 15 ]
166+ return getattr (self , kind )
167+
168+
136169class ExtendSchemaImpl :
137170 """Helper class implementing the methods to extend a schema.
138171
@@ -143,11 +176,11 @@ class ExtendSchemaImpl:
143176 """
144177
145178 type_map : Dict [str , GraphQLNamedType ]
146- type_extensions_map : Dict [ str , Any ]
179+ type_extensions : TypeExtensionsMap
147180
148- def __init__ (self , type_extensions_map : Dict [ str , Any ] ):
181+ def __init__ (self , type_extensions : TypeExtensionsMap ):
149182 self .type_map = {}
150- self .type_extensions_map = type_extensions_map
183+ self .type_extensions = type_extensions
151184
152185 @classmethod
153186 def extend_schema_args (
@@ -164,7 +197,8 @@ def extend_schema_args(
164197
165198 # Collect the type definitions and extensions found in the document.
166199 type_defs : List [TypeDefinitionNode ] = []
167- type_extensions_map : DefaultDict [str , Any ] = defaultdict (list )
200+
201+ type_extensions = TypeExtensionsMap ()
168202
169203 # New directives and types are separate because a directives and types can have
170204 # the same name. For example, a type named "skip".
@@ -174,31 +208,28 @@ def extend_schema_args(
174208 # Schema extensions are collected which may add additional operation types.
175209 schema_extensions : List [SchemaExtensionNode ] = []
176210
211+ is_schema_changed = False
177212 for def_ in document_ast .definitions :
178213 if isinstance (def_ , SchemaDefinitionNode ):
179214 schema_def = def_
180215 elif isinstance (def_ , SchemaExtensionNode ):
181216 schema_extensions .append (def_ )
217+ elif isinstance (def_ , DirectiveDefinitionNode ):
218+ directive_defs .append (def_ )
182219 elif isinstance (def_ , TypeDefinitionNode ):
183220 type_defs .append (def_ )
184221 elif isinstance (def_ , TypeExtensionNode ):
185- extended_type_name = def_ .name .value
186- type_extensions_map [ extended_type_name ]. append ( def_ )
187- elif isinstance ( def_ , DirectiveDefinitionNode ):
188- directive_defs . append ( def_ )
222+ type_extensions . for_node ( def_ )[ def_ .name .value ]. append ( def_ )
223+ else :
224+ continue
225+ is_schema_changed = True
189226
190227 # If this document contains no new types, extensions, or directives then return
191228 # the same unmodified GraphQLSchema instance.
192- if (
193- not type_extensions_map
194- and not type_defs
195- and not directive_defs
196- and not schema_extensions
197- and not schema_def
198- ):
229+ if not is_schema_changed :
199230 return schema_kwargs
200231
201- self = cls (type_extensions_map )
232+ self = cls (type_extensions )
202233 for existing_type in schema_kwargs ["types" ] or ():
203234 self .type_map [existing_type .name ] = self .extend_named_type (existing_type )
204235 for type_node in type_defs :
@@ -311,7 +342,7 @@ def extend_input_object_type(
311342 type_ : GraphQLInputObjectType ,
312343 ) -> GraphQLInputObjectType :
313344 kwargs = type_ .to_kwargs ()
314- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
345+ extensions = tuple (self .type_extensions . input_object [kwargs ["name" ]])
315346
316347 return GraphQLInputObjectType (
317348 ** merge_kwargs (
@@ -325,7 +356,7 @@ def extend_input_object_type(
325356
326357 def extend_enum_type (self , type_ : GraphQLEnumType ) -> GraphQLEnumType :
327358 kwargs = type_ .to_kwargs ()
328- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
359+ extensions = tuple (self .type_extensions . enum [kwargs ["name" ]])
329360
330361 return GraphQLEnumType (
331362 ** merge_kwargs (
@@ -337,7 +368,7 @@ def extend_enum_type(self, type_: GraphQLEnumType) -> GraphQLEnumType:
337368
338369 def extend_scalar_type (self , type_ : GraphQLScalarType ) -> GraphQLScalarType :
339370 kwargs = type_ .to_kwargs ()
340- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
371+ extensions = tuple (self .type_extensions . scalar [kwargs ["name" ]])
341372
342373 specified_by_url = kwargs ["specified_by_url" ]
343374 for extension_node in extensions :
@@ -373,7 +404,7 @@ def extend_object_type_fields(
373404 # noinspection PyShadowingNames
374405 def extend_object_type (self , type_ : GraphQLObjectType ) -> GraphQLObjectType :
375406 kwargs = type_ .to_kwargs ()
376- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
407+ extensions = tuple (self .type_extensions . object [kwargs ["name" ]])
377408
378409 return GraphQLObjectType (
379410 ** merge_kwargs (
@@ -410,7 +441,7 @@ def extend_interface_type(
410441 self , type_ : GraphQLInterfaceType
411442 ) -> GraphQLInterfaceType :
412443 kwargs = type_ .to_kwargs ()
413- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
444+ extensions = tuple (self .type_extensions . interface [kwargs ["name" ]])
414445
415446 return GraphQLInterfaceType (
416447 ** merge_kwargs (
@@ -433,7 +464,7 @@ def extend_union_type_types(
433464
434465 def extend_union_type (self , type_ : GraphQLUnionType ) -> GraphQLUnionType :
435466 kwargs = type_ .to_kwargs ()
436- extensions = tuple (self .type_extensions_map [kwargs ["name" ]])
467+ extensions = tuple (self .type_extensions . union [kwargs ["name" ]])
437468
438469 return GraphQLUnionType (
439470 ** merge_kwargs (
@@ -626,7 +657,7 @@ def build_union_types(
626657 def build_object_type (
627658 self , ast_node : ObjectTypeDefinitionNode
628659 ) -> GraphQLObjectType :
629- extension_nodes = self .type_extensions_map [ast_node .name .value ]
660+ extension_nodes = self .type_extensions . object [ast_node .name .value ]
630661 all_nodes : List [Union [ObjectTypeDefinitionNode , ObjectTypeExtensionNode ]] = [
631662 ast_node ,
632663 * extension_nodes ,
@@ -644,7 +675,7 @@ def build_interface_type(
644675 self ,
645676 ast_node : InterfaceTypeDefinitionNode ,
646677 ) -> GraphQLInterfaceType :
647- extension_nodes = self .type_extensions_map [ast_node .name .value ]
678+ extension_nodes = self .type_extensions . interface [ast_node .name .value ]
648679 all_nodes : List [
649680 Union [InterfaceTypeDefinitionNode , InterfaceTypeExtensionNode ]
650681 ] = [ast_node , * extension_nodes ]
@@ -658,7 +689,7 @@ def build_interface_type(
658689 )
659690
660691 def build_enum_type (self , ast_node : EnumTypeDefinitionNode ) -> GraphQLEnumType :
661- extension_nodes = self .type_extensions_map [ast_node .name .value ]
692+ extension_nodes = self .type_extensions . enum [ast_node .name .value ]
662693 all_nodes : List [Union [EnumTypeDefinitionNode , EnumTypeExtensionNode ]] = [
663694 ast_node ,
664695 * extension_nodes ,
@@ -672,7 +703,7 @@ def build_enum_type(self, ast_node: EnumTypeDefinitionNode) -> GraphQLEnumType:
672703 )
673704
674705 def build_union_type (self , ast_node : UnionTypeDefinitionNode ) -> GraphQLUnionType :
675- extension_nodes = self .type_extensions_map [ast_node .name .value ]
706+ extension_nodes = self .type_extensions . union [ast_node .name .value ]
676707 all_nodes : List [Union [UnionTypeDefinitionNode , UnionTypeExtensionNode ]] = [
677708 ast_node ,
678709 * extension_nodes ,
@@ -688,7 +719,7 @@ def build_union_type(self, ast_node: UnionTypeDefinitionNode) -> GraphQLUnionTyp
688719 def build_scalar_type (
689720 self , ast_node : ScalarTypeDefinitionNode
690721 ) -> GraphQLScalarType :
691- extension_nodes = self .type_extensions_map [ast_node .name .value ]
722+ extension_nodes = self .type_extensions . scalar [ast_node .name .value ]
692723 return GraphQLScalarType (
693724 name = ast_node .name .value ,
694725 description = ast_node .description .value if ast_node .description else None ,
@@ -701,7 +732,7 @@ def build_input_object_type(
701732 self ,
702733 ast_node : InputObjectTypeDefinitionNode ,
703734 ) -> GraphQLInputObjectType :
704- extension_nodes = self .type_extensions_map [ast_node .name .value ]
735+ extension_nodes = self .type_extensions . input_object [ast_node .name .value ]
705736 all_nodes : List [
706737 Union [InputObjectTypeDefinitionNode , InputObjectTypeExtensionNode ]
707738 ] = [ast_node , * extension_nodes ]
0 commit comments