Skip to content

Commit 9c26271

Browse files
randolf-scholzsterliakovpre-commit-ci[bot]ilevkivskyi
authored
[match-case] fix matching against typing.Callable and Protocol types. (#19471)
- Fixes #14014 - Partially addresses #19470 Added extra logic in `checker.py:conditional_types` function to deal with structural types such as `typing.Callable` or protocols. ## new tests - `testMatchClassPatternCallable`: tests `case Callable() as fn` usage - `testMatchClassPatternProtocol`: tests `case Proto()` usage, where `Proto` is a Protocol - `testMatchClassPatternCallbackProtocol`: tests `case Proto()` usage, where `Proto` is a Callback-Protocol - `testGenericAliasIsinstanceUnreachable`: derived from a mypy-primer failure in mesonbuild. Tests that `isinstance(x, Proto)` can produce unreachable error. - `testGenericAliasRedundantExprCompoundIfExpr`: derived from a CI failure of `python runtest.py self` of an earlier version of this PR. ## modified tests - `testOverloadOnProtocol` added annotations to overload implementation, which wasn't getting checked. Added missing return. Fixed return type in second branch. --------- Co-authored-by: Stanislav Terliakov <50529348+sterliakov@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Levkivskyi <levkivskyi@gmail.com>
1 parent 72b0fca commit 9c26271

File tree

5 files changed

+225
-11
lines changed

5 files changed

+225
-11
lines changed

mypy/checker.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8174,11 +8174,15 @@ def conditional_types(
81748174
) -> tuple[Type | None, Type | None]:
81758175
"""Takes in the current type and a proposed type of an expression.
81768176
8177-
Returns a 2-tuple: The first element is the proposed type, if the expression
8178-
can be the proposed type. The second element is the type it would hold
8179-
if it was not the proposed type, if any. UninhabitedType means unreachable.
8180-
None means no new information can be inferred. If default is set it is returned
8181-
instead."""
8177+
Returns a 2-tuple:
8178+
The first element is the proposed type, if the expression can be the proposed type.
8179+
(or default, if default is set and the expression is a subtype of the proposed type).
8180+
The second element is the type it would hold if it was not the proposed type, if any.
8181+
(or default, if default is set and the expression is not a subtype of the proposed type).
8182+
8183+
UninhabitedType means unreachable.
8184+
None means no new information can be inferred.
8185+
"""
81828186
if proposed_type_ranges:
81838187
if len(proposed_type_ranges) == 1:
81848188
target = proposed_type_ranges[0].item
@@ -8190,14 +8194,25 @@ def conditional_types(
81908194
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
81918195
proposed_items = [type_range.item for type_range in proposed_type_ranges]
81928196
proposed_type = make_simplified_union(proposed_items)
8193-
if isinstance(proposed_type, AnyType):
8197+
if isinstance(get_proper_type(current_type), AnyType):
8198+
return proposed_type, current_type
8199+
elif isinstance(proposed_type, AnyType):
81948200
# We don't really know much about the proposed type, so we shouldn't
81958201
# attempt to narrow anything. Instead, we broaden the expr to Any to
81968202
# avoid false positives
81978203
return proposed_type, default
8198-
elif not any(
8199-
type_range.is_upper_bound for type_range in proposed_type_ranges
8200-
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
8204+
elif not any(type_range.is_upper_bound for type_range in proposed_type_ranges) and (
8205+
# concrete subtypes
8206+
is_proper_subtype(current_type, proposed_type, ignore_promotions=True)
8207+
# structural subtypes
8208+
or (
8209+
(
8210+
isinstance(proposed_type, CallableType)
8211+
or (isinstance(proposed_type, Instance) and proposed_type.type.is_protocol)
8212+
)
8213+
and is_subtype(current_type, proposed_type, ignore_promotions=True)
8214+
)
8215+
):
82018216
# Expression is always of one of the types in proposed_type_ranges
82028217
return default, UninhabitedType()
82038218
elif not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):

mypy/checkpattern.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
UninhabitedType,
5252
UnionType,
5353
UnpackType,
54+
callable_with_ellipsis,
5455
find_unpack_in_list,
5556
get_proper_type,
5657
split_with_prefix_and_suffix,
@@ -546,6 +547,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
546547
return self.early_non_match()
547548
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
548549
typ = fill_typevars_with_any(p_typ.type_object())
550+
elif (
551+
isinstance(type_info, Var)
552+
and type_info.type is not None
553+
and type_info.fullname == "typing.Callable"
554+
):
555+
# Create a `Callable[..., Any]`
556+
fallback = self.chk.named_type("builtins.function")
557+
any_type = AnyType(TypeOfAny.unannotated)
558+
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
549559
elif not isinstance(p_typ, AnyType):
550560
self.msg.fail(
551561
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(

test-data/unit/check-generic-alias.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,38 @@ t23: collections.abc.ValuesView[str]
149149
# reveal_type(t23) # Nx Revealed type is "collections.abc.ValuesView[builtins.str]"
150150
[builtins fixtures/tuple.pyi]
151151

152+
[case testGenericAliasIsinstanceUnreachable]
153+
# flags: --warn-unreachable --python-version 3.10
154+
from collections.abc import Iterable
155+
156+
class A: ...
157+
158+
def test(dependencies: list[A] | None) -> None:
159+
if dependencies is None:
160+
dependencies = []
161+
elif not isinstance(dependencies, Iterable):
162+
dependencies = [dependencies] # E: Statement is unreachable
163+
164+
[builtins fixtures/isinstancelist.pyi]
165+
[typing fixtures/typing-full.pyi]
166+
167+
[case testGenericAliasRedundantExprCompoundIfExpr]
168+
# flags: --warn-unreachable --enable-error-code=redundant-expr --python-version 3.10
169+
170+
from typing import Any, reveal_type
171+
from collections.abc import Iterable
172+
173+
def test_example(x: Iterable[Any]) -> None:
174+
if isinstance(x, Iterable) and not isinstance(x, str): # E: Left operand of "and" is always true
175+
reveal_type(x) # N: Revealed type is "typing.Iterable[Any]"
176+
177+
def test_counterexample(x: Any) -> None:
178+
if isinstance(x, Iterable) and not isinstance(x, str):
179+
reveal_type(x) # N: Revealed type is "typing.Iterable[Any]"
180+
181+
[builtins fixtures/isinstancelist.pyi]
182+
[typing fixtures/typing-full.pyi]
183+
152184

153185
[case testGenericBuiltinTupleTyping]
154186
from typing import Tuple

test-data/unit/check-protocols.test

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,11 +1506,12 @@ class C: pass
15061506
def f(x: P1) -> int: ...
15071507
@overload
15081508
def f(x: P2) -> str: ...
1509-
def f(x):
1509+
def f(x: object) -> object:
15101510
if isinstance(x, P1):
15111511
return P1.attr1
15121512
if isinstance(x, P2): # E: Only @runtime_checkable protocols can be used with instance and class checks
1513-
return P1.attr2
1513+
return P2.attr2
1514+
return None
15141515

15151516
reveal_type(f(C1())) # N: Revealed type is "builtins.int"
15161517
reveal_type(f(C2())) # N: Revealed type is "builtins.str"

test-data/unit/check-python310.test

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,30 @@ match m:
1111
-- Literal Pattern --
1212

1313
[case testMatchLiteralPatternNarrows]
14+
# flags: --warn-unreachable
1415
m: object
1516

1617
match m:
1718
case 1:
1819
reveal_type(m) # N: Revealed type is "Literal[1]"
20+
case 2:
21+
reveal_type(m) # N: Revealed type is "Literal[2]"
22+
case other:
23+
reveal_type(other) # N: Revealed type is "builtins.object"
24+
25+
[case testMatchLiteralPatternNarrows2]
26+
# flags: --warn-unreachable
27+
from typing import Any
28+
29+
m: Any
30+
31+
match m:
32+
case 1:
33+
reveal_type(m) # N: Revealed type is "Literal[1]"
34+
case 2:
35+
reveal_type(m) # N: Revealed type is "Literal[2]"
36+
case other:
37+
reveal_type(other) # N: Revealed type is "Any"
1938

2039
[case testMatchLiteralPatternAlreadyNarrower-skip]
2140
m: bool
@@ -1079,6 +1098,143 @@ match m:
10791098
case Foo():
10801099
pass
10811100

1101+
[case testMatchClassPatternCallable]
1102+
# flags: --warn-unreachable
1103+
from typing import Callable, Any
1104+
1105+
class FnImpl:
1106+
def __call__(self, x: object, /) -> int: ...
1107+
1108+
def test_any(x: Any) -> None:
1109+
match x:
1110+
case Callable() as fn:
1111+
reveal_type(fn) # N: Revealed type is "def (*Any, **Any) -> Any"
1112+
case other:
1113+
reveal_type(other) # N: Revealed type is "Any"
1114+
1115+
def test_object(x: object) -> None:
1116+
match x:
1117+
case Callable() as fn:
1118+
reveal_type(fn) # N: Revealed type is "def (*Any, **Any) -> Any"
1119+
case other:
1120+
reveal_type(other) # N: Revealed type is "builtins.object"
1121+
1122+
def test_impl(x: FnImpl) -> None:
1123+
match x:
1124+
case Callable() as fn:
1125+
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
1126+
case other:
1127+
reveal_type(other) # E: Statement is unreachable
1128+
1129+
def test_callable(x: Callable[[object], int]) -> None:
1130+
match x:
1131+
case Callable() as fn:
1132+
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
1133+
case other:
1134+
reveal_type(other) # E: Statement is unreachable
1135+
1136+
[case testMatchClassPatternCallbackProtocol]
1137+
# flags: --warn-unreachable
1138+
from typing import Any, Callable
1139+
from typing_extensions import Protocol, runtime_checkable
1140+
1141+
@runtime_checkable
1142+
class FnProto(Protocol):
1143+
def __call__(self, x: int, /) -> object: ...
1144+
1145+
class FnImpl:
1146+
def __call__(self, x: object, /) -> int: ...
1147+
1148+
def test_any(x: Any) -> None:
1149+
match x:
1150+
case FnProto() as fn:
1151+
reveal_type(fn) # N: Revealed type is "__main__.FnProto"
1152+
case other:
1153+
reveal_type(other) # N: Revealed type is "Any"
1154+
1155+
def test_object(x: object) -> None:
1156+
match x:
1157+
case FnProto() as fn:
1158+
reveal_type(fn) # N: Revealed type is "__main__.FnProto"
1159+
case other:
1160+
reveal_type(other) # N: Revealed type is "builtins.object"
1161+
1162+
def test_impl(x: FnImpl) -> None:
1163+
match x:
1164+
case FnProto() as fn:
1165+
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
1166+
case other:
1167+
reveal_type(other) # E: Statement is unreachable
1168+
1169+
def test_callable(x: Callable[[object], int]) -> None:
1170+
match x:
1171+
case FnProto() as fn:
1172+
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
1173+
case other:
1174+
reveal_type(other) # E: Statement is unreachable
1175+
1176+
[builtins fixtures/dict.pyi]
1177+
1178+
[case testMatchClassPatternAnyCallableProtocol]
1179+
# flags: --warn-unreachable
1180+
from typing import Any, Callable
1181+
from typing_extensions import Protocol, runtime_checkable
1182+
1183+
@runtime_checkable
1184+
class AnyCallable(Protocol):
1185+
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
1186+
1187+
class FnImpl:
1188+
def __call__(self, x: object, /) -> int: ...
1189+
1190+
def test_object(x: object) -> None:
1191+
match x:
1192+
case AnyCallable() as fn:
1193+
reveal_type(fn) # N: Revealed type is "__main__.AnyCallable"
1194+
case other:
1195+
reveal_type(other) # N: Revealed type is "builtins.object"
1196+
1197+
def test_impl(x: FnImpl) -> None:
1198+
match x:
1199+
case AnyCallable() as fn:
1200+
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
1201+
case other:
1202+
reveal_type(other) # E: Statement is unreachable
1203+
1204+
def test_callable(x: Callable[[object], int]) -> None:
1205+
match x:
1206+
case AnyCallable() as fn:
1207+
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
1208+
case other:
1209+
reveal_type(other) # E: Statement is unreachable
1210+
1211+
[builtins fixtures/dict.pyi]
1212+
1213+
1214+
[case testMatchClassPatternProtocol]
1215+
from typing import Any
1216+
from typing_extensions import Protocol, runtime_checkable
1217+
1218+
@runtime_checkable
1219+
class Proto(Protocol):
1220+
def foo(self, x: int, /) -> object: ...
1221+
1222+
class Impl:
1223+
def foo(self, x: object, /) -> int: ...
1224+
1225+
def test_object(x: object) -> None:
1226+
match x:
1227+
case Proto() as y:
1228+
reveal_type(y) # N: Revealed type is "__main__.Proto"
1229+
1230+
def test_impl(x: Impl) -> None:
1231+
match x:
1232+
case Proto() as y:
1233+
reveal_type(y) # N: Revealed type is "__main__.Impl"
1234+
1235+
[builtins fixtures/dict.pyi]
1236+
1237+
10821238
[case testMatchClassPatternNestedGenerics]
10831239
# From cpython test_patma.py
10841240
x = [[{0: 0}]]

0 commit comments

Comments
 (0)