Skip to content

Commit 6996084

Browse files
committed
improve test coverage and simplify flatten_shapely_geometry
1 parent e5a71f6 commit 6996084

File tree

6 files changed

+224
-25
lines changed

6 files changed

+224
-25
lines changed

tests/test_components/test_geometry.py

Lines changed: 192 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
import pytest
1313
import shapely
1414
import trimesh
15+
from shapely.geometry import (
16+
GeometryCollection,
17+
LineString,
18+
MultiLineString,
19+
MultiPoint,
20+
MultiPolygon,
21+
Point,
22+
Polygon,
23+
)
1524

1625
import tidy3d as td
1726
from tidy3d.compat import _shapely_is_older_than
@@ -22,6 +31,7 @@
2231
SnapLocation,
2332
SnappingSpec,
2433
flatten_groups,
34+
flatten_shapely_geometries,
2535
snap_box_to_grid,
2636
traverse_geometries,
2737
)
@@ -1137,7 +1147,14 @@ def test_subdivide():
11371147
@pytest.mark.parametrize("snap_location", [SnapLocation.Boundary, SnapLocation.Center])
11381148
@pytest.mark.parametrize(
11391149
"snap_behavior",
1140-
[SnapBehavior.Off, SnapBehavior.Closest, SnapBehavior.Expand, SnapBehavior.Contract],
1150+
[
1151+
SnapBehavior.Off,
1152+
SnapBehavior.Closest,
1153+
SnapBehavior.Expand,
1154+
SnapBehavior.Contract,
1155+
SnapBehavior.StrictExpand,
1156+
SnapBehavior.StrictContract,
1157+
],
11411158
)
11421159
def test_snap_box_to_grid(snap_location, snap_behavior):
11431160
""" "Test that all combinations of SnappingSpec correctly modify a test box without error."""
@@ -1158,12 +1175,78 @@ def test_snap_box_to_grid(snap_location, snap_behavior):
11581175
new_box = snap_box_to_grid(grid, box, snap_spec)
11591176

