Skip to content

Commit c37661c

Browse files
committed
Reuse 'group_by' in validation rules
Replicates graphql/graphql-js@71c7a14
1 parent 91584ed commit c37661c

File tree

8 files changed

+117
-74
lines changed

8 files changed

+117
-74
lines changed

src/graphql/pyutils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
unregister_description,
1818
)
1919
from .did_you_mean import did_you_mean
20+
from .group_by import group_by
2021
from .identity_func import identity_func
2122
from .inspect import inspect
2223
from .is_awaitable import is_awaitable
@@ -38,6 +39,7 @@
3839
"cached_property",
3940
"did_you_mean",
4041
"Description",
42+
"group_by",
4143
"is_description",
4244
"register_description",
4345
"unregister_description",

src/graphql/pyutils/group_by.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from collections import defaultdict
2+
from typing import Callable, Collection, Dict, List, TypeVar
3+
4+
__all__ = ["group_by"]
5+
6+
K = TypeVar("K")
7+
T = TypeVar("T")
8+
9+
10+
def group_by(items: Collection[T], key_fn: Callable[[T], K]) -> Dict[K, List[T]]:
11+
"""Group an unsorted collection of items by a key derived via a function."""
12+
result: Dict[K, List[T]] = defaultdict(list)
13+
for item in items:
14+
key = key_fn(item)
15+
result[key].append(item)
16+
return result

src/graphql/validation/rules/unique_argument_definition_names.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from collections import defaultdict
21
from operator import attrgetter
3-
from typing import Any, Callable, Collection, Dict, List, TypeVar
2+
from typing import Any, Collection
43

54
from ...error import GraphQLError
65
from ...language import (
@@ -15,6 +14,7 @@
1514
VisitorAction,
1615
SKIP,
1716
)
17+
from ...pyutils import group_by
1818
from . import SDLValidationRule
1919

2020
__all__ = ["UniqueArgumentDefinitionNamesRule"]
@@ -81,16 +81,3 @@ def check_arg_uniqueness(
8181
)
8282
)
8383
return SKIP
84-
85-
86-
K = TypeVar("K")
87-
T = TypeVar("T")
88-
89-
90-
def group_by(items: Collection[T], key_fn: Callable[[T], K]) -> Dict[K, List[T]]:
91-
"""Group an unsorted collection of items by a key derived via a function."""
92-
result: Dict[K, List[T]] = defaultdict(list)
93-
for item in items:
94-
key = key_fn(item)
95-
result[key].append(item)
96-
return result
Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Dict
1+
from operator import attrgetter
2+
from typing import Any, Collection
23

34
from ...error import GraphQLError
4-
from ...language import ArgumentNode, NameNode, VisitorAction, SKIP
5-
from . import ASTValidationContext, ASTValidationRule
5+
from ...language import ArgumentNode, DirectiveNode, FieldNode
6+
from ...pyutils import group_by
7+
from . import ASTValidationRule
68

79
__all__ = ["UniqueArgumentNamesRule"]
810

@@ -16,26 +18,20 @@ class UniqueArgumentNamesRule(ASTValidationRule):
1618
See https://spec.graphql.org/draft/#sec-Argument-Names
1719
"""
1820

19-
def __init__(self, context: ASTValidationContext):
20-
super().__init__(context)
21-
self.known_arg_names: Dict[str, NameNode] = {}
21+
def enter_field(self, node: FieldNode, *_args: Any) -> None:
22+
self.check_arg_uniqueness(node.arguments)
2223

23-
def enter_field(self, *_args: Any) -> None:
24-
self.known_arg_names.clear()
24+
def enter_directive(self, node: DirectiveNode, *args: Any) -> None:
25+
self.check_arg_uniqueness(node.arguments)
2526

26-
def enter_directive(self, *_args: Any) -> None:
27-
self.known_arg_names.clear()
27+
def check_arg_uniqueness(self, argument_nodes: Collection[ArgumentNode]) -> None:
28+
seen_args = group_by(argument_nodes, attrgetter("name.value"))
2829

29-
def enter_argument(self, node: ArgumentNode, *_args: Any) -> VisitorAction:
30-
known_arg_names = self.known_arg_names
31-
arg_name = node.name.value
32-
if arg_name in known_arg_names:
33-
self.report_error(
34-
GraphQLError(
35-
f"There can be only one argument named '{arg_name}'.",
36-
[known_arg_names[arg_name], node.name],
30+
for arg_name, arg_nodes in seen_args.items():
31+
if len(arg_nodes) > 1:
32+
self.report_error(
33+
GraphQLError(
34+
f"There can be only one argument named '{arg_name}'.",
35+
[node.name for node in arg_nodes],
36+
)
3737
)
38-
)
39-
else:
40-
known_arg_names[arg_name] = node.name
41-
return SKIP
Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Dict
1+
from operator import attrgetter
2+
from typing import Any
23

34
from ...error import GraphQLError
4-
from ...language import NameNode, VariableDefinitionNode
5-
from . import ASTValidationContext, ASTValidationRule
5+
from ...language import OperationDefinitionNode
6+
from ...pyutils import group_by
7+
from . import ASTValidationRule
68

79
__all__ = ["UniqueVariableNamesRule"]
810

@@ -13,24 +15,20 @@ class UniqueVariableNamesRule(ASTValidationRule):
1315
A GraphQL operation is only valid if all its variables are uniquely named.
1416
"""
1517

