diff --git a/manim/mobject/svg/svg_mobject.py b/manim/mobject/svg/svg_mobject.py index bd494c0211..a616ce2438 100644 --- a/manim/mobject/svg/svg_mobject.py +++ b/manim/mobject/svg/svg_mobject.py @@ -21,7 +21,7 @@ from ..geometry.line import Line from ..geometry.polygram import Polygon, Rectangle, RoundedRectangle from ..opengl.opengl_compatibility import ConvertToOpenGL -from ..types.vectorized_mobject import VMobject +from ..types.vectorized_mobject import VGroup, VMobject __all__ = ["SVGMobject", "VMobjectFromSVGPath"] @@ -127,12 +127,13 @@ def __init__( self.stroke_color = stroke_color self.stroke_opacity = stroke_opacity # type: ignore[assignment] self.stroke_width = stroke_width # type: ignore[assignment] + self.id_to_vgroup_dict: dict[str, VGroup] = {} if self.stroke_width is None: self.stroke_width = 0 if svg_default is None: svg_default = { - "color": None, + "color": VMobject().color, "opacity": None, "fill_color": None, "fill_opacity": None, @@ -203,8 +204,11 @@ def generate_mobject(self) -> None: svg = se.SVG.parse(modified_file_path) modified_file_path.unlink() - mobjects = self.get_mobjects_from(svg) - self.add(*mobjects) + mobjects_dict = self.get_mobjects_from(svg) + for key, value in mobjects_dict.items(): + self.id_to_vgroup_dict[key] = value + self.add(value) + self.flip(RIGHT) # Flip y def get_file_path(self) -> Path: @@ -258,7 +262,7 @@ def generate_config_style_dict(self) -> dict[str, str]: result[svg_key] = str(svg_default_dict[style_key]) return result - def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]: + def get_mobjects_from(self, svg: se.SVG) -> dict[str, VGroup]: """Convert the elements of the SVG to a list of mobjects. Parameters @@ -266,37 +270,84 @@ def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]: svg The parsed SVG file. """ - result: list[VMobject] = [] - for shape in svg.elements(): - # can we combine the two continue cases into one? - if isinstance(shape, se.Group): # noqa: SIM114 - continue - elif isinstance(shape, se.Path): - mob: VMobject = self.path_to_mobject(shape) - elif isinstance(shape, se.SimpleLine): - mob = self.line_to_mobject(shape) - elif isinstance(shape, se.Rect): - mob = self.rect_to_mobject(shape) - elif isinstance(shape, (se.Circle, se.Ellipse)): - mob = self.ellipse_to_mobject(shape) - elif isinstance(shape, se.Polygon): - mob = self.polygon_to_mobject(shape) - elif isinstance(shape, se.Polyline): - mob = self.polyline_to_mobject(shape) - elif isinstance(shape, se.Text): - mob = self.text_to_mobject(shape) - elif isinstance(shape, se.Use) or type(shape) is se.SVGElement: - continue - else: - logger.warning(f"Unsupported element type: {type(shape)}") - continue - if mob is None or not mob.has_points(): - continue - self.apply_style_to_mobject(mob, shape) - if isinstance(shape, se.Transformable) and shape.apply: - self.handle_transform(mob, shape.transform) - result.append(mob) - return result + stack: list[tuple[se.SVGElement, int]] = [] + stack.append((svg, 1)) + group_id_number = 0 + vgroup_stack: list[str] = ["root"] + vgroup_names: list[str] = ["root"] + vgroups: dict[str, VGroup] = {"root": VGroup()} + while len(stack) > 0: + element, depth = stack.pop() + # Reduce stack heights + vgroup_stack = vgroup_stack[0:(depth)] + try: + group_name = str(element.values["id"]) + except Exception: + group_name = f"numbered_group_{group_id_number}" + group_id_number += 1 + if isinstance(element, se.Group): + vg = VGroup() + vgroups[group_name] = vg + vgroup_names.append(group_name) + vgroup_stack.append(group_name) + parent_name = vgroup_stack[-2] + assert parent_name != group_name + vgroups[parent_name].add(vgroups[group_name]) + + if isinstance(element, (se.Group, se.Use)): + for subelement in element[::-1]: + stack.append((subelement, depth + 1)) + # Add element to the parent vgroup + try: + parent_name = vgroup_stack[depth - 2] + if isinstance( + element, + ( + se.Path, + se.SimpleLine, + se.Rect, + se.Circle, + se.Ellipse, + se.Polygon, + se.Polyline, + se.Text, + ), + ): + mob = self.get_mob_from_shape_element(element) + vgroups[parent_name].add(mob) + except Exception as e: + print(e) + + return vgroups + + def get_mob_from_shape_element(self, shape: se.SVGElement) -> VMobject: + if isinstance(shape, se.Group): # noqa: SIM114 + raise Exception("Should never get here") + elif isinstance(shape, se.Path): + mob: VMobject = self.path_to_mobject(shape) + elif isinstance(shape, se.SimpleLine): + mob = self.line_to_mobject(shape) + elif isinstance(shape, se.Rect): + mob = self.rect_to_mobject(shape) + elif isinstance(shape, (se.Circle, se.Ellipse)): + mob = self.ellipse_to_mobject(shape) + elif isinstance(shape, se.Polygon): + mob = self.polygon_to_mobject(shape) + elif isinstance(shape, se.Polyline): + mob = self.polyline_to_mobject(shape) + elif isinstance(shape, se.Text): + mob = self.text_to_mobject(shape) + elif isinstance(shape, se.Use) or type(shape) is se.SVGElement: + raise Exception("Should never get here - se.Use or se.SVGElement") + else: + logger.warning(f"Unsupported element type: {type(shape)}") + raise Exception(f"Unsupported element type: {type(shape)}") + if mob is None or not mob.has_points(): + raise Exception("mob is empty or have no points") + self.apply_style_to_mobject(mob, shape) + if isinstance(shape, se.Transformable) and shape.apply: + self.handle_transform(mob, shape.transform) + return mob @staticmethod def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject: