Skip to content

Commit 3b0e463

Browse files
committed
assertion: add a Protocol for rewrite hook
Mostly to fix the TODO, doesn't have much semantic meaning. Also fixes a related typo in import inside `TYPE_CHECKING` block.
1 parent 52dccfd commit 3b0e463

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/_pytest/assertion/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Generator
77
import sys
88
from typing import Any
9+
from typing import Protocol
910
from typing import TYPE_CHECKING
1011

1112
from _pytest.assertion import rewrite
@@ -82,15 +83,18 @@ def register_assert_rewrite(*names: str) -> None:
8283
if not isinstance(name, str):
8384
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
8485
raise TypeError(msg.format(repr(names)))
86+
rewrite_hook: RewriteHook
8587
for hook in sys.meta_path:
8688
if isinstance(hook, rewrite.AssertionRewritingHook):
87-
importhook = hook
89+
rewrite_hook = hook
8890
break
8991
else:
90-
# TODO(typing): Add a protocol for mark_rewrite() and use it
91-
# for importhook and for PytestPluginManager.rewrite_hook.
92-
importhook = DummyRewriteHook() # type: ignore
93-
importhook.mark_rewrite(*names)
92+
rewrite_hook = DummyRewriteHook()
93+
rewrite_hook.mark_rewrite(*names)
94+
95+
96+
class RewriteHook(Protocol):
97+
def mark_rewrite(self, *names: str) -> None: ...
9498

9599

96100
class DummyRewriteHook:

src/_pytest/config/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171

7272
if TYPE_CHECKING:
73-
from _pytest.assertions.rewrite import AssertionRewritingHook
73+
from _pytest.assertion.rewrite import AssertionRewritingHook
7474
from _pytest.cacheprovider import Cache
7575
from _pytest.terminal import TerminalReporter
7676

@@ -397,7 +397,8 @@ class PytestPluginManager(PluginManager):
397397
"""
398398

399399
def __init__(self) -> None:
400-
import _pytest.assertion
400+
from _pytest.assertion import DummyRewriteHook
401+
from _pytest.assertion import RewriteHook
401402

402403
super().__init__("pytest")
403404

@@ -443,7 +444,7 @@ def __init__(self) -> None:
443444
self.enable_tracing()
444445

445446
# Config._consider_importhook will set a real object if required.
446-
self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
447+
self.rewrite_hook: RewriteHook = DummyRewriteHook()
447448
# Used to know when we are importing conftests after the pytest_configure stage.
448449
self._configured = False
449450

0 commit comments

Comments
 (0)