From afb012ddaf008bac9a36880be683ded82416bbbd Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Sat, 16 Aug 2025 22:58:27 +0200 Subject: [PATCH] Add type annotations to `transform_matching_parts.py` --- manim/animation/transform.py | 9 ++++++++- manim/animation/transform_matching_parts.py | 20 ++++++++++++-------- mypy.ini | 3 --- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/manim/animation/transform.py b/manim/animation/transform.py index 5bf5b76936..4688e1c1c5 100644 --- a/manim/animation/transform.py +++ b/manim/animation/transform.py @@ -834,7 +834,14 @@ def construct(self): """ - def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwargs): + def __init__( + self, + mobject: Mobject, + target_mobject: Mobject, + stretch: bool = True, + dim_to_match: int = 1, + **kwargs: Any, + ): self.to_add_on_completion = target_mobject self.stretch = stretch self.dim_to_match = dim_to_match diff --git a/manim/animation/transform_matching_parts.py b/manim/animation/transform_matching_parts.py index 03305201f1..b8ce3e94cc 100644 --- a/manim/animation/transform_matching_parts.py +++ b/manim/animation/transform_matching_parts.py @@ -4,12 +4,13 @@ __all__ = ["TransformMatchingShapes", "TransformMatchingTex"] -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject +from manim.mobject.text.tex_mobject import SingleStringMathTex from .._config import config from ..constants import RendererType @@ -74,10 +75,10 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, + **kwargs: Any, ): if isinstance(mobject, OpenGLVMobject): - group_type = OpenGLVGroup + group_type: type[OpenGLVGroup | OpenGLGroup | VGroup | Group] = OpenGLVGroup elif isinstance(mobject, OpenGLMobject): group_type = OpenGLGroup elif isinstance(mobject, VMobject): @@ -141,7 +142,7 @@ def __init__( self.to_add = target_mobject def get_shape_map(self, mobject: Mobject) -> dict: - shape_map = {} + shape_map: dict[int | str, VGroup | OpenGLVGroup] = {} for sm in self.get_mobject_parts(mobject): key = self.get_mobject_key(sm) if key not in shape_map: @@ -149,6 +150,7 @@ def get_shape_map(self, mobject: Mobject) -> dict: shape_map[key] = OpenGLVGroup() else: shape_map[key] = VGroup() + # error: Argument 1 to "add" of "OpenGLVGroup" has incompatible type "Mobject"; expected "OpenGLVMobject" [arg-type] shape_map[key].add(sm) return shape_map @@ -156,16 +158,17 @@ def clean_up_from_scene(self, scene: Scene) -> None: # Interpolate all animations back to 0 to ensure source mobjects remain unchanged. for anim in self.animations: anim.interpolate(0) + # error: Argument 1 to "remove" of "Scene" has incompatible type "OpenGLMobject"; expected "Mobject" [arg-type] scene.remove(self.mobject) scene.remove(*self.to_remove) scene.add(self.to_add) @staticmethod - def get_mobject_parts(mobject: Mobject): + def get_mobject_parts(mobject: Mobject) -> list[Mobject]: raise NotImplementedError("To be implemented in subclass.") @staticmethod - def get_mobject_key(mobject: Mobject): + def get_mobject_key(mobject: Mobject) -> int | str: raise NotImplementedError("To be implemented in subclass.") @@ -205,7 +208,7 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, + **kwargs: Any, ): super().__init__( mobject, @@ -269,7 +272,7 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, + **kwargs: Any, ): super().__init__( mobject, @@ -294,4 +297,5 @@ def get_mobject_parts(mobject: Mobject) -> list[Mobject]: @staticmethod def get_mobject_key(mobject: Mobject) -> str: + assert isinstance(mobject, SingleStringMathTex) return mobject.tex_string diff --git a/mypy.ini b/mypy.ini index ca5722a015..5472b6491f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -67,9 +67,6 @@ ignore_errors = True [mypy-manim.animation.speedmodifier] ignore_errors = True -[mypy-manim.animation.transform_matching_parts] -ignore_errors = True - [mypy-manim.animation.transform] ignore_errors = True