11601177
if snap_behavior != SnapBehavior.Off and snap_location == SnapLocation.Boundary:
1161-
# Check that the box boundary slightly off from 0.1 was correctly snapped to 0.1
1162-
assert math.isclose(new_box.bounds[0][1], xyz[1])
1163-
# Check that the box boundary slightly off from 0.3 was correctly snapped to 0.3
1164-
assert math.isclose(new_box.bounds[1][1], xyz[3])
1165-
# Check that the box boundary outside the grid was snapped to the smallest grid coordinate
1166-
assert math.isclose(new_box.bounds[0][2], xyz[0])
1178+
# Strict behaviors have different snapping rules, so skip these specific assertions
1179+
if snap_behavior not in (SnapBehavior.StrictExpand, SnapBehavior.StrictContract):
1180+
# Check that the box boundary slightly off from 0.1 was correctly snapped to 0.1
1181+
assert math.isclose(new_box.bounds[0][1], xyz[1])
1182+
# Check that the box boundary slightly off from 0.3 was correctly snapped to 0.3
1183+
assert math.isclose(new_box.bounds[1][1], xyz[3])
1184+
# Check that the box boundary outside the grid was snapped to the smallest grid coordinate
1185+
assert math.isclose(new_box.bounds[0][2], xyz[0])
1186+
1187+
1188+
def test_snap_box_to_grid_strict_behaviors():
1189+
"""Test StrictExpand and StrictContract behaviors specifically."""
1190+
xyz = np.linspace(0, 1, 11) # Grid points at 0.0, 0.1, 0.2, ..., 1.0
1191+
coords = td.Coords(x=xyz, y=xyz, z=xyz)
1192+
grid = td.Grid(boundaries=coords)
1193+
1194+
# Test StrictExpand: should always move endpoints outwards, even if coincident
1195+
box_coincident = td.Box(
1196+
center=(0.1, 0.2, 0.3), size=(0, 0, 0)
1197+
) # Centered exactly on grid points
1198+
snap_spec_strict_expand = SnappingSpec(
1199+
location=[SnapLocation.Boundary] * 3, behavior=[SnapBehavior.StrictExpand] * 3
1200+
)
1201+
1202+
expanded_box = snap_box_to_grid(grid, box_coincident, snap_spec_strict_expand)
1203+
1204+
# StrictExpand should move bounds outwards even when already on grid
1205+
assert expanded_box.bounds[0][0] < 0.1 # Left bound moved left from 0.1
1206+
assert expanded_box.bounds[1][0] > 0.1 # Right bound moved right from 0.1
1207+
assert expanded_box.bounds[0][1] < 0.2 # Bottom bound moved down from 0.2
1208+
assert expanded_box.bounds[1][1] > 0.2 # Top bound moved up from 0.2
1209+
1210+
# Test StrictContract: should always move endpoints inwards, even if coincident
1211+
box_large = td.Box(center=(0.5, 0.5, 0.5), size=(0.4, 0.4, 0.4)) # Spans multiple grid cells
1212+
snap_spec_strict_contract = SnappingSpec(
1213+
location=[SnapLocation.Boundary] * 3, behavior=[SnapBehavior.StrictContract] * 3
1214+
)
1215+
1216+
contracted_box = snap_box_to_grid(grid, box_large, snap_spec_strict_contract)
1217+
1218+
# StrictContract should make the box smaller than the original
1219+
assert contracted_box.size[0] < box_large.size[0]
1220+
assert contracted_box.size[1] < box_large.size[1]
1221+
assert contracted_box.size[2] < box_large.size[2]
1222+
1223+
# Test edge case: box coincident with grid boundaries
1224+
box_on_grid = td.Box(
1225+
center=(0.15, 0.25, 0.35), size=(0.1, 0.1, 0.1)
1226+
) # Boundaries at 0.1,0.2 and 0.2,0.3
1227+
1228+
# Regular Expand shouldn't change a box already coincident with grid
1229+
snap_spec_regular_expand = SnappingSpec(
1230+
location=[SnapLocation.Boundary] * 3, behavior=[SnapBehavior.Expand] * 3
1231+
)
1232+
regular_expanded = snap_box_to_grid(grid, box_on_grid, snap_spec_regular_expand)
1233+
assert np.allclose(regular_expanded.bounds, box_on_grid.bounds) # Should be unchanged
1234+
1235+
# StrictExpand should still expand even when coincident
1236+
strict_expanded = snap_box_to_grid(grid, box_on_grid, snap_spec_strict_expand)
1237+
assert not np.allclose(strict_expanded.bounds, box_on_grid.bounds) # Should be changed
1238+
assert strict_expanded.size[0] > box_on_grid.size[0] # Should be larger
1239+
1240+
# Test with margin parameter for strict behaviors
1241+
snap_spec_strict_expand_margin = SnappingSpec(
1242+
location=[SnapLocation.Boundary] * 3,
1243+
behavior=[SnapBehavior.StrictExpand] * 3,
1244+
margin=(1, 1, 1), # Consider 1 additional grid point when expanding
1245+
)
1246+
1247+
margin_expanded = snap_box_to_grid(grid, box_coincident, snap_spec_strict_expand_margin)
1248+
# With margin=1, should expand even further than without margin
1249+
assert margin_expanded.size[0] >= expanded_box.size[0]
11671250

11681251

