Skip to content

Commit afb012d

Browse files
committed
Add type annotations to transform_matching_parts.py
1 parent 855ea86 commit afb012d

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

manim/animation/transform.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,14 @@ def construct(self):
834834
835835
"""
836836

837-
def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwargs):
837+
def __init__(
838+
self,
839+
mobject: Mobject,
840+
target_mobject: Mobject,
841+
stretch: bool = True,
842+
dim_to_match: int = 1,
843+
**kwargs: Any,
844+
):
838845
self.to_add_on_completion = target_mobject
839846
self.stretch = stretch
840847
self.dim_to_match = dim_to_match

manim/animation/transform_matching_parts.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
__all__ = ["TransformMatchingShapes", "TransformMatchingTex"]
66

7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Any
88

99
import numpy as np
1010

1111
from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject
1212
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject
13+
from manim.mobject.text.tex_mobject import SingleStringMathTex
1314

1415
from .._config import config
1516
from ..constants import RendererType
@@ -74,10 +75,10 @@ def __init__(
7475
transform_mismatches: bool = False,
7576
fade_transform_mismatches: bool = False,
7677
key_map: dict | None = None,
77-
**kwargs,
78+
**kwargs: Any,
7879
):
7980
if isinstance(mobject, OpenGLVMobject):
80-
group_type = OpenGLVGroup
81+
group_type: type[OpenGLVGroup | OpenGLGroup | VGroup | Group] = OpenGLVGroup
8182
elif isinstance(mobject, OpenGLMobject):
8283
group_type = OpenGLGroup
8384
elif isinstance(mobject, VMobject):
@@ -141,31 +142,33 @@ def __init__(
141142
self.to_add = target_mobject
142143

143144
def get_shape_map(self, mobject: Mobject) -> dict:
144-
shape_map = {}
145+
shape_map: dict[int | str, VGroup | OpenGLVGroup] = {}
145146
for sm in self.get_mobject_parts(mobject):
146147
key = self.get_mobject_key(sm)
147148
if key not in shape_map:
148149
if config["renderer"] == RendererType.OPENGL:
149150
shape_map[key] = OpenGLVGroup()
150151
else:
151152
shape_map[key] = VGroup()
153+
# error: Argument 1 to "add" of "OpenGLVGroup" has incompatible type "Mobject"; expected "OpenGLVMobject" [arg-type]
152154
shape_map[key].add(sm)
153155
return shape_map
154156

155157
def clean_up_from_scene(self, scene: Scene) -> None:
156158
# Interpolate all animations back to 0 to ensure source mobjects remain unchanged.
157159
for anim in self.animations:
158160
anim.interpolate(0)
161+
# error: Argument 1 to "remove" of "Scene" has incompatible type "OpenGLMobject"; expected "Mobject" [arg-type]
159162
scene.remove(self.mobject)
160163
scene.remove(*self.to_remove)
161164
scene.add(self.to_add)
162165

163166
@staticmethod
164-
def get_mobject_parts(mobject: Mobject):
167+
def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
165168
raise NotImplementedError("To be implemented in subclass.")
166169

167170
@staticmethod
168-
def get_mobject_key(mobject: Mobject):
171+
def get_mobject_key(mobject: Mobject) -> int | str:
169172
raise NotImplementedError("To be implemented in subclass.")
170173

171174

@@ -205,7 +208,7 @@ def __init__(
205208
transform_mismatches: bool = False,
206209
fade_transform_mismatches: bool = False,
207210
key_map: dict | None = None,
208-
**kwargs,
211+
**kwargs: Any,
209212
):
210213
super().__init__(
211214
mobject,
@@ -269,7 +272,7 @@ def __init__(
269272
transform_mismatches: bool = False,
270273
fade_transform_mismatches: bool = False,
271274
key_map: dict | None = None,
272-
**kwargs,
275+
**kwargs: Any,
273276
):
274277
super().__init__(
275278
mobject,
@@ -294,4 +297,5 @@ def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
294297

295298
@staticmethod
296299
def get_mobject_key(mobject: Mobject) -> str:
300+
assert isinstance(mobject, SingleStringMathTex)
297301
return mobject.tex_string

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ ignore_errors = True
6767
[mypy-manim.animation.speedmodifier]
6868
ignore_errors = True
6969

70-
[mypy-manim.animation.transform_matching_parts]
71-
ignore_errors = True
72-
7370
[mypy-manim.animation.transform]
7471
ignore_errors = True
7572

0 commit comments

Comments
 (0)