Skip to content

Commit e7d6983

Browse files
SirJamesClarkMaxwellpre-commit-ci[bot]JasonGrace2282
authored
Ensure that start and end points are stored as float values in Line3D (#4080)
* fixed problem with type conversions in Line3D * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed broken test and Arrow3D * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update manim/mobject/three_d/three_dimensions.py Co-authored-by: Aarush Deshpande <110117391+JasonGrace2282@users.noreply.github.com> * Update manim/mobject/three_d/three_dimensions.py Co-authored-by: Aarush Deshpande <110117391+JasonGrace2282@users.noreply.github.com> * Delete unnecessary file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typo in variable nam --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aarush Deshpande <110117391+JasonGrace2282@users.noreply.github.com>
1 parent 7ffdf04 commit e7d6983

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

manim/mobject/three_d/three_dimensions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,10 @@ def __init__(
936936
):
937937
self.thickness = thickness
938938
self.resolution = (2, resolution) if isinstance(resolution, int) else resolution
939+
940+
start = np.array(start, dtype=np.float64)
941+
end = np.array(end, dtype=np.float64)
942+
939943
self.set_start_and_end_attrs(start, end, **kwargs)
940944
if color is not None:
941945
self.set_color(color)
@@ -1193,8 +1197,9 @@ def __init__(
11931197
height=height,
11941198
**kwargs,
11951199
)
1196-
self.cone.shift(end)
1197-
self.end_point = VectorizedPoint(end)
1200+
np_end = np.asarray(end, dtype=np.float64)
1201+
self.cone.shift(np_end)
1202+
self.end_point = VectorizedPoint(np_end)
11981203
self.add(self.end_point, self.cone)
11991204
self.set_color(color)
12001205

tests/test_graphical_units/test_threed.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,34 @@ def param_surface(u, v):
164164

165165

166166
def test_get_start_and_end_Arrow3d():
167-
start, end = ORIGIN, np.array([2, 1, 0])
167+
start, end = ORIGIN, np.array([2, 1, 0], dtype=np.float64)
168168
arrow = Arrow3D(start, end)
169169
assert np.allclose(arrow.get_start(), start, atol=0.01), (
170170
"start points of Arrow3D do not match"
171171
)
172172
assert np.allclose(arrow.get_end(), end, atol=0.01), (
173173
"end points of Arrow3D do not match"
174174
)
175+
176+
177+
def test_type_conversion_in_Line3D():
178+
start, end = [0, 0, 0], [1, 1, 1]
179+
line = Line3D(start, end)
180+
type_table = [type(item) for item in [*line.get_start(), *line.get_end()]]
181+
bool_table = [t == np.float64 for t in type_table]
182+
assert all(bool_table), "Types of start and end points are not np.float64"
183+
184+
185+
def test_type_conversion_in_Arrow3D():
186+
start, end = [0, 0, 0], [1, 1, 1]
187+
arrow = Arrow3D(start, end)
188+
type_table = [type(item) for item in [*arrow.get_start(), *arrow.get_end()]]
189+
bool_table = [t == np.float64 for t in type_table]
190+
assert all(bool_table), "Types of start and end points are not np.float64"
191+
192+
assert np.allclose(arrow.get_start(), start, atol=0.01), (
193+
"start points of Arrow3D do not match"
194+
)
195+
assert np.allclose(arrow.get_end(), end, atol=0.01), (
196+
"end points of Arrow3D do not match"
197+
)

0 commit comments

Comments
 (0)