11from copy import copy
2+ from enum import Enum
23from typing import (
34 Any ,
45 Callable ,
1920__all__ = [
2021 "Visitor" ,
2122 "ParallelVisitor" ,
23+ "VisitorAction" ,
2224 "visit" ,
2325 "BREAK" ,
2426 "SKIP" ,
2830]
2931
3032
31- # Special return values for the visitor methods:
33+ class VisitorActionEnum (Enum ):
34+ """Special return values for the visitor methods.
35+
36+ You can also use the values of this enum directly.
37+ """
38+
39+ BREAK = True
40+ SKIP = False
41+ REMOVE = Ellipsis
42+
43+
44+ VisitorAction = Optional [VisitorActionEnum ]
45+
3246# Note that in GraphQL.js these are defined differently:
3347# BREAK = {}, SKIP = false, REMOVE = null, IDLE = undefined
34- BREAK , SKIP , REMOVE , IDLE = True , False , Ellipsis , None
48+
49+ BREAK = VisitorActionEnum .BREAK
50+ SKIP = VisitorActionEnum .SKIP
51+ REMOVE = VisitorActionEnum .REMOVE
52+ IDLE = None
3553
3654# Default map from visitor kinds to their traversable node attributes:
3755QUERY_DOCUMENT_KEYS : Dict [str , Tuple [str , ...]] = {
@@ -253,7 +271,7 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
253271 for edit_key , edit_value in edits :
254272 if in_array :
255273 edit_key -= edit_offset
256- if in_array and edit_value is REMOVE :
274+ if in_array and ( edit_value is REMOVE or edit_value is Ellipsis ) :
257275 node .pop (edit_key )
258276 edit_offset += 1
259277 else :
@@ -292,10 +310,10 @@ def visit(root: Node, visitor: Visitor, visitor_keys=None) -> Any:
292310 if visit_fn :
293311 result = visit_fn (visitor , node , key , parent , path , ancestors )
294312
295- if result is BREAK :
313+ if result is BREAK or result is True :
296314 break
297315
298- if result is SKIP :
316+ if result is SKIP or result is False :
299317 if not is_leaving :
300318 path_pop ()
301319 continue
@@ -356,9 +374,9 @@ def enter(self, node, *args):
356374 fn = visitor .get_visit_fn (node .kind )
357375 if fn :
358376 result = fn (visitor , node , * args )
359- if result is SKIP :
377+ if result is SKIP or result is False :
360378 skipping [i ] = node
361- elif result == BREAK :
379+ elif result is BREAK or result is True :
362380 skipping [i ] = BREAK
363381 elif result is not None :
364382 return result
@@ -370,9 +388,13 @@ def leave(self, node, *args):
370388 fn = visitor .get_visit_fn (node .kind , is_leaving = True )
371389 if fn :
372390 result = fn (visitor , node , * args )
373- if result == BREAK :
391+ if result is BREAK or result is True :
374392 skipping [i ] = BREAK
375- elif result is not None and result is not SKIP :
393+ elif (
394+ result is not None
395+ and result is not SKIP
396+ and result is not False
397+ ):
376398 return result
377399 elif skipping [i ] is node :
378400 skipping [i ] = None
0 commit comments