diff --git a/mypy/cache.py b/mypy/cache.py index 900815b9f7e7..5037262aa32a 100644 --- a/mypy/cache.py +++ b/mypy/cache.py @@ -236,6 +236,7 @@ def read(cls, data: Buffer, data_file: str) -> CacheMeta | None: # Misc classes. EXTRA_ATTRS: Final[Tag] = 150 DT_SPEC: Final[Tag] = 151 +PLUGIN_FLAGS: Final[Tag] = 152 END_TAG: Final[Tag] = 255 diff --git a/mypy/checker.py b/mypy/checker.py index f3a93d1eeda1..52a43898f32e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -122,6 +122,7 @@ OverloadedFuncDef, OverloadPart, PassStmt, + PluginFlags, PromoteExpr, RaiseStmt, RefExpr, @@ -2205,6 +2206,11 @@ def check_method_override( defn.name != "__replace__" or defn.info.metadata.get("dataclass_tag") is None ) + and not ( + defn.info + and (node := defn.info.get(defn.name)) + and PluginFlags.should_skip_override_checks(node) + ) ) found_method_base_classes: list[TypeInfo] = [] for base in defn.info.mro[1:]: @@ -3547,6 +3553,10 @@ def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) -> and lvalue.kind in (MDEF, None) # None for Vars defined via self and len(lvalue_node.info.bases) > 0 ): + if not ( + sym := lvalue_node.info.names.get(lvalue_node.name) + ) or PluginFlags.should_skip_override_checks(sym): + return for base in lvalue_node.info.mro[1:]: tnode = base.names.get(lvalue_node.name) if tnode is not None: diff --git a/mypy/nodes.py b/mypy/nodes.py index 539995ce9229..37d78b2f732f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -28,6 +28,7 @@ LIST_STR, LITERAL_COMPLEX, LITERAL_NONE, + PLUGIN_FLAGS, Buffer, Tag, read_bool, @@ -4469,6 +4470,57 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_placeholder_node(self) +class PluginFlags: + """Checking customization for plugin-generated nodes. + + This class is part of the public API. It can be used with the + `mypy.plugins.common.add_*_to_class` family of functions. + + Args: + skip_override_checks: Allow this node to be an incompatible override. + A node having this flag set to True will not be required to be + LSP-compatible with the superclasses of its enclosing class. + This is helpful when the plugin generates a precise signature, + overriding a fallback signature defined in the base class. + This flag does not affect checking overrides *of* this node in + further subclasses. + """ + + def __init__(self, *, skip_override_checks: bool = False) -> None: + self.skip_override_checks = skip_override_checks + + @staticmethod + def should_skip_override_checks(node: SymbolTableNode) -> bool: + if node.plugin_flags is None: + return False + return node.plugin_flags.skip_override_checks + + def serialize(self) -> JsonDict: + data: JsonDict = {".class": "PluginFlags"} + if self.skip_override_checks: + data["skip_override_checks"] = True + return data + + @classmethod + def deserialize(cls, data: JsonDict) -> PluginFlags: + flags = PluginFlags() + if data.get("skip_override_checks"): + flags.skip_override_checks = True + return flags + + def write(self, data: Buffer) -> None: + write_tag(data, PLUGIN_FLAGS) + write_bool(data, self.skip_override_checks) + write_tag(data, END_TAG) + + @classmethod + def read(cls, data: Buffer) -> PluginFlags: + flags = PluginFlags() + flags.skip_override_checks = read_bool(data) + assert read_tag(data) == END_TAG + return flags + + class SymbolTableNode: """Description of a name binding in a symbol table. @@ -4537,6 +4589,7 @@ class SymbolTableNode: "cross_ref", "implicit", "plugin_generated", + "plugin_flags", "no_serialize", ) @@ -4549,6 +4602,7 @@ def __init__( module_hidden: bool = False, *, plugin_generated: bool = False, + plugin_flags: PluginFlags | None = None, no_serialize: bool = False, ) -> None: self.kind = kind @@ -4558,6 +4612,7 @@ def __init__( self.module_hidden = module_hidden self.cross_ref: str | None = None self.plugin_generated = plugin_generated + self.plugin_flags = plugin_flags self.no_serialize = no_serialize @property @@ -4611,6 +4666,8 @@ def serialize(self, prefix: str, name: str) -> JsonDict: data["implicit"] = True if self.plugin_generated: data["plugin_generated"] = True + if self.plugin_flags: + data["plugin_flags"] = self.plugin_flags.serialize() if isinstance(self.node, MypyFile): data["cross_ref"] = self.node.fullname else: @@ -4650,6 +4707,8 @@ def deserialize(cls, data: JsonDict) -> SymbolTableNode: stnode.implicit = data["implicit"] if "plugin_generated" in data: stnode.plugin_generated = data["plugin_generated"] + if "plugin_flags" in data: + stnode.plugin_flags = PluginFlags.deserialize(data["plugin_flags"]) return stnode def write(self, data: Buffer, prefix: str, name: str) -> None: @@ -4681,6 +4740,10 @@ def write(self, data: Buffer, prefix: str, name: str) -> None: if cross_ref is None: assert self.node is not None self.node.write(data) + if self.plugin_flags is None: + write_literal(data, None) + else: + self.plugin_flags.write(data) write_tag(data, END_TAG) @classmethod @@ -4696,6 +4759,10 @@ def read(cls, data: Buffer) -> SymbolTableNode: sym.node = read_symbol(data) else: sym.cross_ref = cross_ref + if (tag := read_tag(data)) == PLUGIN_FLAGS: + sym.plugin_flags = PluginFlags.read(data) + else: + assert tag == LITERAL_NONE assert read_tag(data) == END_TAG return sym diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index ed2a91d102f4..3243d5e849cb 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -20,6 +20,7 @@ Node, OverloadedFuncDef, PassStmt, + PluginFlags, RefExpr, SymbolTableNode, TypeInfo, @@ -220,6 +221,7 @@ class MethodSpec(NamedTuple): return_type: Type self_type: Type | None = None tvar_defs: list[TypeVarType] | None = None + flags: PluginFlags | None = None def add_method_to_class( @@ -233,6 +235,7 @@ def add_method_to_class( tvar_def: list[TypeVarType] | TypeVarType | None = None, is_classmethod: bool = False, is_staticmethod: bool = False, + flags: PluginFlags | None = None, ) -> FuncDef | Decorator: """Adds a new method to a class definition.""" _prepare_class_namespace(cls, name) @@ -244,7 +247,13 @@ def add_method_to_class( api, cls.info, name, - MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def), + MethodSpec( + args=args, + return_type=return_type, + self_type=self_type, + tvar_defs=tvar_def, + flags=flags, + ), is_classmethod=is_classmethod, is_staticmethod=is_staticmethod, ) @@ -260,6 +269,7 @@ def add_overloaded_method_to_class( items: list[MethodSpec], is_classmethod: bool = False, is_staticmethod: bool = False, + flags: PluginFlags | None = None, ) -> OverloadedFuncDef: """Adds a new overloaded method to a class definition.""" assert len(items) >= 2, "Overloads must contain at least two cases" @@ -294,8 +304,7 @@ def add_overloaded_method_to_class( overload_def.info = cls.info overload_def.is_class = is_classmethod overload_def.is_static = is_staticmethod - sym = SymbolTableNode(MDEF, overload_def) - sym.plugin_generated = True + sym = SymbolTableNode(MDEF, overload_def, plugin_generated=True, plugin_flags=flags) cls.info.names[name] = sym cls.info.defn.defs.body.append(overload_def) @@ -330,7 +339,7 @@ def _add_method_by_spec( is_classmethod: bool, is_staticmethod: bool, ) -> tuple[FuncDef | Decorator, SymbolTableNode]: - args, return_type, self_type, tvar_defs = spec + args, return_type, self_type, tvar_defs, flags = spec assert not ( is_classmethod is True and is_staticmethod is True @@ -383,8 +392,7 @@ def _add_method_by_spec( sym.plugin_generated = True return dec, sym - sym = SymbolTableNode(MDEF, func) - sym.plugin_generated = True + sym = SymbolTableNode(MDEF, func, plugin_generated=True, plugin_flags=flags) return func, sym @@ -399,6 +407,7 @@ def add_attribute_to_class( fullname: str | None = None, is_classvar: bool = False, overwrite_existing: bool = False, + flags: PluginFlags | None = None, ) -> Var: """ Adds a new attribute to a class definition. @@ -428,7 +437,7 @@ def add_attribute_to_class( node._fullname = info.fullname + "." + name info.names[name] = SymbolTableNode( - MDEF, node, plugin_generated=True, no_serialize=no_serialize + MDEF, node, plugin_generated=True, no_serialize=no_serialize, plugin_flags=flags ) return node diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 0c157510cb34..a7fcf7f86312 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -1087,6 +1087,59 @@ plugins=/test-data/unit/plugins/add_method.py enable_error_code = explicit-override [typing fixtures/typing-override.pyi] +[case testAddMethodPluginExplicitOverrideIgnoreCompat] +# flags: --python-version 3.12 --config-file tmp/mypy.ini --debug-serialize +from typing import TypeVar + +T = TypeVar('T', bound=type) +def inject_foo(t: T) -> T: + return t + +class BaseWithoutFoo: + pass + +@inject_foo +class Child1(BaseWithoutFoo): ... + +class BaseWithSameFoo: + attr: None + def meth_ok(self) -> None: ... + def meth_bad(self) -> None: ... + +@inject_foo +class Child2(BaseWithSameFoo): ... + +class BaseWithOtherFoo: + attr: int + def meth_ok(self) -> int: ... + def meth_bad(self) -> int: ... + +# `attr` is not reported because add_attribute_to_class does not generate a statement (yet). +@inject_foo +class Child3(BaseWithOtherFoo): ... # E: Return type "None" of "meth_bad" incompatible with return type "int" in supertype "BaseWithOtherFoo" + +@inject_foo +class ImmediatelyOverridden: + attr: int + def meth_ok(self) -> int: ... + def meth_bad(self) -> int: ... + +@inject_foo +class Original: + ... +class FurtherOverridden(Original): + attr: int # E: Incompatible types in assignment (expression has type "int", base class "Original" defined the type as "None") + def meth_ok(self) -> int: ... # E: Return type "int" of "meth_ok" incompatible with return type "None" in supertype "Original" \ + # E: Method "meth_ok" is not using @override but is overriding a method in class "__main__.Original" + def meth_bad(self) -> int: ... # E: Return type "int" of "meth_bad" incompatible with return type "None" in supertype "Original" \ + # E: Method "meth_bad" is not using @override but is overriding a method in class "__main__.Original" + +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/add_method_ignore_compat.py +enable_error_code = explicit-override +[typing fixtures/typing-override.pyi] + [case testCustomErrorCodePlugin] # flags: --config-file tmp/mypy.ini --show-error-codes def main() -> int: diff --git a/test-data/unit/plugins/add_method_ignore_compat.py b/test-data/unit/plugins/add_method_ignore_compat.py new file mode 100644 index 000000000000..db94febb3dbc --- /dev/null +++ b/test-data/unit/plugins/add_method_ignore_compat.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Callable + +from mypy.nodes import PluginFlags +from mypy.plugin import ClassDefContext, Plugin +from mypy.plugins.common import add_attribute_to_class, add_method_to_class +from mypy.types import NoneType + + +class AddOverrideMethodPlugin(Plugin): + def get_class_decorator_hook_2(self, fullname: str) -> Callable[[ClassDefContext], bool] | None: + if fullname == "__main__.inject_foo": + return add_extra_methods_hook + return None + + +def add_extra_methods_hook(ctx: ClassDefContext) -> bool: + add_method_to_class( + ctx.api, + ctx.cls, + "meth_ok", + [], + NoneType(), + flags=PluginFlags(skip_override_checks=True) + ) + add_method_to_class( + ctx.api, + ctx.cls, + "meth_bad", + [], + NoneType(), + flags=PluginFlags(skip_override_checks=False) + ) + add_attribute_to_class( + ctx.api, + ctx.cls, + "attr", + NoneType(), + flags=PluginFlags(skip_override_checks=True) + ) + return True + + +def plugin(version: str) -> type[AddOverrideMethodPlugin]: + return AddOverrideMethodPlugin