Skip to content

Commit 183001f

Browse files
chebyshev interpolation option
1 parent 4349fc6 commit 183001f

File tree

4 files changed

+367
-24
lines changed

4 files changed

+367
-24
lines changed

tests/test_components/test_mode_interp.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ..test_data.test_data_arrays import FS, MODE_SPEC, SIZE_2D
1313
from ..utils import AssertLogLevel
1414
from tidy3d.plugins.mode import ModeSolver
15+
from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC
16+
1517

1618
# Shared test constants
1719
FREQS_DENSE = np.linspace(1e14, 2e14, 20)
@@ -48,6 +50,52 @@ def test_interp_spec_cubic_needs_4_points():
4850
td.ModeInterpSpec(num_points=3, method="cubic")
4951

5052

53+
def test_interp_spec_valid_cheb():
54+
"""Test creating valid ModeInterpSpec with Chebyshev interpolation."""
55+
spec = td.ModeInterpSpec(num_points=10, method="cheb")
56+
assert spec.num_points == 10
57+
assert spec.method == "cheb"
58+
59+
60+
def test_interp_spec_cheb_needs_3_points():
61+
"""Test that Chebyshev interpolation requires at least 3 points."""
62+
with pytest.raises(pydantic.ValidationError, match="Chebyshev interpolation requires at least 3"):
63+
td.ModeInterpSpec(num_points=2, method="cheb")
64+
65+
66+
def test_interp_spec_sampling_points_linear():
67+
"""Test sampling_points for linear interpolation."""
68+
spec = td.ModeInterpSpec(num_points=5, method="linear")
69+
freqs = np.linspace(1e14, 2e14, 100)
70+
sampling = spec.sampling_points(freqs)
71+
72+
assert len(sampling) == 5
73+
assert np.isclose(sampling[0], 1e14)
74+
assert np.isclose(sampling[-1], 2e14)
75+
# Check uniform spacing
76+
diffs = np.diff(sampling)
77+
assert np.allclose(diffs, diffs[0])
78+
79+
80+
def test_interp_spec_sampling_points_cheb():
81+
"""Test sampling_points for Chebyshev interpolation."""
82+
spec = td.ModeInterpSpec(num_points=5, method="cheb")
83+
freqs = np.linspace(1e14, 2e14, 100)
84+
sampling = spec.sampling_points(freqs)
85+
86+
assert len(sampling) == 5
87+
# Chebyshev nodes should include endpoints
88+
assert np.isclose(sampling.min(), 1e14)
89+
assert np.isclose(sampling.max(), 2e14)
90+
91+
# Verify they are Chebyshev nodes
92+
f_min, f_max = 1e14, 2e14
93+
k = np.arange(5)
94+
expected_normalized = np.cos(k * np.pi / 4)
95+
expected = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * expected_normalized
96+
assert np.allclose(np.sort(sampling), np.sort(expected))
97+
98+
5199
def test_interp_spec_min_2_points():
52100
"""Test that at least 2 points are required."""
53101
with pytest.raises(pydantic.ValidationError):
@@ -380,6 +428,107 @@ def test_mode_solver_data_interp_cubic():
380428
assert data_interp.n_complex.shape[0] == 20
381429

382430

