|
| 1 | +from collections import defaultdict |
1 | 2 | from operator import attrgetter, itemgetter |
2 | 3 | from typing import Any, Collection, Dict, List, Optional, Set, Tuple, Union, cast |
3 | 4 |
|
|
11 | 12 | SchemaDefinitionNode, |
12 | 13 | SchemaExtensionNode, |
13 | 14 | ) |
14 | | -from ..pyutils import inspect |
| 15 | +from ..pyutils import and_list, inspect |
15 | 16 | from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of |
16 | 17 | from .definition import ( |
17 | 18 | GraphQLEnumType, |
@@ -105,19 +106,37 @@ def validate_root_types(self) -> None: |
105 | 106 | schema = self.schema |
106 | 107 | if not schema.query_type: |
107 | 108 | self.report_error("Query root type must be provided.", schema.ast_node) |
| 109 | + root_types_map: Dict[GraphQLObjectType, List[OperationType]] = defaultdict(list) |
| 110 | + |
108 | 111 | for operation_type in OperationType: |
109 | 112 | root_type = schema.get_root_type(operation_type) |
110 | | - if root_type and not is_object_type(root_type): |
111 | | - operation_type_str = operation_type.value.capitalize() |
112 | | - root_type_str = inspect(root_type) |
113 | | - if_provided_str = ( |
114 | | - "" if operation_type == operation_type.QUERY else " if provided" |
| 113 | + if root_type: |
| 114 | + if is_object_type(root_type): |
| 115 | + root_types_map[root_type].append(operation_type) |
| 116 | + else: |
| 117 | + operation_type_str = operation_type.value.capitalize() |
| 118 | + root_type_str = inspect(root_type) |
| 119 | + if_provided_str = ( |
| 120 | + "" if operation_type == operation_type.QUERY else " if provided" |
| 121 | + ) |
| 122 | + self.report_error( |
| 123 | + f"{operation_type_str} root type must be Object type" |
| 124 | + f"{if_provided_str}, it cannot be {root_type_str}.", |
| 125 | + get_operation_type_node(schema, operation_type) |
| 126 | + or root_type.ast_node, |
| 127 | + ) |
| 128 | + for root_type, operation_types in root_types_map.items(): |
| 129 | + if len(operation_types) > 1: |
| 130 | + operation_list = and_list( |
| 131 | + [operation_type.value for operation_type in operation_types] |
115 | 132 | ) |
116 | 133 | self.report_error( |
117 | | - f"{operation_type_str} root type must be Object type" |
118 | | - f"{if_provided_str}, it cannot be {root_type_str}.", |
119 | | - get_operation_type_node(schema, operation_type) |
120 | | - or root_type.ast_node, |
| 134 | + "All root types must be different," |
| 135 | + f" '{root_type.name}' type is used as {operation_list} root types.", |
| 136 | + [ |
| 137 | + get_operation_type_node(schema, operation_type) |
| 138 | + for operation_type in operation_types |
| 139 | + ], |
121 | 140 | ) |
122 | 141 |
|
123 | 142 | def validate_directives(self) -> None: |
|
0 commit comments