Skip to content

Commit 1b0ab87

Browse files
committed
Add GraphQLSchema.get_root_type and deprecate get_operation_root_type
Replicates graphql/graphql-js@96b146d
1 parent 3a80e38 commit 1b0ab87

File tree

9 files changed

+285
-220
lines changed

9 files changed

+285
-220
lines changed

src/graphql/__init__.py

Lines changed: 201 additions & 202 deletions
Large diffs are not rendered by default.

src/graphql/execution/execute.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
Path,
3838
Undefined,
3939
)
40-
from ..utilities.get_operation_root_type import get_operation_root_type
4140
from ..type import (
4241
GraphQLAbstractType,
4342
GraphQLField,
@@ -340,12 +339,19 @@ def execute_operation(
340339
341340
Implements the "Executing operations" section of the spec.
342341
"""
343-
type_ = get_operation_root_type(self.schema, operation)
344-
fields = collect_fields(
342+
root_type = self.schema.get_root_type(operation.operation)
343+
if root_type is None:
344+
raise GraphQLError(
345+
"Schema is not configured to execute"
346+
f" {operation.operation.value} operation.",
347+
operation,
348+
)
349+
350+
root_fields = collect_fields(
345351
self.schema,
346352
self.fragments,
347353
self.variable_values,
348-
type_,
354+
root_type,
349355
operation.selection_set,
350356
)
351357

@@ -360,7 +366,7 @@ def execute_operation(
360366
self.execute_fields_serially
361367
if operation.operation == OperationType.MUTATION
362368
else self.execute_fields
363-
)(type_, root_value, path, fields)
369+
)(root_type, root_value, path, root_fields)
364370
except GraphQLError as error:
365371
self.errors.append(error)
366372
return None

src/graphql/execution/subscribe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ..language import DocumentNode
2222
from ..pyutils import Path, inspect
2323
from ..type import GraphQLFieldResolver, GraphQLSchema
24-
from ..utilities import get_operation_root_type
2524
from .map_async_iterator import MapAsyncIterator
2625

2726
__all__ = ["subscribe", "create_source_event_stream"]
@@ -163,25 +162,32 @@ async def create_source_event_stream(
163162

164163
async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
165164
schema = context.schema
166-
type_ = get_operation_root_type(schema, context.operation)
167-
fields = collect_fields(
165+
166+
root_type = schema.subscription_type
167+
if root_type is None:
168+
raise GraphQLError(
169+
"Schema is not configured to execute subscription operation.",
170+
context.operation,
171+
)
172+
173+
root_fields = collect_fields(
168174
schema,
169175
context.fragments,
170176
context.variable_values,
171-
type_,
177+
root_type,
172178
context.operation.selection_set,
173179
)
174-
response_name, field_nodes = next(iter(fields.items()))
175-
field_def = get_field_def(schema, type_, field_nodes[0])
180+
response_name, field_nodes = next(iter(root_fields.items()))
181+
field_def = get_field_def(schema, root_type, field_nodes[0])
176182

177183
if not field_def:
178184
field_name = field_nodes[0].name.value
179185
raise GraphQLError(
180186
f"The subscription field '{field_name}' is not defined.", field_nodes
181187
)
182188

183-
path = Path(None, response_name, type_.name)
184-
info = context.build_resolve_info(field_def, field_nodes, type_, path)
189+
path = Path(None, response_name, root_type.name)
190+
info = context.build_resolve_info(field_def, field_nodes, root_type, path)
185191

186192
# Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
187193
# It differs from "ResolveFieldValue" due to providing a different `resolveFn`.

src/graphql/type/schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
from ..error import GraphQLError
15-
from ..language import ast
15+
from ..language import ast, OperationType
1616
from ..pyutils import inspect, is_collection, is_description, FrozenList
1717
from .definition import (
1818
GraphQLAbstractType,
@@ -326,6 +326,9 @@ def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema":
326326
assume_valid=True,
327327
)
328328

329+
def get_root_type(self, operation: OperationType) -> Optional[GraphQLObjectType]:
330+
return getattr(self, f"{operation.value}_type")
331+
329332
def get_type(self, name: str) -> Optional[GraphQLNamedType]:
330333
return self.type_map.get(name)
331334

src/graphql/utilities/get_operation_root_type.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ def get_operation_root_type(
1515
schema: GraphQLSchema,
1616
operation: Union[OperationDefinitionNode, OperationTypeDefinitionNode],
1717
) -> GraphQLObjectType:
18-
"""Extract the root type of the operation from the schema."""
18+
"""Extract the root type of the operation from the schema.
19+
20+
.. deprecated:: 3.2
21+
Please use `GraphQLSchema.getRootType` instead. Will be removed in v3.3.
22+
"""
1923
operation_type = operation.operation
2024
if operation_type == OperationType.QUERY:
2125
query_type = schema.query_type

src/graphql/utilities/type_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def enter_directive(self, node: DirectiveNode) -> None:
166166
self._directive = self._schema.get_directive(node.name.value)
167167

168168
def enter_operation_definition(self, node: OperationDefinitionNode) -> None:
169-
type_ = getattr(self._schema, f"{node.operation.value}_type")
170-
self._type_stack.append(type_ if is_object_type(type_) else None)
169+
root_type = self._schema.get_root_type(node.operation)
170+
self._type_stack.append(root_type if is_object_type(root_type) else None)
171171

172172
def enter_inline_fragment(self, node: InlineFragmentNode) -> None:
173173
type_condition_ast = node.type_condition

tests/execution/test_executor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,35 @@ class Data:
789789
result = execute_sync(schema, document, Data(), operation_name="S")
790790
assert result == ({"a": "b"}, None)
791791

792+
def resolves_to_an_error_if_schema_does_not_support_operation():
793+
schema = GraphQLSchema(assume_valid=True)
794+
795+
document = parse(
796+
"""
797+
query Q { __typename }
798+
mutation M { __typename }
799+
subscription S { __typename }
800+
"""
801+
)
802+
803+
with raises(
804+
GraphQLError,
805+
match=r"^Schema is not configured to execute query operation\.",
806+
):
807+
execute_sync(schema, document, operation_name="Q")
808+
809+
with raises(
810+
GraphQLError,
811+
match=r"^Schema is not configured to execute mutation operation\.",
812+
):
813+
execute_sync(schema, document, operation_name="M")
814+
815+
with raises(
816+
GraphQLError,
817+
match=r"^Schema is not configured to execute subscription operation\.",
818+
):
819+
execute_sync(schema, document, operation_name="S")
820+
792821
@mark.asyncio
793822
async def correct_field_ordering_despite_execution_order():
794823
schema = GraphQLSchema(

tests/execution/test_subscribe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,24 @@ async def throws_an_error_if_some_of_required_arguments_are_missing():
296296
with raises(TypeError, match="missing .* positional argument: 'document'"):
297297
await subscribe(schema=schema) # type: ignore
298298

299+
@mark.asyncio
300+
async def resolves_to_an_error_if_schema_does_not_support_subscriptions():
301+
schema = GraphQLSchema(query=DummyQueryType)
302+
document = parse("subscription { unknownField }")
303+
304+
result = await subscribe(schema, document)
305+
306+
assert result == (
307+
None,
308+
[
309+
{
310+
"message": "Schema is not configured to execute"
311+
" subscription operation.",
312+
"locations": [(1, 1)],
313+
}
314+
],
315+
)
316+
299317
@mark.asyncio
300318
async def resolves_to_an_error_for_unknown_subscription_field():
301319
schema = GraphQLSchema(

tests/utilities/test_get_operation_root_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_operation_node(doc: DocumentNode) -> OperationDefinitionNode:
2727
return operation_node
2828

2929

30-
def describe_get_operation_root_type():
30+
def describe_deprecated_get_operation_root_type():
3131
def gets_a_query_type_for_an_unnamed_operation_definition_node():
3232
test_schema = GraphQLSchema(query_type)
3333
doc = parse("{ field }")

0 commit comments

Comments
 (0)