11691252
def test_triangulation_with_collinear_vertices():
@@ -1431,3 +1514,105 @@ def test_trim_dims_and_bounds_edge():
14311514
assert np.all(np.array(expected_trimmed_bounds) == np.array(trimmed_bounds)), (
14321515
"Unexpected trimmed bounds"
14331516
)
1517+
1518+
1519+
def test_flatten_shapely_geometries():
1520+
"""Test the flatten_shapely_geometries utility function comprehensively."""
1521+
# Test 1: Single polygon (should be wrapped in list and returned)
1522+
single_polygon = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
1523+
result = flatten_shapely_geometries(single_polygon)
1524+
assert len(result) == 1
1525+
assert result[0] == single_polygon
1526+
1527+
# Test 2: List of polygons (should return as-is)
1528+
poly1 = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
1529+
poly2 = Polygon([(2, 0), (3, 0), (3, 1), (2, 1)])
1530+
polygon_list = [poly1, poly2]
1531+
result = flatten_shapely_geometries(polygon_list)
1532+
assert len(result) == 2
1533+
assert result == polygon_list
1534+
1535+
# Test 3: MultiPolygon (should be flattened)
1536+
multi_polygon = MultiPolygon([poly1, poly2])
1537+
result = flatten_shapely_geometries(multi_polygon)
1538+
assert len(result) == 2
1539+
assert result[0] == poly1
1540+
assert result[1] == poly2
1541+
1542+
# Test 4: Empty geometries (should be filtered out)
1543+
empty_polygon = Polygon()
1544+
mixed_list = [poly1, empty_polygon, poly2]
1545+
result = flatten_shapely_geometries(mixed_list)
1546+
assert len(result) == 2
1547+
assert empty_polygon not in result
1548+
1549+
# Test 5: GeometryCollection (should be recursively flattened)
1550+
line = LineString([(0, 0), (1, 1)])
1551+
point = Point(0, 0)
1552+
collection = GeometryCollection([poly1, line, point, poly2])
1553+
result = flatten_shapely_geometries(collection)
1554+
assert len(result) == 2 # Only polygons kept by default
1555+
assert poly1 in result
1556+
assert poly2 in result
1557+
1558+
# Test 6: Custom keep_types parameter
1559+
result_with_lines = flatten_shapely_geometries(collection, keep_types=(Polygon, LineString))
1560+
assert len(result_with_lines) == 3 # 2 polygons + 1 line
1561+
assert poly1 in result_with_lines
1562+
assert poly2 in result_with_lines
1563+
assert line in result_with_lines
1564+
1565+
# Test 7: Nested collections and multi-geometries
1566+
line1 = LineString([(0, 0), (1, 1)])
1567+
line2 = LineString([(2, 2), (3, 3)])
1568+
multi_line = MultiLineString([line1, line2])
1569+
nested_collection = GeometryCollection(
1570+
[
1571+
collection, # Contains poly1, line, point, poly2
1572+
multi_line,
1573+
poly1,
1574+
]
1575+
)
1576+
result = flatten_shapely_geometries(nested_collection)
1577+
assert len(result) == 3 # poly1 (from collection), poly2 (from collection), poly1 (direct)
1578+
1579+
# Test 8: MultiPoint (should be handled)
1580+
point1 = Point(0, 0)
1581+
point2 = Point(1, 1)
1582+
multi_point = MultiPoint([point1, point2])
1583+
result = flatten_shapely_geometries(multi_point, keep_types=(Point,))
1584+
assert len(result) == 2
1585+
assert point1 in result
1586+
assert point2 in result
1587+
1588+
# Test 9: MultiLineString (should be handled)
1589+
result = flatten_shapely_geometries(multi_line, keep_types=(LineString,))
1590+
assert len(result) == 2
1591+
assert line1 in result
1592+
assert line2 in result
1593+
1594+
# Test 10: Mixed empty and non-empty geometries
1595+
empty_multi = MultiPolygon([])
1596+
mixed_with_empty = [poly1, empty_multi, empty_polygon, poly2]
1597+
result = flatten_shapely_geometries(mixed_with_empty)
1598+
assert len(result) == 2
1599+
assert poly1 in result
1600+
assert poly2 in result
1601+
1602+
# Test 11: Deeply nested structure
1603+
inner_collection = GeometryCollection([poly1, line])
1604+
outer_multi = MultiPolygon([poly2])
1605+
deep_collection = GeometryCollection([inner_collection, outer_multi])
1606+
result = flatten_shapely_geometries(deep_collection)
1607+
assert len(result) == 2
1608+
assert poly1 in result
1609+
assert poly2 in result
1610+
1611+
# Test 12: All geometry types filtered out
1612+
points_and_lines = GeometryCollection([Point(0, 0), LineString([(0, 0), (1, 1)])])
1613+
result = flatten_shapely_geometries(points_and_lines) # Default keeps only Polygons
1614+
assert len(result) == 0
1615+
1616+
# Test 13: Edge case - single empty geometry
1617+
result = flatten_shapely_geometries(empty_polygon)
1618+
assert len(result) == 0

tests/test_components/test_microwave.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def test_antenna_parameters():
326326
antenna_params.partial_realized_gain("invalid")
327327

328328

329-
def test_path_integral_plotting():
330-
"""Test that all types of path integrals correctly plot themselves."""
329+
def test_path_spec_plotting():
330+
"""Test that all types of path specification correctly plot themselves."""
331331

332332
mean_radius = (COAX_R2 + COAX_R1) * 0.5
333333
size = [COAX_R2 - COAX_R1, 0, 0]
@@ -372,6 +372,28 @@ def test_path_integral_plotting():
372372
current_integral.plot(y=2, ax=ax)
373373
plt.close()
374374

375+
current_integral = td.CompositeCurrentIntegralSpec(
376+
path_specs=(
377+
td.CurrentIntegralAxisAlignedSpec(center=(-1, -1, 0), size=(1, 1, 0), sign="-"),
378+
td.CurrentIntegralAxisAlignedSpec(center=(1, 1, 0), size=(1, 1, 0), sign="-"),
379+
),
380+
sum_spec="sum",
381+
)
382+
ax = current_integral.plot(z=0)
383+
plt.close()
384+
385+
386+
@pytest.mark.parametrize("clockwise", [False, True])
387+
def test_custom_current_specification_sign(clockwise):
388+
"""Make sure the sign is correctly calculated for custom current specs."""
389+
current_integral = td.CustomCurrentIntegral2DSpec.from_circular_path(
390+
center=(0, 0, 0), radius=0.4, num_points=31, normal_axis=2, clockwise=clockwise
391+
)
392+
if clockwise:
393+
assert current_integral.sign == "-"
394+
else:
395+
assert current_integral.sign == "+"
396+
375397