431+
def test_mode_solver_data_interp_cheb():
432+
"""Test Chebyshev interpolation on ModeSolverData."""
433+
# Create data with frequencies at Chebyshev nodes
434+
interp_spec = td.ModeInterpSpec(num_points=5, method="cheb")
435+
freqs_all = np.linspace(1e14, 2e14, 50)
436+
freqs_cheb = interp_spec.sampling_points(freqs_all)
437+
438+
mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"))
439+
monitor = td.ModeSolverMonitor(
440+
center=(0, 0, 0),
441+
size=SIZE_2D,
442+
freqs=freqs_cheb,
443+
mode_spec=mode_spec,
444+
name="test_cheb",
445+
)
446+
447+
from ..test_data.test_data_arrays import make_scalar_mode_field_data_array
448+
from ..test_data.test_monitor_data import N_COMPLEX
449+
450+
mode_data = td.ModeSolverData(
451+
monitor=monitor,
452+
Ex=make_scalar_mode_field_data_array("Ex"),
453+
Ey=make_scalar_mode_field_data_array("Ey"),
454+
Ez=make_scalar_mode_field_data_array("Ez"),
455+
Hx=make_scalar_mode_field_data_array("Hx"),
456+
Hy=make_scalar_mode_field_data_array("Hy"),
457+
Hz=make_scalar_mode_field_data_array("Hz"),
458+
n_complex=N_COMPLEX.copy(),
459+
symmetry=(0, 0, 0),
460+
symmetry_center=(0, 0, 0),
461+
grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])),
462+
)
463+
464+
# Interpolate to 50 frequencies
465+
data_interp = mode_data.interp(freqs=freqs_all, method="cheb")
466+
467+
# Check frequency dimension
468+
assert len(data_interp.monitor.freqs) == 50
469+
assert data_interp.n_complex.shape[0] == 50
470+
471+
472+
def test_mode_solver_data_interp_cheb_needs_3_source():
473+
"""Test that Chebyshev interpolation fails with too few source frequencies."""
474+
# Create data with only 2 frequencies
475+
freqs = np.linspace(1e14, 2e14, 2)
476+
mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"))
477+
monitor = td.ModeSolverMonitor(
478+
center=(0, 0, 0),
479+
size=SIZE_2D,
480+
freqs=freqs,
481+
mode_spec=mode_spec,
482+
name="test",
483+
)
484+
485+
from ..test_data.test_data_arrays import make_scalar_mode_field_data_array
486+
from ..test_data.test_monitor_data import N_COMPLEX
487+
488+
mode_data = td.ModeSolverData(
489+
monitor=monitor,
490+
Ex=make_scalar_mode_field_data_array("Ex"),
491+
n_complex=N_COMPLEX.copy(),
492+
symmetry=(0, 0, 0),
493+
symmetry_center=(0, 0, 0),
494+
grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])),
495+
)
496+
497+
freqs_dense = np.linspace(1e14, 2e14, 10)
498+
with pytest.raises(td.exceptions.DataError, match="at least 3 source"):
499+
mode_data.interp(freqs=freqs_dense, method="cheb")
500+
501+
502+
def test_mode_solver_data_interp_cheb_validates_nodes():
503+
"""Test that Chebyshev interpolation validates source frequencies are Chebyshev nodes."""
504+
# Create data with uniform (not Chebyshev) nodes
505+
freqs_uniform = np.linspace(1e14, 2e14, 5)
506+
mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"))
507+
monitor = td.ModeSolverMonitor(
508+
center=(0, 0, 0),
509+
size=SIZE_2D,
510+
freqs=freqs_uniform,
511+
mode_spec=mode_spec,
512+
name="test",
513+
)
514+
515+
from ..test_data.test_data_arrays import make_scalar_mode_field_data_array
516+
from ..test_data.test_monitor_data import N_COMPLEX
517+
518+
mode_data = td.ModeSolverData(
519+
monitor=monitor,
520+
Ex=make_scalar_mode_field_data_array("Ex"),
521+
n_complex=N_COMPLEX.copy(),
522+
symmetry=(0, 0, 0),
523+
symmetry_center=(0, 0, 0),
524+
grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])),
525+
)
526+
527+
freqs_dense = np.linspace(1e14, 2e14, 10)
528+
with pytest.raises(td.exceptions.DataError, match="must be at Chebyshev nodes"):
529+
mode_data.interp(freqs=freqs_dense, method="cheb")
530+
531+
383532
def test_mode_solver_data_interp_preserves_modes():
384533
"""Test that interpolation preserves mode count."""
385534
mode_data = get_mode_solver_data()
@@ -573,6 +722,32 @@ def test_mode_solver_interp_cubic():
573722
assert data.n_complex.shape[0] == 10
574723

575724

