1919
2020import libcst as cst
2121import libcst .matchers as m
22- from libcst .metadata import PositionProvider
22+ from libcst .metadata import CodeRange , PositionProvider
2323
2424from ..base import Statement
2525from .flake8asyncvisitor import Flake8AsyncVisitor_cst
2626from .helpers import (
27- AttributeCall ,
27+ MatchingCall ,
2828 cancel_scope_names ,
2929 disable_codes_by_default ,
3030 error_class_cst ,
3131 flatten_preserving_comments ,
3232 fnmatch_qualified_name_cst ,
3333 func_has_decorator ,
34+ get_matching_call_cst ,
3435 identifier_to_string ,
3536 iter_guaranteed_once_cst ,
36- with_has_call ,
3737)
3838
3939if TYPE_CHECKING :
@@ -374,6 +374,14 @@ def leave_Yield(
374374disable_codes_by_default ("ASYNC910" , "ASYNC911" , "ASYNC912" , "ASYNC913" )
375375
376376
377+ @dataclass
378+ class ContextManager :
379+ has_checkpoint : bool | None = None
380+ call : MatchingCall [cst .Call ] | None = None
381+ line : int | None = None
382+ column : int | None = None
383+
384+
377385@error_class_cst
378386class Visitor91X (Flake8AsyncVisitor_cst , CommonVisitors ):
379387 error_codes : Mapping [str , str ] = {
@@ -408,8 +416,7 @@ def __init__(self, *args: Any, **kwargs: Any):
408416 self .match_state = MatchState ()
409417
410418 # ASYNC100
411- self .has_checkpoint_stack : list [bool ] = []
412- self .node_dict : dict [cst .With , list [AttributeCall ]] = {}
419+ self .has_checkpoint_stack : list [ContextManager ] = []
413420 self .taskgroup_has_start_soon : dict [str , bool ] = {}
414421
415422 # --exception-suppress-context-manager
@@ -429,7 +436,11 @@ def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
429436 )
430437
431438 def checkpoint_cancel_point (self ) -> None :
432- self .has_checkpoint_stack = [True ] * len (self .has_checkpoint_stack )
439+ for cm in reversed (self .has_checkpoint_stack ):
440+ if cm .has_checkpoint :
441+ # Everything further down in the stack is already True.
442+ break
443+ cm .has_checkpoint = True
433444 # don't need to look for any .start_soon() calls
434445 self .taskgroup_has_start_soon .clear ()
435446
@@ -705,59 +716,106 @@ def _checkpoint_with(self, node: cst.With, entry: bool):
705716 # missing-checkpoint warning when there might in fact be one (i.e. a false alarm).
706717 def visit_With_body (self , node : cst .With ):
707718 self .save_state (node , "taskgroup_has_start_soon" , copy = True )
708- self ._checkpoint_with (node , entry = True )
719+
720+ is_suppressing = False
709721
710722 # if this might suppress exceptions, we cannot treat anything inside it as
711723 # checkpointing.
712724 if self ._is_exception_suppressing_context_manager (node ):
713725 self .save_state (node , "uncheckpointed_statements" , copy = True )
714726
715- if res := (
716- with_has_call (node , * cancel_scope_names )
717- or with_has_call (
718- node , "timeout" , "timeout_at" , base = ("asyncio" , "asyncio.timeouts" )
719- )
720- ):
721- pos = self .get_metadata (PositionProvider , node ).start # pyright: ignore
722- line : int = pos .line # pyright: ignore
723- column : int = pos .column # pyright: ignore
724- self .uncheckpointed_statements .add (
725- ArtificialStatement ("with" , line , column )
726- )
727- self .node_dict [node ] = res
728- self .has_checkpoint_stack .append (False )
729- else :
730- self .has_checkpoint_stack .append (True )
727+ for withitem in node .items :
728+ self .has_checkpoint_stack .append (ContextManager ())
729+ if get_matching_call_cst (
730+ withitem .item , "open_nursery" , "create_task_group"
731+ ):
732+ if withitem .asname is not None and isinstance (
733+ withitem .asname .name , cst .Name
734+ ):
735+ self .taskgroup_has_start_soon [withitem .asname .name .value ] = False
736+ self .checkpoint_schedule_point ()
737+ # Technically somebody could set open_nursery or create_task_group as
738+ # suppressing context managers, but we're not add logic for that.
739+ continue
740+
741+ if bool (getattr (node , "asynchronous" , False )):
742+ self .checkpoint ()
743+
744+ # not a clean function call
745+ if not isinstance (withitem .item , cst .Call ) or not isinstance (
746+ withitem .item .func , (cst .Name , cst .Attribute )
747+ ):
748+ continue
749+
750+ if (
751+ fnmatch_qualified_name_cst (
752+ (withitem .item .func ,),
753+ "contextlib.suppress" ,
754+ * self .suppress_imported_as ,
755+ * self .options .exception_suppress_context_managers ,
756+ )
757+ is not None
758+ ):
759+ # Don't re-update state if there's several suppressing cm's.
760+ if not is_suppressing :
761+ self .save_state (node , "uncheckpointed_statements" , copy = True )
762+ is_suppressing = True
763+ continue
764+
765+ if res := (
766+ get_matching_call_cst (withitem .item , * cancel_scope_names )
767+ or get_matching_call_cst (
768+ withitem .item ,
769+ "timeout" ,
770+ "timeout_at" ,
771+ base = "asyncio" ,
772+ )
773+ ):
774+ # typing issue: https://github.com/Instagram/LibCST/issues/1107
775+ pos = cst .ensure_type (
776+ self .get_metadata (PositionProvider , withitem ),
777+ CodeRange ,
778+ ).start
779+ self .uncheckpointed_statements .add (
780+ ArtificialStatement ("withitem" , pos .line , pos .column )
781+ )
782+
783+ cm = self .has_checkpoint_stack [- 1 ]
784+ cm .line = pos .line
785+ cm .column = pos .column
786+ cm .call = res
787+ cm .has_checkpoint = False
731788
732789 def leave_With (self , original_node : cst .With , updated_node : cst .With ):
733- # Uses leave_With instead of leave_With_body because we need access to both
734- # original and updated node
735- # ASYNC100
736- if not self .has_checkpoint_stack .pop ():
737- autofix = len (updated_node .items ) == 1
738- for res in self .node_dict [original_node ]:
790+ withitems = list (updated_node .items )
791+ for i in reversed (range (len (updated_node .items ))):
792+ cm = self .has_checkpoint_stack .pop ()
793+ # ASYNC100
794+ if cm .has_checkpoint is False :
795+ res = cm .call
796+ assert res is not None
739797 # bypass 910 & 911's should_autofix logic, which excludes asyncio
740- # (TODO: and uses self.noautofix ... which I don't remember what it's for)
741- autofix &= self . error (
742- res .node , res . base , res . function , error_code = "ASYNC100"
743- ) and super (). should_autofix ( res . node , code = "ASYNC100" )
744-
745- if autofix :
746- return flatten_preserving_comments ( updated_node )
747- # ASYNC912
748- else :
749- pos = self . get_metadata ( # pyright: ignore
750- PositionProvider , original_node
751- ). start # pyright: ignore
752- line : int = pos . line # pyright: ignore
753- column : int = pos . column # pyright: ignore
754- s = ArtificialStatement ( "with" , line , column )
755- if s in self . uncheckpointed_statements :
756- self . uncheckpointed_statements . remove ( s )
757- for res in self . node_dict [ original_node ]:
758- self .error ( res . node , error_code = "ASYNC912" )
759-
760- self .node_dict . pop ( original_node , None )
798+ if self .error (
799+ res . node , res . base , res . name , error_code = "ASYNC100"
800+ ) and super (). should_autofix ( res .node , code = "ASYNC100" ):
801+ if len ( withitems ) == 1 :
802+ # Remove this With node, bypassing later logic.
803+ return flatten_preserving_comments ( updated_node )
804+ if i == len ( withitems ) - 1 :
805+ # preserve trailing comma, or remove comma if there was none
806+ withitems [ - 2 ] = withitems [ - 2 ]. with_changes (
807+ comma = withitems [ - 1 ]. comma
808+ )
809+ withitems . pop ( i )
810+
811+ # ASYNC912
812+ elif cm . call is not None :
813+ assert cm . line is not None
814+ assert cm . column is not None
815+ s = ArtificialStatement ( "withitem" , cm . line , cm . column )
816+ if s in self .uncheckpointed_statements :
817+ self . uncheckpointed_statements . remove ( s )
818+ self .error ( cm . call . node , error_code = "ASYNC912" )
761819
762820 # if exception-suppressing, restore all uncheckpointed statements from
763821 # before the `with`.
@@ -767,7 +825,8 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
767825 self .uncheckpointed_statements .update (prev_checkpoints )
768826
769827 self ._checkpoint_with (original_node , entry = False )
770- return updated_node
828+
829+ return updated_node .with_changes (items = withitems )
771830
772831 # error if no checkpoint since earlier yield or function entry
773832 def leave_Yield (
0 commit comments