Skip to content

Commit 66b8383

Browse files
authored
handle multiple withitems checkpointing each other (#371)
* async100, async9xx: handle multiple withitems checkpointing each other, handle autofixing redundant timeouts with multiple withitems (although not when they also checkpoint each other.
1 parent 5bc2d1d commit 66b8383

File tree

14 files changed

+330
-128
lines changed

14 files changed

+330
-128
lines changed

docs/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ Changelog
44

55
`CalVer, YY.month.patch <https://calver.org/>`_
66

7+
25.4.3
8+
======
9+
- :ref:`ASYNC100 <async100>` can now autofix ``with`` statements with multiple items.
10+
- Fixed a bug where multiple ``with`` items would not interact, leading to ASYNC100 and ASYNC9xx false alarms. https://github.com/python-trio/flake8-async/issues/156
11+
712
25.4.2
813
======
914
- Add :ref:`ASYNC125 <async125>` constant-absolute-deadline

docs/usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``:
3333
minimum_pre_commit_version: '2.9.0'
3434
repos:
3535
- repo: https://github.com/python-trio/flake8-async
36-
rev: 25.4.2
36+
rev: 25.4.3
3737
hooks:
3838
- id: flake8-async
3939
# args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"]

flake8_async/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
41-
__version__ = "25.4.2"
41+
__version__ = "25.4.3"
4242

4343

4444
# taken from https://github.com/Zac-HD/shed

flake8_async/visitors/helpers.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ast
99
from dataclasses import dataclass
1010
from fnmatch import fnmatch
11-
from typing import TYPE_CHECKING, NamedTuple, TypeVar, Union
11+
from typing import TYPE_CHECKING, Generic, TypeVar, Union
1212

1313
import libcst as cst
1414
import libcst.matchers as m
@@ -38,6 +38,8 @@
3838
"T_EITHER", bound=Union[Flake8AsyncVisitor, Flake8AsyncVisitor_cst]
3939
)
4040

41+
T_Call = TypeVar("T_Call", bound=Union[cst.Call, ast.Call])
42+
4143

4244
def error_class(error_class: type[T]) -> type[T]:
4345
assert error_class.error_codes
@@ -289,8 +291,8 @@ def has_exception(node: ast.expr) -> str | None:
289291

290292

291293
@dataclass
292-
class MatchingCall:
293-
node: ast.Call
294+
class MatchingCall(Generic[T_Call]):
295+
node: T_Call
294296
name: str
295297
base: str
296298

@@ -301,7 +303,7 @@ def __str__(self) -> str:
301303
# convenience function used in a lot of visitors
302304
def get_matching_call(
303305
node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio")
304-
) -> MatchingCall | None:
306+
) -> MatchingCall[ast.Call] | None:
305307
if isinstance(base, str):
306308
base = (base,)
307309
if (
@@ -316,6 +318,23 @@ def get_matching_call(
316318

317319

318320
# ___ CST helpers ___
321+
def get_matching_call_cst(
322+
node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio")
323+
) -> MatchingCall[cst.Call] | None:
324+
if isinstance(base, str):
325+
base = (base,)
326+
if (
327+
isinstance(node, cst.Call)
328+
and isinstance(node.func, cst.Attribute)
329+
and node.func.attr.value in names
330+
and isinstance(node.func.value, (cst.Name, cst.Attribute))
331+
):
332+
attr_base = identifier_to_string(node.func.value)
333+
if attr_base is not None and attr_base in base:
334+
return MatchingCall(node, node.func.attr.value, attr_base)
335+
return None
336+
337+
319338
def oneof_names(*names: str):
320339
return m.OneOf(*map(m.Name, names))
321340

@@ -329,12 +348,6 @@ def list_contains(
329348
yield from (item for item in seq if m.matches(item, matcher))
330349

331350

332-
class AttributeCall(NamedTuple):
333-
node: cst.Call
334-
base: str
335-
function: str
336-
337-
338351
# the custom __or__ in libcst breaks pyright type checking. It's possible to use
339352
# `Union` as a workaround ... except pyupgrade will automatically replace that.
340353
# So we have to resort to specifying one of the base classes.
@@ -365,7 +378,7 @@ def identifier_to_string(node: cst.CSTNode) -> str | None:
365378

366379
def with_has_call(
367380
node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio")
368-
) -> list[AttributeCall]:
381+
) -> list[MatchingCall[cst.Call]]:
369382
"""Check if a with statement has a matching call, returning a list with matches.
370383
371384
`names` specify the names of functions to match, `base` specifies the
@@ -396,7 +409,7 @@ def with_has_call(
396409
)
397410
)
398411

399-
res_list: list[AttributeCall] = []
412+
res_list: list[MatchingCall[cst.Call]] = []
400413
for item in node.items:
401414
if res := m.extract(item.item, matcher):
402415
assert isinstance(item.item, cst.Call)
@@ -405,7 +418,9 @@ def with_has_call(
405418
base_string = identifier_to_string(res["base"])
406419
assert base_string is not None, "subscripts should never get matched"
407420
res_list.append(
408-
AttributeCall(item.item, base_string, res["function"].value)
421+
MatchingCall(
422+
node=item.item, base=base_string, name=res["function"].value
423+
)
409424
)
410425
return res_list
411426

flake8_async/visitors/visitor91x.py

Lines changed: 110 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,21 @@
1919

2020
import libcst as cst
2121
import libcst.matchers as m
22-
from libcst.metadata import PositionProvider
22+
from libcst.metadata import CodeRange, PositionProvider
2323

2424
from ..base import Statement
2525
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
2626
from .helpers import (
27-
AttributeCall,
27+
MatchingCall,
2828
cancel_scope_names,
2929
disable_codes_by_default,
3030
error_class_cst,
3131
flatten_preserving_comments,
3232
fnmatch_qualified_name_cst,
3333
func_has_decorator,
34+
get_matching_call_cst,
3435
identifier_to_string,
3536
iter_guaranteed_once_cst,
36-
with_has_call,
3737
)
3838

3939
if TYPE_CHECKING:
@@ -374,6 +374,14 @@ def leave_Yield(
374374
disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912", "ASYNC913")
375375

376376

377+
@dataclass
378+
class ContextManager:
379+
has_checkpoint: bool | None = None
380+
call: MatchingCall[cst.Call] | None = None
381+
line: int | None = None
382+
column: int | None = None
383+
384+
377385
@error_class_cst
378386
class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
379387
error_codes: Mapping[str, str] = {
@@ -408,8 +416,7 @@ def __init__(self, *args: Any, **kwargs: Any):
408416
self.match_state = MatchState()
409417

410418
# ASYNC100
411-
self.has_checkpoint_stack: list[bool] = []
412-
self.node_dict: dict[cst.With, list[AttributeCall]] = {}
419+
self.has_checkpoint_stack: list[ContextManager] = []
413420
self.taskgroup_has_start_soon: dict[str, bool] = {}
414421

415422
# --exception-suppress-context-manager
@@ -429,7 +436,11 @@ def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
429436
)
430437

431438
def checkpoint_cancel_point(self) -> None:
432-
self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack)
439+
for cm in reversed(self.has_checkpoint_stack):
440+
if cm.has_checkpoint:
441+
# Everything further down in the stack is already True.
442+
break
443+
cm.has_checkpoint = True
433444
# don't need to look for any .start_soon() calls
434445
self.taskgroup_has_start_soon.clear()
435446

@@ -705,59 +716,106 @@ def _checkpoint_with(self, node: cst.With, entry: bool):
705716
# missing-checkpoint warning when there might in fact be one (i.e. a false alarm).
706717
def visit_With_body(self, node: cst.With):
707718
self.save_state(node, "taskgroup_has_start_soon", copy=True)
708-
self._checkpoint_with(node, entry=True)
719+
720+
is_suppressing = False
709721

710722
# if this might suppress exceptions, we cannot treat anything inside it as
711723
# checkpointing.
712724
if self._is_exception_suppressing_context_manager(node):
713725
self.save_state(node, "uncheckpointed_statements", copy=True)
714726

715-
if res := (
716-
with_has_call(node, *cancel_scope_names)
717-
or with_has_call(
718-
node, "timeout", "timeout_at", base=("asyncio", "asyncio.timeouts")
719-
)
720-
):
721-
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
722-
line: int = pos.line # pyright: ignore
723-
column: int = pos.column # pyright: ignore
724-
self.uncheckpointed_statements.add(
725-
ArtificialStatement("with", line, column)
726-
)
727-
self.node_dict[node] = res
728-
self.has_checkpoint_stack.append(False)
729-
else:
730-
self.has_checkpoint_stack.append(True)
727+
for withitem in node.items:
728+
self.has_checkpoint_stack.append(ContextManager())
729+
if get_matching_call_cst(
730+
withitem.item, "open_nursery", "create_task_group"
731+
):
732+
if withitem.asname is not None and isinstance(
733+
withitem.asname.name, cst.Name
734+
):
735+
self.taskgroup_has_start_soon[withitem.asname.name.value] = False
736+
self.checkpoint_schedule_point()
737+
# Technically somebody could set open_nursery or create_task_group as
738+
# suppressing context managers, but we're not add logic for that.
739+
continue
740+
741+
if bool(getattr(node, "asynchronous", False)):
742+
self.checkpoint()
743+
744+
# not a clean function call
745+
if not isinstance(withitem.item, cst.Call) or not isinstance(
746+
withitem.item.func, (cst.Name, cst.Attribute)
747+
):
748+
continue
749+
750+
if (
751+
fnmatch_qualified_name_cst(
752+
(withitem.item.func,),
753+
"contextlib.suppress",
754+
*self.suppress_imported_as,
755+
*self.options.exception_suppress_context_managers,
756+
)
757+
is not None
758+
):
759+
# Don't re-update state if there's several suppressing cm's.
760+
if not is_suppressing:
761+
self.save_state(node, "uncheckpointed_statements", copy=True)
762+
is_suppressing = True
763+
continue
764+
765+
if res := (
766+
get_matching_call_cst(withitem.item, *cancel_scope_names)
767+
or get_matching_call_cst(
768+
withitem.item,
769+
"timeout",
770+
"timeout_at",
771+
base="asyncio",
772+
)
773+
):
774+
# typing issue: https://github.com/Instagram/LibCST/issues/1107
775+
pos = cst.ensure_type(
776+
self.get_metadata(PositionProvider, withitem),
777+
CodeRange,
778+
).start
779+
self.uncheckpointed_statements.add(
780+
ArtificialStatement("withitem", pos.line, pos.column)
781+
)
782+
783+
cm = self.has_checkpoint_stack[-1]
784+
cm.line = pos.line
785+
cm.column = pos.column
786+
cm.call = res
787+
cm.has_checkpoint = False
731788

732789
def leave_With(self, original_node: cst.With, updated_node: cst.With):
733-
# Uses leave_With instead of leave_With_body because we need access to both
734-
# original and updated node
735-
# ASYNC100
736-
if not self.has_checkpoint_stack.pop():
737-
autofix = len(updated_node.items) == 1
738-
for res in self.node_dict[original_node]:
790+
withitems = list(updated_node.items)
791+
for i in reversed(range(len(updated_node.items))):
792+
cm = self.has_checkpoint_stack.pop()
793+
# ASYNC100
794+
if cm.has_checkpoint is False:
795+
res = cm.call
796+
assert res is not None
739797
# bypass 910 & 911's should_autofix logic, which excludes asyncio
740-
# (TODO: and uses self.noautofix ... which I don't remember what it's for)
741-
autofix &= self.error(
742-
res.node, res.base, res.function, error_code="ASYNC100"
743-
) and super().should_autofix(res.node, code="ASYNC100")
744-
745-
if autofix:
746-
return flatten_preserving_comments(updated_node)
747-
# ASYNC912
748-
else:
749-
pos = self.get_metadata( # pyright: ignore
750-
PositionProvider, original_node
751-
).start # pyright: ignore
752-
line: int = pos.line # pyright: ignore
753-
column: int = pos.column # pyright: ignore
754-
s = ArtificialStatement("with", line, column)
755-
if s in self.uncheckpointed_statements:
756-
self.uncheckpointed_statements.remove(s)
757-
for res in self.node_dict[original_node]:
758-
self.error(res.node, error_code="ASYNC912")
759-
760-
self.node_dict.pop(original_node, None)
798+
if self.error(
799+
res.node, res.base, res.name, error_code="ASYNC100"
800+
) and super().should_autofix(res.node, code="ASYNC100"):
801+
if len(withitems) == 1:
802+
# Remove this With node, bypassing later logic.
803+
return flatten_preserving_comments(updated_node)
804+
if i == len(withitems) - 1:
805+
# preserve trailing comma, or remove comma if there was none
806+
withitems[-2] = withitems[-2].with_changes(
807+
comma=withitems[-1].comma
808+
)
809+
withitems.pop(i)
810+
811+
# ASYNC912
812+
elif cm.call is not None:
813+
assert cm.line is not None
814+
assert cm.column is not None
815+
s = ArtificialStatement("withitem", cm.line, cm.column)
816+
if s in self.uncheckpointed_statements:
817+
self.uncheckpointed_statements.remove(s)
818+
self.error(cm.call.node, error_code="ASYNC912")
761819

762820
# if exception-suppressing, restore all uncheckpointed statements from
763821
# before the `with`.
@@ -767,7 +825,8 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With):
767825
self.uncheckpointed_statements.update(prev_checkpoints)
768826

769827
self._checkpoint_with(original_node, entry=False)
770-
return updated_node
828+
829+
return updated_node.with_changes(items=withitems)
771830

772831
# error if no checkpoint since earlier yield or function entry
773832
def leave_Yield(

0 commit comments

Comments
 (0)