725+
def test_mode_solver_interp_cheb():
726+
"""Test that ModeSolver works with Chebyshev interpolation."""
727+
sim = get_simple_sim()
728+
729+
freqs = np.linspace(1e14, 2e14, 20)
730+
mode_spec = td.ModeSpec(
731+
num_modes=2,
732+
sort_spec=td.ModeSortSpec(track_freq="central")
733+
)
734+
735+
# Chebyshev interpolation requires at least 3 points
736+
interp_spec = td.ModeInterpSpec(num_points=5, method="cheb")
737+
738+
solver = ModeSolver(
739+
simulation=sim,
740+
plane=td.Box(center=(0, 0, 0), size=SIZE_2D),
741+
freqs=freqs,
742+
mode_spec=mode_spec,
743+
interp_spec=interp_spec,
744+
)
745+
746+
data = solver.data_raw
747+
assert len(data.monitor.freqs) == 20
748+
assert data.n_complex.shape[0] == 20
749+
750+
576751
def test_mode_solver_without_interp_returns_full_data():
577752
"""Test that solver without interp_spec computes at all frequencies."""
578753
sim = get_simple_sim()
@@ -755,6 +930,74 @@ def test_mode_monitor_interp_spec_none():
755930
assert monitor.interp_spec is None
756931

757932

933+
# ============================================================================
934+
# WavePort interp_spec Tests
935+
# ============================================================================
936+
937+
def make_wave_port():
938+
"""Make a WavePort."""
939+
from tidy3d.plugins.smatrix.ports.wave import WavePort
940+
from tidy3d.components.microwave.path_integrals.integrals.current import AxisAlignedCurrentIntegral
941+
return WavePort(
942+
center=(0, 0, 0),
943+
size=(1, 1, 0),
944+
direction="+",
945+
name="port1",
946+
current_integral=AxisAlignedCurrentIntegral(
947+
center=(0, 0, 0),
948+
size=(1, 1, 0),
949+
sign="+",
950+
extrapolate_to_endpoints=True,
951+
snap_contour_to_grid=True,
952+
)
953+
)
954+
955+
956+
def test_wave_port_to_monitors_propagates_default_interp_spec():
957+
"""Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor."""
958+
959+
port = make_wave_port()
960+
961+
freqs = np.linspace(1e14, 2e14, 20)
962+
monitors = port.to_monitors(freqs=freqs)
963+
964+
assert len(monitors) == 1
965+
monitor = monitors[0]
966+
assert isinstance(monitor, td.ModeMonitor)
967+
assert monitor.interp_spec is not None
968+
assert monitor.interp_spec.num_points == DEFAULT_WAVE_PORT_INTERP_SPEC.num_points
969+
assert monitor.interp_spec.method == DEFAULT_WAVE_PORT_INTERP_SPEC.method
970+
971+
972+
def test_wave_port_to_monitors_propagates_custom_interp_spec():
973+
"""Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor."""
974+
custom_interp = td.ModeInterpSpec(num_points=8, method="cheb")
975+
port = make_wave_port().updated_copy(interp_spec=custom_interp)
976+
977+
freqs = np.linspace(1e14, 2e14, 50)
978+
monitors = port.to_monitors(freqs=freqs)
979+
980+
assert len(monitors) == 1
981+
monitor = monitors[0]
982+
assert isinstance(monitor, td.ModeMonitor)
983+
assert monitor.interp_spec is not None
984+
assert monitor.interp_spec.num_points == 8
985+
assert monitor.interp_spec.method == "cheb"
986+
987+
988+
def test_wave_port_to_monitors_propagates_none_interp_spec():
989+
"""Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor."""
990+
port = make_wave_port().updated_copy(interp_spec=None)
991+
992+
freqs = np.linspace(1e14, 2e14, 20)
993+
monitors = port.to_monitors(freqs=freqs)
994+
995+
assert len(monitors) == 1
996+
monitor = monitors[0]
997+
assert isinstance(monitor, td.ModeMonitor)
998+
assert monitor.interp_spec is None
999+
1000+
7581001
# ============================================================================
7591002
# Placeholder tests for future phases
7601003
# ============================================================================

0 commit comments

Comments
 (0)