Skip to content

Commit 57c6e2a

Browse files
committed
Add enum type for visitor return values (#96)
1 parent ffdf1b3 commit 57c6e2a

38 files changed

+231
-134
lines changed

docs/modules/language.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,17 @@ The module also exports the following special symbols which can be used as
103103
return values in the :class:`Visitor` methods to signal particular actions:
104104

105105
.. data:: BREAK
106-
:annotation: = True
106+
:annotation: (same as ``True``)
107107

108108
This return value signals that no further nodes shall be visited.
109109

110110
.. data:: SKIP
111-
:annotation: = False
111+
:annotation: (same as ``False``)
112112

113113
This return value signals that the current node shall be skipped.
114114

115115
.. data:: REMOVE
116-
:annotation: = Ellipsis
116+
:annotation: (same as``Ellipsis``)
117117

118118
This return value signals that the current node shall be deleted.
119119

src/graphql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
visit,
182182
ParallelVisitor,
183183
Visitor,
184+
VisitorAction,
184185
BREAK,
185186
SKIP,
186187
REMOVE,
@@ -532,6 +533,7 @@
532533
"ParallelVisitor",
533534
"TypeInfoVisitor",
534535
"Visitor",
536+
"VisitorAction",
535537
"BREAK",
536538
"SKIP",
537539
"REMOVE",

src/graphql/language/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
visit,
2323
Visitor,
2424
ParallelVisitor,
25+
VisitorAction,
2526
BREAK,
2627
SKIP,
2728
REMOVE,
@@ -115,6 +116,7 @@
115116
"visit",
116117
"Visitor",
117118
"ParallelVisitor",
119+
"VisitorAction",
118120
"BREAK",
119121
"SKIP",
120122
"REMOVE",

src/graphql/language/visitor.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy
2+
from enum import Enum
23
from typing import (
34
Any,
45
Callable,
@@ -19,6 +20,7 @@
1920
__all__ = [
2021
"Visitor",
2122
"ParallelVisitor",
23+
"VisitorAction",
2224
"visit",
2325
"BREAK",
2426
"SKIP",
@@ -28,10 +30,26 @@
2830
]
2931

3032

31-
# Special return values for the visitor methods:
33+
class VisitorActionEnum(Enum):
34+
"""Special return values for the visitor methods.
35+
36+
You can also use the values of this enum directly.
37+
"""
38+
39+
BREAK = True
40+
SKIP = False
41+
REMOVE = Ellipsis
42+
43+
44+
VisitorAction = Optional[VisitorActionEnum]
45+
3246
# Note that in GraphQL.js these are defined differently:
3347
# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
34-
BREAK, SKIP, REMOVE, IDLE = True, False, Ellipsis, None
48+
49+
BREAK = VisitorActionEnum.BREAK
50+
SKIP = VisitorActionEnum.SKIP
51+
REMOVE = VisitorActionEnum.REMOVE
52+
IDLE = None
3553

3654
# Default map from visitor kinds to their traversable node attributes:
3755
QUERY_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = {
@@ -253,7 +271,7 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
253271
for edit_key, edit_value in edits:
254272
if in_array:
255273
edit_key -= edit_offset
256-
if in_array and edit_value is REMOVE:
274+
if in_array and (edit_value is REMOVE or edit_value is Ellipsis):
257275
node.pop(edit_key)
258276
edit_offset += 1
259277
else:
@@ -292,10 +310,10 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
292310
if visit_fn:
293311
result = visit_fn(visitor, node, key, parent, path, ancestors)
294312

295-
if result is BREAK:
313+
if result is BREAK or result is True:
296314
break
297315

298-
if result is SKIP:
316+
if result is SKIP or result is False:
299317
if not is_leaving:
300318
path_pop()
301319
continue
@@ -356,9 +374,9 @@ def enter(self, node, *args):
356374
fn = visitor.get_visit_fn(node.kind)
357375
if fn:
358376
result = fn(visitor, node, *args)
359-
if result is SKIP:
377+
if result is SKIP or result is False:
360378
skipping[i] = node
361-
elif result == BREAK:
379+
elif result is BREAK or result is True:
362380
skipping[i] = BREAK
363381
elif result is not None:
364382
return result
@@ -370,9 +388,13 @@ def leave(self, node, *args):
370388
fn = visitor.get_visit_fn(node.kind, is_leaving=True)
371389
if fn:
372390
result = fn(visitor, node, *args)
373-
if result == BREAK:
391+
if result is BREAK or result is True:
374392
skipping[i] = BREAK
375-
elif result is not None and result is not SKIP:
393+
elif (
394+
result is not None
395+
and result is not SKIP
396+
and result is not False
397+
):
376398
return result
377399
elif skipping[i] is node:
378400
skipping[i] = None

src/graphql/validation/rules/executable_definitions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
SchemaDefinitionNode,
99
SchemaExtensionNode,
1010
TypeDefinitionNode,
11+
VisitorAction,
12+
SKIP,
1113
)
1214
from . import ASTValidationRule
1315

@@ -21,7 +23,7 @@ class ExecutableDefinitionsRule(ASTValidationRule):
2123
operation or fragment definitions.
2224
"""
2325

24-
def enter_document(self, node: DocumentNode, *_args):
26+
def enter_document(self, node: DocumentNode, *_args) -> VisitorAction:
2527
for definition in node.definitions:
2628
if not isinstance(definition, ExecutableDefinitionNode):
2729
def_name = (
@@ -41,4 +43,4 @@ def enter_document(self, node: DocumentNode, *_args):
4143
f"The {def_name} definition is not executable.", definition,
4244
)
4345
)
44-
return self.SKIP
46+
return SKIP

src/graphql/validation/rules/fields_on_correct_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FieldsOnCorrectTypeRule(ValidationRule):
2727
type, or are an allowed meta field such as ``__typename``.
2828
"""
2929

30-
def enter_field(self, node: FieldNode, *_args):
30+
def enter_field(self, node: FieldNode, *_args) -> None:
3131
type_ = self.context.get_parent_type()
3232
if not type_:
3333
return

src/graphql/validation/rules/fragments_on_composite_types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from ...error import GraphQLError
2-
from ...language import FragmentDefinitionNode, InlineFragmentNode, print_ast
2+
from ...language import (
3+
FragmentDefinitionNode,
4+
InlineFragmentNode,
5+
print_ast,
6+
)
37
from ...type import is_composite_type
48
from ...utilities import type_from_ast
59
from . import ValidationRule
@@ -15,7 +19,7 @@ class FragmentsOnCompositeTypesRule(ValidationRule):
1519
must also be a composite type.
1620
"""
1721

18-
def enter_inline_fragment(self, node: InlineFragmentNode, *_args):
22+
def enter_inline_fragment(self, node: InlineFragmentNode, *_args) -> None:
1923
type_condition = node.type_condition
2024
if type_condition:
2125
type_ = type_from_ast(self.context.schema, type_condition)
@@ -29,7 +33,7 @@ def enter_inline_fragment(self, node: InlineFragmentNode, *_args):
2933
)
3034
)
3135

32-
def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args):
36+
def enter_fragment_definition(self, node: FragmentDefinitionNode, *_args) -> None:
3337
type_condition = node.type_condition
3438
type_ = type_from_ast(self.context.schema, type_condition)
3539
if type_ and not is_composite_type(type_):

src/graphql/validation/rules/known_argument_names.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from typing import cast, Dict, List, Union
22

33
from ...error import GraphQLError
4-
from ...language import ArgumentNode, DirectiveDefinitionNode, DirectiveNode, SKIP
4+
from ...language import (
5+
ArgumentNode,
6+
DirectiveDefinitionNode,
7+
DirectiveNode,
8+
SKIP,
9+
VisitorAction,
10+
)
511
from ...pyutils import did_you_mean, suggestion_list
612
from ...type import specified_directives
713
from . import ASTValidationRule, SDLValidationContext, ValidationContext
@@ -37,7 +43,7 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
3743

3844
self.directive_args = directive_args
3945

40-
def enter_directive(self, directive_node: DirectiveNode, *_args):
46+
def enter_directive(self, directive_node: DirectiveNode, *_args) -> VisitorAction:
4147
directive_name = directive_node.name.value
4248
known_args = self.directive_args.get(directive_name)
4349
if directive_node.arguments and known_args is not None:
@@ -67,7 +73,7 @@ class KnownArgumentNamesRule(KnownArgumentNamesOnDirectivesRule):
6773
def __init__(self, context: ValidationContext):
6874
super().__init__(context)
6975

70-
def enter_argument(self, arg_node: ArgumentNode, *args):
76+
def enter_argument(self, arg_node: ArgumentNode, *args) -> None:
7177
context = self.context
7278
arg_def = context.get_argument()
7379
field_def = context.get_field_def()

src/graphql/validation/rules/known_directives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
4141
]
4242
self.locations_map = locations_map
4343

44-
def enter_directive(self, node: DirectiveNode, _key, _parent, _path, ancestors):
44+
def enter_directive(
45+
self, node: DirectiveNode, _key, _parent, _path, ancestors
46+
) -> None:
4547
name = node.name.value
4648
locations = self.locations_map.get(name)
4749
if locations:

src/graphql/validation/rules/known_fragment_names.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class KnownFragmentNamesRule(ValidationRule):
1212
fragments defined in the same document.
1313
"""
1414

15-
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args):
15+
def enter_fragment_spread(self, node: FragmentSpreadNode, *_args) -> None:
1616
fragment_name = node.name.value
1717
fragment = self.context.get_fragment(fragment_name)
1818
if not fragment:

0 commit comments

Comments
 (0)