376398
def test_composite_current_integral_validation():
377399
"""Ensures that the CompositeCurrentIntegralSpec is validated correctly."""
@@ -704,6 +726,6 @@ def test_mode_solver_with_microwave_mode_spec():
704726
# mms_data.field_components["Ez"].isel(mode_index=0, f=0).real.plot(ax=ax)
705727
# ax.set_aspect("equal")
706728
# plt.show()
707-
# print(mms_data.to_dataframe())
729+
mms_data.to_dataframe()
708730

709731
assert np.all(np.isclose(mms_data.microwave_data.Z0.real, 28.6, 0.2))

tidy3d/components/geometry/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
import pydantic.v1 as pydantic
1313
import shapely
1414
from shapely.geometry import (
15-
GeometryCollection,
16-
MultiLineString,
17-
MultiPoint,
18-
MultiPolygon,
1915
Polygon,
2016
)
21-
from shapely.geometry.base import BaseGeometry
17+
from shapely.geometry.base import (
18+
BaseMultipartGeometry,
19+
)
2220

2321
from tidy3d.components.autograd.utils import get_static
2422
from tidy3d.components.base import Tidy3dBaseModel
@@ -85,9 +83,7 @@ def flatten_shapely_geometries(
8583
continue
8684
if isinstance(geom, keep_types):
8785
flat.append(geom)
88-
elif isinstance(geom, (MultiPolygon, MultiLineString, MultiPoint, GeometryCollection)):
89-
flat.extend(flatten_shapely_geometries(geom.geoms, keep_types))
90-
elif isinstance(geom, BaseGeometry) and hasattr(geom, "geoms"):
86+
elif isinstance(geom, BaseMultipartGeometry):
9187
flat.extend(flatten_shapely_geometries(geom.geoms, keep_types))
9288
return flat
9389

@@ -114,7 +110,7 @@ def merging_geometries_on_plane(
114110
115111
Returns
116112
-------
117-
List[Tuple[Any, shapely]]
113+
List[Tuple[Any, Shapely]]
118114
List of shapes and their property value on the plane after merging.
119115
"""
120116

tidy3d/components/microwave/path_integrals/base_spec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tidy3d.components.types.base import Axis, Direction
1717
from tidy3d.components.validators import assert_line
1818
from tidy3d.constants import MICROMETER, fp_eps
19-
from tidy3d.exceptions import SetupError, Tidy3dError
19+
from tidy3d.exceptions import SetupError
2020
from tidy3d.log import log
2121

2222

@@ -88,7 +88,6 @@ def main_axis(self) -> Axis:
8888
for index, value in enumerate(self.size):
8989
if value != 0:
9090
return index
91-
raise Tidy3dError("Failed to identify axis.")
9291

9392
def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]:
9493
"""Returns the two vertices of this path in the plane defined by ``axis``."""

tidy3d/components/microwave/path_integrals/current_spec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tidy3d.components.validators import assert_plane
2121
from tidy3d.components.viz import add_ax_if_none
2222
from tidy3d.constants import fp_eps
23-
from tidy3d.exceptions import SetupError, Tidy3dError
23+
from tidy3d.exceptions import SetupError
2424

2525

2626
class CurrentIntegralAxisAlignedSpec(AbstractAxesRH, Box):
@@ -52,7 +52,6 @@ def main_axis(self) -> Axis:
5252
for index, value in enumerate(self.size):
5353
if value == 0:
5454
return index
55-
raise Tidy3dError("Failed to identify axis.")
5655

5756
def _to_path_integral_specs(
5857
self, h_horizontal=None, h_vertical=None

tidy3d/components/mode/mode_solver.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,8 +1383,6 @@ def _make_path_integrals(
13831383
def _add_microwave_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
13841384
"""Calculate and add microwave data to ``mode_solver_data`` which uses the path specifications.
13851385
Note: this modifies ``mode_solver_data`` in-place."""
1386-
if self.mode_spec.microwave_mode_spec is None:
1387-
return mode_solver_data
13881386
voltage_integrals, current_integrals = self._make_path_integrals()
13891387
# Need to operate on the full symmetry expanded fields
13901388
mode_solver_data_expanded = mode_solver_data.symmetry_expanded_copy

0 commit comments

Comments
 (0)