|
10 | 10 | from types import TracebackType |
11 | 11 | from typing import Any |
12 | 12 | from typing import Callable |
| 13 | +from typing import cast |
| 14 | +from typing import Generic |
13 | 15 | from typing import Optional |
14 | 16 | from typing import overload |
15 | 17 | from typing import Pattern |
16 | 18 | from typing import Tuple |
| 19 | +from typing import TypeVar |
17 | 20 | from typing import Union |
18 | 21 |
|
19 | 22 | from more_itertools.more import always_iterable |
@@ -537,33 +540,35 @@ def _is_numpy_array(obj): |
537 | 540 |
|
538 | 541 | # builtin pytest.raises helper |
539 | 542 |
|
| 543 | +_E = TypeVar("_E", bound=BaseException) |
| 544 | + |
540 | 545 |
|
541 | 546 | @overload |
542 | 547 | def raises( |
543 | | - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 548 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
544 | 549 | *, |
545 | 550 | match: Optional[Union[str, Pattern]] = ... |
546 | | -) -> "RaisesContext": |
| 551 | +) -> "RaisesContext[_E]": |
547 | 552 | ... # pragma: no cover |
548 | 553 |
|
549 | 554 |
|
550 | 555 | @overload |
551 | 556 | def raises( |
552 | | - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 557 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
553 | 558 | func: Callable, |
554 | 559 | *args: Any, |
555 | 560 | match: Optional[str] = ..., |
556 | 561 | **kwargs: Any |
557 | | -) -> Optional[_pytest._code.ExceptionInfo]: |
| 562 | +) -> Optional[_pytest._code.ExceptionInfo[_E]]: |
558 | 563 | ... # pragma: no cover |
559 | 564 |
|
560 | 565 |
|
561 | 566 | def raises( |
562 | | - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 567 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
563 | 568 | *args: Any, |
564 | 569 | match: Optional[Union[str, Pattern]] = None, |
565 | 570 | **kwargs: Any |
566 | | -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: |
| 571 | +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: |
567 | 572 | r""" |
568 | 573 | Assert that a code block/function call raises ``expected_exception`` |
569 | 574 | or raise a failure exception otherwise. |
@@ -703,28 +708,30 @@ def raises( |
703 | 708 | try: |
704 | 709 | func(*args[1:], **kwargs) |
705 | 710 | except expected_exception: |
706 | | - return _pytest._code.ExceptionInfo.from_current() |
| 711 | + # Cast to narrow the type to expected_exception (_E). |
| 712 | + return cast( |
| 713 | + _pytest._code.ExceptionInfo[_E], |
| 714 | + _pytest._code.ExceptionInfo.from_current(), |
| 715 | + ) |
707 | 716 | fail(message) |
708 | 717 |
|
709 | 718 |
|
710 | 719 | raises.Exception = fail.Exception # type: ignore |
711 | 720 |
|
712 | 721 |
|
713 | | -class RaisesContext: |
| 722 | +class RaisesContext(Generic[_E]): |
714 | 723 | def __init__( |
715 | 724 | self, |
716 | | - expected_exception: Union[ |
717 | | - "Type[BaseException]", Tuple["Type[BaseException]", ...] |
718 | | - ], |
| 725 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
719 | 726 | message: str, |
720 | 727 | match_expr: Optional[Union[str, Pattern]] = None, |
721 | 728 | ) -> None: |
722 | 729 | self.expected_exception = expected_exception |
723 | 730 | self.message = message |
724 | 731 | self.match_expr = match_expr |
725 | | - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] |
| 732 | + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] |
726 | 733 |
|
727 | | - def __enter__(self) -> _pytest._code.ExceptionInfo: |
| 734 | + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: |
728 | 735 | self.excinfo = _pytest._code.ExceptionInfo.for_later() |
729 | 736 | return self.excinfo |
730 | 737 |
|
|
0 commit comments