16-
def __init__(self, context: ASTValidationContext):
17-
super().__init__(context)
18-
self.known_variable_names: Dict[str, NameNode] = {}
19-
20-
def enter_operation_definition(self, *_args: Any) -> None:
21-
self.known_variable_names.clear()
22-
23-
def enter_variable_definition(
24-
self, node: VariableDefinitionNode, *_args: Any
18+
def enter_operation_definition(
19+
self, node: OperationDefinitionNode, *_args: Any
2520
) -> None:
26-
known_variable_names = self.known_variable_names
27-
variable_name = node.variable.name.value
28-
if variable_name in known_variable_names:
29-
self.report_error(
30-
GraphQLError(
31-
f"There can be only one variable named '${variable_name}'.",
32-
[known_variable_names[variable_name], node.variable.name],
21+
variable_definitions = node.variable_definitions
22+
23+
seen_variable_definitions = group_by(
24+
variable_definitions, attrgetter("variable.name.value")
25+
)
26+
27+
for variable_name, variable_nodes in seen_variable_definitions.items():
28+
if len(variable_nodes) > 1:
29+
self.report_error(
30+
GraphQLError(
31+
f"There can be only one variable named '${variable_name}'.",
32+
[node.variable.name for node in variable_nodes],
33+
)
3334
)
34-
)
35-
else:
36-
known_variable_names[variable_name] = node.variable.name

tests/pyutils/test_group_by.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from graphql.pyutils import group_by
2+
3+
4+
def describe_group_by():
5+
def does_accept_an_empty_list():
6+
def key_fn(_x: str) -> str:
7+
raise TypeError("Unexpected call of key function.")
8+
9+
assert group_by([], key_fn) == {}
10+
11+
def does_not_change_order():
12+
def key_fn(_x: int) -> str:
13+
return "all"
14+
15+
assert group_by([3, 1, 5, 4, 2, 6], key_fn) == {
16+
"all": [3, 1, 5, 4, 2, 6],
17+
}
18+
19+
def can_group_by_odd_and_even():
20+
def key_fn(x: int) -> str:
21+
return "odd" if x % 2 else "even"
22+
23+
assert group_by([3, 1, 5, 4, 2, 6], key_fn) == {
24+
"odd": [3, 1, 5],
25+
"even": [4, 2, 6],
26+
}
27+
28+
def can_group_by_string_length():
29+
def key_fn(s: str) -> int:
30+
return len(s)
31+
32+
assert group_by(
33+
[
34+
"alpha",
35+
"beta",
36+
"gamma",
37+
"delta",
38+
"epsilon",
39+
"zeta",
40+
"eta",
41+
"iota",
42+
"kapp",
43+
"lambda",
44+
"my",
45+
"ny",
46+
"omikron",
47+
],
48+
key_fn,
49+
) == {
50+
2: ["my", "ny"],
51+
3: ["eta"],
52+
4: ["beta", "zeta", "iota", "kapp"],
53+
5: ["alpha", "gamma", "delta"],
54+
6: ["lambda"],
55+
7: ["epsilon", "omikron"],
56+
}

tests/validation/test_unique_argument_names.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,7 @@ def many_duplicate_field_arguments():
116116
[
117117
{
118118
"message": "There can be only one argument named 'arg1'.",
119-
"locations": [(3, 21), (3, 36)],
120-
},
121-
{
122-
"message": "There can be only one argument named 'arg1'.",
123-
"locations": [(3, 21), (3, 51)],
119+
"locations": [(3, 21), (3, 36), (3, 51)],
124120
},
125121
],
126122
)
@@ -150,11 +146,7 @@ def many_duplicate_directive_arguments():
150146
[
151147
{
152148
"message": "There can be only one argument named 'arg1'.",
153-
"locations": [(3, 32), (3, 47)],
154-
},
155-
{
156-
"message": "There can be only one argument named 'arg1'.",
157-
"locations": [(3, 32), (3, 62)],
149+
"locations": [(3, 32), (3, 47), (3, 62)],
158150
},
159151
],
160152
)

tests/validation/test_unique_variable_names.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ def duplicate_variable_names():
2828
[
2929
{
3030
"message": "There can be only one variable named '$x'.",
31-
"locations": [(2, 22), (2, 31)],
32-
},
33-
{
34-
"message": "There can be only one variable named '$x'.",
35-
"locations": [(2, 22), (2, 40)],
31+
"locations": [(2, 22), (2, 31), (2, 40)],
3632
},
3733
{
3834
"message": "There can be only one variable named '$x'.",

0 commit comments

Comments
 (0)