diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 03ebc5058cee..1f21a84e7162 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -143,6 +143,7 @@ validate_instance, ) from mypy.typeops import ( + bind_self, callable_type, custom_special_method, erase_to_union_or_bound, @@ -1001,6 +1002,7 @@ def typeddict_callable_from_context( self.named_type("builtins.type"), variables=variables, is_bound=True, + original_self_type=AnyType(TypeOfAny.implementation_artifact), ) def check_typeddict_call_with_kwargs( @@ -2911,10 +2913,61 @@ def infer_overload_return_type( args_contain_any = any(map(has_any_type, arg_types)) type_maps: list[dict[Expression, Type]] = [] + # If we have a selftype overload, it should contribute to `any_causes_overload_ambiguity` + # check. Pretend that we're checking `Foo.func(instance, ...)` instead of + # `instance.func(...)`. + + def is_trivial_self(t: CallableType) -> bool: + if isinstance(t.definition, FuncDef): + return t.definition.is_trivial_self + if isinstance(t.definition, Decorator): + return t.definition.func.is_trivial_self + return False + + prepend_self = ( + object_type is not None + and has_any_type(object_type) + and any( + typ.is_bound and typ.original_self_type is not None and not is_trivial_self(typ) + for typ in plausible_targets + ) + ) + if prepend_self: + assert object_type is not None + + args = [TempNode(object_type)] + args + arg_types = [object_type] + arg_types + arg_kinds = [ARG_POS] + arg_kinds + arg_names = [None, *arg_names] if arg_names is not None else None + + def maybe_bind_self(t: Type) -> Type: + if prepend_self and isinstance(t, ProperType) and isinstance(t, FunctionLike): + return bind_self(t, object_type) + return t + for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map as m: + with self.msg.filter_errors() as w, self.chk.local_type_map as m: + if prepend_self: + param = typ.original_self_type + assert param is not None, "Overload bound only partially?" + typ = typ.copy_modified( + arg_types=[param] + typ.arg_types, + arg_kinds=[ARG_POS] + typ.arg_kinds, + arg_names=[None, *typ.arg_names], + is_bound=False, + original_self_type=None, + ) + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=None, + ) + else: ret_type, infer_type = self.check_call( callee=typ, args=args, @@ -2928,9 +2981,9 @@ def infer_overload_return_type( if is_match: # Return early if possible; otherwise record info, so we can # check for ambiguity due to 'Any' below. - if not args_contain_any: + if not args_contain_any and not prepend_self: self.chk.store_types(m) - return ret_type, infer_type + return ret_type, maybe_bind_self(infer_type) p_infer_type = get_proper_type(infer_type) if isinstance(p_infer_type, CallableType): # Prefer inferred types if possible, this will avoid false triggers for @@ -2949,10 +3002,10 @@ def infer_overload_return_type( # We try returning a precise type if we can. If not, we give up and just return 'Any'. if all_same_types(return_types): self.chk.store_types(type_maps[0]) - return return_types[0], inferred_types[0] + return return_types[0], maybe_bind_self(inferred_types[0]) elif all_same_types([erase_type(typ) for typ in return_types]): self.chk.store_types(type_maps[0]) - return erase_type(return_types[0]), erase_type(inferred_types[0]) + return erase_type(return_types[0]), maybe_bind_self(erase_type(inferred_types[0])) else: return self.check_call( callee=AnyType(TypeOfAny.special_form), @@ -2966,7 +3019,7 @@ def infer_overload_return_type( else: # Success! No ambiguity; return the first match. self.chk.store_types(type_maps[0]) - return return_types[0], inferred_types[0] + return return_types[0], maybe_bind_self(inferred_types[0]) def overload_erased_call_targets( self, @@ -5017,6 +5070,7 @@ def apply_type_arguments_to_callable( name="tuple", definition=tp.definition, is_bound=tp.is_bound, + original_self_type=tp.original_self_type, ) self.msg.incompatible_type_application( min_arg_count, len(type_vars), len(args), ctx diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 719b48b14e07..65b547e5f235 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1519,6 +1519,7 @@ def bind_self_fast(method: F, original_type: Type | None = None) -> F: arg_kinds=method.arg_kinds[1:], arg_names=method.arg_names[1:], is_bound=True, + original_self_type=method.arg_types[0], ) diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 500d8fd5ae08..3ff030c06df9 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -79,14 +79,14 @@ def visit_deleted_type(self, t: DeletedType) -> ProperType: return t def visit_instance(self, t: Instance) -> ProperType: - args = erased_vars(t.type.defn.type_vars, TypeOfAny.special_form) + args = erased_vars(t.type.defn.type_vars, TypeOfAny.explicit) return Instance(t.type, args, t.line) def visit_type_var(self, t: TypeVarType) -> ProperType: - return AnyType(TypeOfAny.special_form) + return AnyType(TypeOfAny.explicit) def visit_param_spec(self, t: ParamSpecType) -> ProperType: - return AnyType(TypeOfAny.special_form) + return AnyType(TypeOfAny.explicit) def visit_parameters(self, t: Parameters) -> ProperType: raise RuntimeError("Parameters should have been bound to a class") @@ -94,14 +94,14 @@ def visit_parameters(self, t: Parameters) -> ProperType: def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: # Likely, we can never get here because of aggressive erasure of types that # can contain this, but better still return a valid replacement. - return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)]) + return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.explicit)]) def visit_unpack_type(self, t: UnpackType) -> ProperType: - return AnyType(TypeOfAny.special_form) + return AnyType(TypeOfAny.explicit) def visit_callable_type(self, t: CallableType) -> ProperType: # We must preserve the fallback type for overload resolution to work. - any_type = AnyType(TypeOfAny.special_form) + any_type = AnyType(TypeOfAny.explicit) return CallableType( arg_types=[any_type, any_type], arg_kinds=[ARG_STAR, ARG_STAR2], diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 891ea4d89a80..8c0482bb63cc 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -381,6 +381,11 @@ def visit_callable_type(self, t: CallableType) -> CallableType: arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds, arg_names=t.arg_names[:-2] + repl.arg_names, ret_type=t.ret_type.accept(self), + original_self_type=( + t.original_self_type.accept(self) + if t.original_self_type is not None + else None + ), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), type_is=(t.type_is.accept(self) if t.type_is is not None else None), imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds), @@ -420,6 +425,9 @@ def visit_callable_type(self, t: CallableType) -> CallableType: expanded = t.copy_modified( arg_types=arg_types, ret_type=t.ret_type.accept(self), + original_self_type=( + t.original_self_type.accept(self) if t.original_self_type is not None else None + ), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), type_is=(t.type_is.accept(self) if t.type_is is not None else None), ) diff --git a/mypy/exportjson.py b/mypy/exportjson.py index 09945f0ef28f..896bfd211c32 100644 --- a/mypy/exportjson.py +++ b/mypy/exportjson.py @@ -467,6 +467,9 @@ def convert_callable_type(self: CallableType) -> Json: "is_ellipsis_args": self.is_ellipsis_args, "implicit": self.implicit, "is_bound": self.is_bound, + "original_self_type": ( + convert_type(self.original_self_type) if self.original_self_type is not None else None + ), "type_guard": convert_type(self.type_guard) if self.type_guard is not None else None, "type_is": convert_type(self.type_is) if self.type_is is not None else None, "from_concatenate": self.from_concatenate, diff --git a/mypy/fixup.py b/mypy/fixup.py index d0205f64b720..d1e943604687 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -283,6 +283,8 @@ def visit_callable_type(self, ct: CallableType) -> None: ct.ret_type.accept(self) for v in ct.variables: v.accept(self) + if ct.original_self_type is not None: + ct.original_self_type.accept(self) if ct.type_guard is not None: ct.type_guard.accept(self) if ct.type_is is not None: diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 15d472b64886..75abccdc5e6f 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -471,6 +471,7 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem: typ.is_ellipsis_args, snapshot_types(typ.variables), typ.is_bound, + snapshot_optional_type(typ.original_self_type), ) def normalize_callable_variables(self, typ: CallableType) -> CallableType: diff --git a/mypy/typeops.py b/mypy/typeops.py index d2f9f4da44e4..63825b9a828b 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -205,6 +205,7 @@ def type_object_type(info: TypeInfo, named_type: Callable[[str], Instance]) -> P arg_names=["_args", "_kwds"], ret_type=any_type, is_bound=True, + original_self_type=any_type, fallback=named_type("builtins.function"), ) result: FunctionLike = class_callable(sig, info, fallback, None, is_new=False) @@ -490,6 +491,7 @@ class B(A): pass arg_names=func.arg_names[1:], variables=variables, is_bound=True, + original_self_type=func.arg_types[0], ) return cast(F, res) diff --git a/mypy/types.py b/mypy/types.py index 7a8343097204..ae98f24ce3d5 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2169,6 +2169,7 @@ class CallableType(FunctionLike): "from_type_type", # Was this callable generated by analyzing Type[...] # instantiation? "is_bound", # Is this a bound method? + "original_self_type", # If bound, what was the type of `self` before? "type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case). "type_is", # T, if -> TypeIs[T] (ret_type is bool in this case). "from_concatenate", # whether this callable is from a concatenate object @@ -2195,6 +2196,7 @@ def __init__( special_sig: str | None = None, from_type_type: bool = False, is_bound: bool = False, + original_self_type: Type | None = None, type_guard: Type | None = None, type_is: Type | None = None, from_concatenate: bool = False, @@ -2232,6 +2234,7 @@ def __init__( self.from_concatenate = from_concatenate self.imprecise_arg_kinds = imprecise_arg_kinds self.is_bound = is_bound + self.original_self_type = original_self_type self.type_guard = type_guard self.type_is = type_is self.unpack_kwargs = unpack_kwargs @@ -2253,6 +2256,7 @@ def copy_modified( special_sig: Bogus[str | None] = _dummy, from_type_type: Bogus[bool] = _dummy, is_bound: Bogus[bool] = _dummy, + original_self_type: Bogus[Type | None] = _dummy, type_guard: Bogus[Type | None] = _dummy, type_is: Bogus[Type | None] = _dummy, from_concatenate: Bogus[bool] = _dummy, @@ -2277,6 +2281,9 @@ def copy_modified( special_sig=special_sig if special_sig is not _dummy else self.special_sig, from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type, is_bound=is_bound if is_bound is not _dummy else self.is_bound, + original_self_type=( + original_self_type if original_self_type is not _dummy else self.original_self_type + ), type_guard=type_guard if type_guard is not _dummy else self.type_guard, type_is=type_is if type_is is not _dummy else self.type_is, from_concatenate=( @@ -2598,6 +2605,11 @@ def serialize(self) -> JsonDict: "is_ellipsis_args": self.is_ellipsis_args, "implicit": self.implicit, "is_bound": self.is_bound, + "original_self_type": ( + self.original_self_type.serialize() + if self.original_self_type is not None + else None + ), "type_guard": self.type_guard.serialize() if self.type_guard is not None else None, "type_is": (self.type_is.serialize() if self.type_is is not None else None), "from_concatenate": self.from_concatenate, @@ -2620,6 +2632,11 @@ def deserialize(cls, data: JsonDict) -> CallableType: is_ellipsis_args=data["is_ellipsis_args"], implicit=data["implicit"], is_bound=data["is_bound"], + original_self_type=( + deserialize_type(data["original_self_type"]) + if data["original_self_type"] is not None + else None + ), type_guard=( deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None ), @@ -2641,6 +2658,7 @@ def write(self, data: Buffer) -> None: write_bool(data, self.is_ellipsis_args) write_bool(data, self.implicit) write_bool(data, self.is_bound) + write_type_opt(data, self.original_self_type) write_type_opt(data, self.type_guard) write_type_opt(data, self.type_is) write_bool(data, self.from_concatenate) @@ -2663,6 +2681,7 @@ def read(cls, data: Buffer) -> CallableType: is_ellipsis_args=read_bool(data), implicit=read_bool(data), is_bound=read_bool(data), + original_self_type=read_type_opt(data), type_guard=read_type_opt(data), type_is=read_type_opt(data), from_concatenate=read_bool(data), diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index abd0f6bf3bdf..de175e9db157 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -87,9 +87,11 @@ def visit_callable_type(self, t: CallableType, /) -> None: t.ret_type.accept(self) t.fallback.accept(self) + if t.original_self_type is not None: + t.original_self_type.accept(self) + if t.type_guard is not None: t.type_guard.accept(self) - if t.type_is is not None: t.type_is.accept(self) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index be55a182b87b..4d485bc99e8d 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6852,3 +6852,20 @@ if isinstance(headers, dict): reveal_type(headers) # N: Revealed type is "Union[__main__.Headers, typing.Iterable[tuple[builtins.bytes, builtins.bytes]]]" [builtins fixtures/isinstancelist.pyi] + +[case testSelfOverloadWithAnySelf] +from typing import Any, Generic, TypeVar, overload + +T = TypeVar("T") + +class A(Generic[T]): + @overload + def run(self: A[int]) -> int: ... + @overload + def run(self: A[str]) -> str: ... + def run(self: "A[int] | A[str]") -> "int | str": ... + +foo: A[Any] +reveal_type(foo.run()) # N: Revealed type is "Any" +reveal_type(A.run(foo)) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/exportjson.test b/test-data/unit/exportjson.test index 14295281a48f..3f9759adfaa0 100644 --- a/test-data/unit/exportjson.test +++ b/test-data/unit/exportjson.test @@ -170,6 +170,7 @@ def foo(a: int) -> None: ... "is_ellipsis_args": false, "implicit": false, "is_bound": false, + "original_self_type": null, "type_guard": null, "type_is": null, "from_concatenate": false,