diff --git a/schemas/EMESimulation.json b/schemas/EMESimulation.json index 176d2a501e..4053a94c18 100644 --- a/schemas/EMESimulation.json +++ b/schemas/EMESimulation.json @@ -8140,6 +8140,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeSolverMonitor": { "additionalProperties": false, "properties": { @@ -8261,6 +8294,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/ModeSimulation.json b/schemas/ModeSimulation.json index e43308481a..c0fa04be8b 100644 --- a/schemas/ModeSimulation.json +++ b/schemas/ModeSimulation.json @@ -7353,6 +7353,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -7444,6 +7477,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -7705,6 +7745,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/Simulation.json b/schemas/Simulation.json index c5ee4e818f..e406b1edef 100644 --- a/schemas/Simulation.json +++ b/schemas/Simulation.json @@ -10737,6 +10737,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -10968,6 +10975,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11319,6 +11333,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -11410,6 +11457,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11671,6 +11725,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/TerminalComponentModeler.json b/schemas/TerminalComponentModeler.json index 9adbed89cc..82388e488a 100644 --- a/schemas/TerminalComponentModeler.json +++ b/schemas/TerminalComponentModeler.json @@ -11393,6 +11393,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11624,6 +11631,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11975,6 +11989,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -12066,6 +12113,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -12327,6 +12381,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -18047,6 +18108,19 @@ "type": "PECFrame" } }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ], + "default": { + "attrs": {}, + "method": "cheb", + "num_points": 21, + "type": "ModeInterpSpec" + } + }, "mode_index": { "default": 0, "minimum": 0, diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py new file mode 100644 index 0000000000..b2257bdc50 --- /dev/null +++ b/tests/test_components/test_mode_interp.py @@ -0,0 +1,1033 @@ +"""Tests for mode frequency interpolation.""" + +from __future__ import annotations + +import numpy as np +import pydantic.v1 as pydantic +import pytest + +import tidy3d as td +from tidy3d.plugins.mode import ModeSolver +from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC + +from ..test_data.test_data_arrays import MODE_SPEC, SIZE_2D +from ..utils import AssertLogLevel + +# Shared test constants +FREQS_DENSE = np.linspace(1e14, 2e14, 20) + + +# ============================================================================ +# ModeInterpSpec Tests +# ============================================================================ + + +def test_interp_spec_valid_linear(): + """Test creating valid ModeInterpSpec with linear interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="linear") + assert spec.num_points == 5 + assert spec.method == "linear" + + +def test_interp_spec_valid_cubic(): + """Test creating valid ModeInterpSpec with cubic interpolation.""" + spec = td.ModeInterpSpec(num_points=10, method="cubic") + assert spec.num_points == 10 + assert spec.method == "cubic" + + +def test_interp_spec_default_method(): + """Test that default method is 'linear'.""" + spec = td.ModeInterpSpec(num_points=5) + assert spec.method == "linear" + + +def test_interp_spec_cubic_needs_4_points(): + """Test that cubic interpolation requires at least 4 points.""" + with pytest.raises(pydantic.ValidationError, match="Cubic interpolation requires at least 4"): + td.ModeInterpSpec(num_points=3, method="cubic") + + +def test_interp_spec_valid_cheb(): + """Test creating valid ModeInterpSpec with Chebyshev interpolation.""" + spec = td.ModeInterpSpec(num_points=10, method="cheb") + assert spec.num_points == 10 + assert spec.method == "cheb" + + +def test_interp_spec_cheb_needs_3_points(): + """Test that Chebyshev interpolation requires at least 3 points.""" + with pytest.raises( + pydantic.ValidationError, match="Chebyshev interpolation requires at least 3" + ): + td.ModeInterpSpec(num_points=2, method="cheb") + + +def test_interp_spec_sampling_points_linear(): + """Test sampling_points for linear interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="linear") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + assert np.isclose(sampling[0], 1e14) + assert np.isclose(sampling[-1], 2e14) + # Check uniform spacing + diffs = np.diff(sampling) + assert np.allclose(diffs, diffs[0]) + + +def test_interp_spec_sampling_points_cheb(): + """Test sampling_points for Chebyshev interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="cheb") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + # Chebyshev nodes should include endpoints + assert np.isclose(sampling.min(), 1e14) + assert np.isclose(sampling.max(), 2e14) + + # Verify they are Chebyshev nodes + f_min, f_max = 1e14, 2e14 + k = np.arange(5) + expected_normalized = np.cos(k * np.pi / 4) + expected = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * expected_normalized + assert np.allclose(np.sort(sampling), np.sort(expected)) + + +def test_interp_spec_min_2_points(): + """Test that at least 2 points are required.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=1, method="linear") + + +def test_interp_spec_positive_points(): + """Test that num_points must be positive.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=0, method="linear") + + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=-5, method="linear") + + +def test_interp_spec_invalid_method(): + """Test that invalid interpolation method is rejected.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=5, method="quadratic") + + +# ============================================================================ +# Monitor with interp_spec Tests +# ============================================================================ + + +def test_mode_monitor_requires_tracking(): + """Test that ModeMonitor with interp_spec requires track_freq.""" + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with pytest.raises(pydantic.ValidationError, match="tracking"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_valid_with_tracking(): + """Test that ModeMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + assert monitor.interp_spec.method == "linear" + + +def test_mode_solver_monitor_requires_tracking(): + """Test that ModeSolverMonitor with interp_spec requires track_freq.""" + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with pytest.raises(pydantic.ValidationError, match="tracking"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_valid_with_tracking(): + """Test that ModeSolverMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + + +def test_interp_num_points_less_than_freqs(): + """Test that num_points must be less than total freqs.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=25, method="linear") + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_interp_num_points_equal_to_freqs(): + """Test that num_points equal to freqs is rejected.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=20, method="linear") + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_interp_spec_none_allowed(): + """Test that interp_spec=None is allowed (no interpolation).""" + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + interp_spec=None, + name="test", + ) + assert monitor.interp_spec is None + + +def test_interp_deprecated_track_freq_still_works(): + """Test that deprecated track_freq on ModeSpec still enables interpolation.""" + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + # Using deprecated track_freq instead of sort_spec.track_freq + with AssertLogLevel("WARNING", contains_str="deprecated"): + mode_spec = td.ModeSpec(num_modes=2, track_freq="central") + + # Should still work since _track_freq property resolves it + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + + +# ============================================================================ +# ModeSolver with interp_spec Tests +# ============================================================================ + + +def get_simple_sim(): + """Create a simple simulation for ModeSolver tests.""" + return td.Simulation( + size=(10, 10, 10), + grid_spec=td.GridSpec(wavelength=1.0), + structures=[ + td.Structure( + geometry=td.Box(size=(1, 1, 10)), + medium=td.Medium(permittivity=4.0), + ) + ], + run_time=1e-12, + ) + + +def test_mode_solver_requires_tracking(): + """Test that ModeSolver with interp_spec requires track_freq.""" + sim = get_simple_sim() + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + with pytest.raises(pydantic.ValidationError, match="tracking"): + ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + ) + + +def test_mode_solver_valid_with_tracking(): + """Test that ModeSolver validates with tracking enabled.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + assert solver.interp_spec.num_points == 5 + assert solver.interp_spec.method == "linear" + + +def test_mode_solver_warns_num_points(): + """Test that ModeSolver warns when num_points >= num_freqs.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=25, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + +def test_mode_solver_interp_spec_none(): + """Test that ModeSolver accepts interp_spec=None.""" + sim = get_simple_sim() + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + interp_spec=None, + ) + assert solver.interp_spec is None + + +# ============================================================================ +# ModeSolverData.interp() Tests +# ============================================================================ + + +def get_mode_solver_data(): + """Create a simple ModeSolverData object for testing.""" + from tidy3d.components.data.data_array import GroupIndexDataArray, ModeDispersionDataArray + + from ..test_data.test_data_arrays import ( + make_scalar_mode_field_data_array, + ) + from ..test_data.test_monitor_data import N_COMPLEX + + freqs = np.linspace(1e14, 2e14, 5) + num_modes = len(N_COMPLEX.mode_index) + mode_indices = np.arange(num_modes) + mode_spec = td.ModeSpec(num_modes=num_modes, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test_monitor", + ) + + # Create n_group_raw and dispersion_raw with same shape as n_complex + n_group_values = 1.5 + 0.1 * np.random.random((len(freqs), num_modes)) + n_group_raw = GroupIndexDataArray( + n_group_values, coords={"f": freqs, "mode_index": mode_indices} + ) + + dispersion_values = 10.0 + 2.0 * np.random.random((len(freqs), num_modes)) + dispersion_raw = ModeDispersionDataArray( + dispersion_values, coords={"f": freqs, "mode_index": mode_indices} + ) + + # Create mode data with the right frequencies + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + Ey=make_scalar_mode_field_data_array("Ey"), + Ez=make_scalar_mode_field_data_array("Ez"), + Hx=make_scalar_mode_field_data_array("Hx"), + Hy=make_scalar_mode_field_data_array("Hy"), + Hz=make_scalar_mode_field_data_array("Hz"), + n_complex=N_COMPLEX.copy(), + n_group_raw=n_group_raw, + dispersion_raw=dispersion_raw, + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + return mode_data + + +def test_mode_solver_data_interp_linear(): + """Test linear interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert data_interp.n_complex.shape[0] == 20 + + # Check mode dimension is preserved + assert len(data_interp.n_complex.mode_index) == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 20 + + +def test_mode_solver_data_interp_cubic(): + """Test cubic interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Need at least 4 frequencies for cubic + assert len(mode_data.monitor.freqs) >= 4 + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp(freqs=freqs_dense, method="cubic") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert len(data_interp.n_complex.mode_index) == len(mode_data.n_complex.mode_index) + + +def test_mode_solver_data_interp_cheb(): + """Test Chebyshev interpolation on ModeSolverData.""" + # Create data with frequencies at Chebyshev nodes + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") + freqs_all = np.linspace(1e14, 2e14, 50) + freqs_cheb = interp_spec.sampling_points(freqs_all) + + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs_cheb, + mode_spec=mode_spec, + name="test_cheb", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + Ey=make_scalar_mode_field_data_array("Ey"), + Ez=make_scalar_mode_field_data_array("Ez"), + Hx=make_scalar_mode_field_data_array("Hx"), + Hy=make_scalar_mode_field_data_array("Hy"), + Hz=make_scalar_mode_field_data_array("Hz"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + # Interpolate to 50 frequencies + data_interp = mode_data.interp(freqs=freqs_all, method="cheb") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 50 + assert len(data_interp.n_complex.mode_index) == len(N_COMPLEX.mode_index) + + +def test_mode_solver_data_interp_cheb_needs_3_source(): + """Test that Chebyshev interpolation fails with too few source frequencies.""" + # Create data with only 2 frequencies + freqs = np.linspace(1e14, 2e14, 2) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 3 source"): + mode_data.interp(freqs=freqs_dense, method="cheb") + + +def test_mode_solver_data_interp_cheb_validates_nodes(): + """Test that Chebyshev interpolation validates source frequencies are Chebyshev nodes.""" + # Create data with uniform (not Chebyshev) nodes + freqs_uniform = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs_uniform, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="must be at Chebyshev nodes"): + mode_data.interp(freqs=freqs_dense, method="cheb") + + +def test_mode_solver_data_interp_preserves_modes(): + """Test that interpolation preserves mode count.""" + mode_data = get_mode_solver_data() + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to different number of frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Mode count should be unchanged + assert len(data_interp.n_complex.mode_index) == original_num_modes + + +def test_mode_solver_data_interp_includes_n_group_and_dispersion(): + """Test that interpolation includes n_group_raw and dispersion_raw.""" + mode_data = get_mode_solver_data() + + # Verify source data has n_group_raw and dispersion_raw + assert mode_data.n_group_raw is not None + assert mode_data.dispersion_raw is not None + assert mode_data.n_group_raw.shape == (5, len(mode_data.n_complex.mode_index)) + assert mode_data.dispersion_raw.shape == (5, len(mode_data.n_complex.mode_index)) + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Verify interpolated data has n_group_raw and dispersion_raw + assert data_interp.n_group_raw is not None + assert data_interp.dispersion_raw is not None + + # Check shapes are correct + assert data_interp.n_group_raw.shape == (20, len(data_interp.n_complex.mode_index)) + assert data_interp.dispersion_raw.shape == (20, len(data_interp.n_complex.mode_index)) + + # Check frequency coordinates are correct + assert len(data_interp.n_group_raw.coords["f"]) == 20 + assert len(data_interp.dispersion_raw.coords["f"]) == 20 + + +def test_mode_solver_data_interp_single_frequency(): + """Test that interpolation works with a single target frequency.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to a single frequency in the middle of the range + single_freq = np.array([1.5e14]) + data_interp = mode_data.interp(freqs=single_freq, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 1 + assert data_interp.n_complex.shape[0] == 1 + + # Check mode dimension is preserved + assert len(data_interp.n_complex.mode_index) == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 1 + assert float(field_data.coords["f"]) == 1.5e14 + + # Check n_group_raw and dispersion_raw if present + if data_interp.n_group_raw is not None: + print(data_interp.n_group_raw.shape) + print((1, original_num_modes)) + assert data_interp.n_group_raw.shape == (1, original_num_modes) + if data_interp.dispersion_raw is not None: + assert data_interp.dispersion_raw.shape == (1, original_num_modes) + + +def test_mode_solver_data_interp_cubic_needs_4_source(): + """Test that cubic interpolation fails with too few source frequencies.""" + # Create data with only 3 frequencies + freqs = np.linspace(1e14, 2e14, 3) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 4 source"): + mode_data.interp(freqs=freqs_dense, method="cubic") + + +def test_mode_solver_data_interp_invalid_method(): + """Test that invalid interpolation method raises error.""" + mode_data = get_mode_solver_data() + freqs_dense = np.linspace(1e14, 2e14, 10) + + with pytest.raises(td.exceptions.DataError, match="Invalid interpolation method"): + mode_data.interp(freqs=freqs_dense, method="quadratic") + + +def test_mode_solver_data_interp_extrapolation_warning(): + """Test that extrapolation triggers a warning.""" + mode_data = get_mode_solver_data() + + # Interpolate beyond original range + freqs_extrap = np.linspace(0.5e14, 2.5e14, 10) + + with AssertLogLevel("WARNING", contains_str="outside original range"): + mode_data.interp(freqs=freqs_extrap, method="linear") + + +# ============================================================================ +# ModeSolver Integration Tests (Phase 5) +# ============================================================================ + + +def test_mode_solver_with_interp(): + """Test that ModeSolver uses interpolation when interp_spec is provided.""" + sim = get_simple_sim() + + # Create solver with 10 frequencies + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # Create solver with interpolation: compute at 3 frequencies, interpolate to 10 + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + solver_with_interp = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + # The solver should have the original 10 frequencies + assert len(solver_with_interp.freqs) == 10 + + # The returned data should have 10 frequencies + data = solver_with_interp.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +def test_mode_solver_creates_reduced_freqs(): + """Test that solver creates correct reduced frequency set internally.""" + sim = get_simple_sim() + + # Create solver with 20 frequencies + freqs = np.linspace(1e14, 2e14, 20) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # Compute at 5 frequencies, interpolate to 20 + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + # The returned data should have all 20 frequencies + data = solver.data_raw + assert len(data.monitor.freqs) == 20 + + # The effective indices should be properly interpolated + assert data.n_complex.shape[0] == 20 + + +def test_mode_solver_interp_preserves_num_modes(): + """Test that interpolation preserves the number of modes.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 15) + mode_spec = td.ModeSpec(num_modes=3, sort_spec=td.ModeSortSpec(track_freq="central")) + + interp_spec = td.ModeInterpSpec(num_points=4, method="linear") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + + # Should have 3 modes at each of 15 frequencies + assert data.n_complex.shape == (15, 3) + + +def test_mode_solver_interp_cubic(): + """Test that ModeSolver works with cubic interpolation.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # Cubic interpolation requires at least 4 points + interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +def test_mode_solver_interp_cheb(): + """Test that ModeSolver works with Chebyshev interpolation.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 20) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # Chebyshev interpolation requires at least 3 points + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 20 + assert data.n_complex.shape[0] == 20 + + +def test_mode_solver_without_interp_returns_full_data(): + """Test that solver without interp_spec computes at all frequencies.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=None, # No interpolation + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +# ============================================================================ +# Monitor Integration Tests (Phase 6) +# ============================================================================ + + +def test_mode_monitor_with_interp_spec(): + """Test that ModeMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="mode_monitor", + ) + + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 3 + assert monitor.interp_spec.method == "linear" + + +def test_mode_solver_monitor_with_interp_spec(): + """Test that ModeSolverMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="mode_solver_monitor", + ) + + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 4 + assert monitor.interp_spec.method == "cubic" + + +def test_mode_monitor_interp_requires_tracking(): + """Test that ModeMonitor with interp_spec requires frequency tracking.""" + freqs = np.linspace(1e14, 2e14, 10) + + # Without tracking + mode_spec_no_track = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq=None), # No tracking + ) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_interp_requires_tracking(): + """Test that ModeSolverMonitor with interp_spec requires frequency tracking.""" + freqs = np.linspace(1e14, 2e14, 10) + + # Without tracking + mode_spec_no_track = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq=None), # No tracking + ) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # num_points >= len(freqs) should trigger warning + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeSolverMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + # num_points >= len(freqs) should trigger warning + interp_spec = td.ModeInterpSpec(num_points=6, method="linear") + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_interp_spec_none(): + """Test that ModeMonitor works without interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=None, + name="test", + ) + + assert monitor.interp_spec is None + + +# ============================================================================ +# WavePort interp_spec Tests +# ============================================================================ + + +def make_wave_port(): + """Make a WavePort.""" + from tidy3d.components.microwave.path_integrals.integrals.current import ( + AxisAlignedCurrentIntegral, + ) + from tidy3d.plugins.smatrix.ports.wave import WavePort + + return WavePort( + center=(0, 0, 0), + size=(1, 1, 0), + direction="+", + name="port1", + current_integral=AxisAlignedCurrentIntegral( + center=(0, 0, 0), + size=(1, 1, 0), + sign="+", + extrapolate_to_endpoints=True, + snap_contour_to_grid=True, + ), + ) + + +def test_wave_port_to_monitors_propagates_default_interp_spec(): + """Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor.""" + + port = make_wave_port() + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == DEFAULT_WAVE_PORT_INTERP_SPEC.num_points + assert monitor.interp_spec.method == DEFAULT_WAVE_PORT_INTERP_SPEC.method + + +def test_wave_port_to_monitors_propagates_custom_interp_spec(): + """Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor.""" + custom_interp = td.ModeInterpSpec(num_points=8, method="cheb") + port = make_wave_port().updated_copy(interp_spec=custom_interp) + + freqs = np.linspace(1e14, 2e14, 50) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 8 + assert monitor.interp_spec.method == "cheb" + + +def test_wave_port_to_monitors_propagates_none_interp_spec(): + """Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor.""" + port = make_wave_port().updated_copy(interp_spec=None) + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is None + + +# ============================================================================ +# Placeholder tests for future phases +# ============================================================================ diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index d82c1735cb..21071408d9 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -342,7 +342,7 @@ from .components.mode.simulation import ModeSimulation # modes -from .components.mode_spec import ModeSortSpec, ModeSpec +from .components.mode_spec import ModeInterpSpec, ModeSortSpec, ModeSpec # monitors from .components.monitor import ( @@ -693,6 +693,7 @@ def set_logging_level(level: str) -> None: "ModeAmpsDataArray", "ModeData", "ModeIndexDataArray", + "ModeInterpSpec", "ModeMonitor", "ModeSimulation", "ModeSimulationData", diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index b8f310561a..d7f5e0f407 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -18,7 +18,7 @@ from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType -from tidy3d.components.mode_spec import ModeSortSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSortSpec from tidy3d.components.monitor import ( AuxFieldTimeMonitor, DiffractionMonitor, @@ -47,6 +47,7 @@ Coordinate, EMField, EpsSpecType, + FreqArray, Numpy, PolarizationBasis, Size, @@ -101,6 +102,8 @@ AXIAL_RATIO_CAP = 100 # At this sampling rate, the computed area of a sphere is within ~1% of the true value. MIN_ANGULAR_SAMPLES_SPHERE = 10 +MODE_INTERP_EXTRAPOLATION_TOLERANCE = 1e-2 +CHEB_NODES_TOLERANCE = 1e-5 class MonitorData(AbstractMonitorData, ABC): @@ -1657,6 +1660,9 @@ def eps_spec_match_mode_spec(cls, val, values): """Raise validation error if frequencies in eps_spec does not match frequency list""" if val: mode_data_freqs = values["monitor"].freqs + interp_spec = values["monitor"].mode_spec.interp_spec + if interp_spec is not None: + return val if len(val) != len(mode_data_freqs): raise ValidationError( "eps_spec must be provided at the same frequencies as mode solver data." @@ -2388,6 +2394,214 @@ def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolve """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() + @staticmethod + def _validate_cheb_nodes(freqs: np.ndarray) -> None: + """Validate that frequencies are approximately at Chebyshev nodes. + + Parameters + ---------- + freqs : np.ndarray + Frequency array to validate. + + Raises + ------ + DataError + If frequencies are not close to Chebyshev nodes. + """ + + mode_interp_spec = ModeInterpSpec(method="cheb", num_points=len(freqs)) + expected_freqs = mode_interp_spec.sampling_points(freqs) + + # Sort both arrays for comparison (Chebyshev nodes are naturally sorted in descending order) + freqs_sorted = np.sort(freqs) + expected_sorted = np.sort(expected_freqs) + + # Check relative error + freq_range = np.abs(expected_freqs[-1] - expected_freqs[0]) + max_error = np.max(np.abs(freqs_sorted - expected_sorted)) / freq_range + + if max_error > CHEB_NODES_TOLERANCE: + raise DataError( + f"For Chebyshev interpolation ('cheb'), source frequencies must be at " + f"Chebyshev nodes of the second kind. Maximum relative error: {max_error:.2e}, " + f"tolerance: {CHEB_NODES_TOLERANCE:.2e}. Use ModeInterpSpec.sampling_points() to generate " + f"appropriate frequencies." + ) + + def interp( + self, + freqs: FreqArray, + method: Literal["linear", "cubic", "cheb"] = "linear", + assume_constant_modes: bool = False, + ) -> ModeSolverData: + """Interpolate mode data to new frequency points. + + Interpolates all stored mode data (effective indices, field components, group indices, + and dispersion) from the current frequency grid to a new set of frequencies. This is + useful for obtaining mode data at many frequencies from computations at fewer frequencies, + when modes vary smoothly with frequency. + + Parameters + ---------- + freqs : FreqArray + New frequency points to interpolate to. Should generally span a similar range + as the original frequencies to avoid extrapolation. + method : Literal["linear", "cubic", "cheb"] + Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source + frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source + frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric + formula (requires 3+ source frequencies at Chebyshev nodes). + For complex-valued data, real and imaginary parts are interpolated independently. + + Returns + ------- + ModeSolverData + New :class:`ModeSolverData` object with data interpolated to the requested frequencies. + + Raises + ------ + DataError + If interpolation parameters are invalid (e.g., too few source frequencies for the + chosen method, or source frequencies not at Chebyshev nodes for 'cheb' method). + + Note + ---- + Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate + near mode crossings or regions of rapid mode variation. Use frequency tracking + (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the + second kind within the frequency range. + + Example + ------- + >>> # Compute modes at 5 frequencies + >>> import numpy as np + >>> freqs_sparse = np.linspace(1e14, 2e14, 5) + >>> # ... create mode_solver and compute modes ... + >>> # mode_data = mode_solver.solve() + >>> # Interpolate to 50 frequencies + >>> freqs_dense = np.linspace(1e14, 2e14, 50) + >>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear') + """ + # Validate input + freqs = np.array(freqs) + source_freqs = np.array(self.monitor.freqs) + + # Validate method-specific requirements + if method == "cubic" and len(source_freqs) < 4: + raise DataError( + f"Cubic interpolation requires at least 4 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + + if method == "cheb": + if len(source_freqs) < 3: + raise DataError( + f"Chebyshev interpolation requires at least 3 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + # Validate that source frequencies are approximately Chebyshev nodes + self._validate_cheb_nodes(source_freqs) + + if method not in ["linear", "cubic", "cheb"]: + raise DataError( + f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'cheb'." + ) + + # Build update dictionary + update_dict = {} + + # Interpolate n_complex (required field) + update_dict["n_complex"] = self._interp_dataarray(self.n_complex, freqs, method) + + # Interpolate field components if present + if not assume_constant_modes: + for field_name, field_data in self.field_components.items(): + if field_data is not None: + update_dict[field_name] = self._interp_dataarray(field_data, freqs, method) + + # Interpolate n_group_raw if present + if self.n_group_raw is not None: + update_dict["n_group_raw"] = self._interp_dataarray(self.n_group_raw, freqs, method) + + # Interpolate dispersion_raw if present + if self.dispersion_raw is not None: + update_dict["dispersion_raw"] = self._interp_dataarray( + self.dispersion_raw, freqs, method + ) + + # Interpolate grid correction data if present + if not assume_constant_modes: + for key, data in self._grid_correction_dict.items(): + if isinstance(data, DataArray) and "f" in data.coords: + update_dict[key] = self._interp_dataarray(data, freqs, method) + + # Handle eps_spec if present - use nearest neighbor interpolation + if self.eps_spec is not None: + update_dict["eps_spec"] = list( + self._interp_dataarray( + FreqDataArray(self.eps_spec, coords={"f": np.array(self.monitor.freqs)}), + freqs, + "nearest", + ).data + ) + + # Update monitor with new frequencies + update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs)) + + return self.copy(update=update_dict) + + @staticmethod + def _interp_dataarray( + data: DataArray, + freqs: FreqArray, + method: str, + ) -> DataArray: + """Interpolate a DataArray along the frequency coordinate. + + Parameters + ---------- + data : DataArray + Data array to interpolate. Must have a frequency coordinate ``"f"``. + freqs : FreqArray + New frequency points. + method : str + Interpolation method (``"linear"``, ``"cubic"``, or ``"cheb"``). + For ``"cheb"``, uses barycentric formula for Chebyshev interpolation. + + Returns + ------- + DataArray + Interpolated data array with the same structure but new frequency points. + """ + # Map 'cheb' to xarray's 'barycentric' method + xr_method = "barycentric" if method == "cheb" else method + + # Use xarray's built-in interpolation + # For complex data, this automatically interpolates real and imaginary parts + interp_kwargs = {"method": xr_method} + + # Check if we're extrapolating significantly and warn + freq_min, freq_max = float(data.coords["f"].min()), float(data.coords["f"].max()) + new_freq_min, new_freq_max = float(freqs.min()), float(freqs.max()) + + if new_freq_min < freq_min * ( + 1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE + ) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE): + log.warning( + f"Interpolating to frequencies outside original range " + f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: " + f"[{new_freq_min:.3e}, {new_freq_max:.3e}] Hz. " + "Results may be inaccurate due to extrapolation." + ) + interp_kwargs["kwargs"] = {"fill_value": "extrapolate"} + + if method == "nearest": + return data.sel(f=freqs, method="nearest") + else: + return data.interp(f=freqs, **interp_kwargs) + @property def time_reversed_copy(self) -> FieldData: """Make a copy of the data with direction-reversed fields. In lossy or gyrotropic systems, diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index 03d16e65d9..1296457d02 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -197,8 +197,8 @@ def smatrix_in_basis( modes1 = port_modes1 if not modes2_provided: modes2 = port_modes2 - f1 = list(modes1.field_components.values())[0].f.values - f2 = list(modes2.field_components.values())[0].f.values + f1 = list(modes1.monitor.freqs) + f2 = list(modes2.monitor.freqs) f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs))) @@ -259,6 +259,11 @@ def smatrix_in_basis( overlaps1 = modes1.outer_dot(port_modes1, conjugate=False) if not modes_in_1: overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1) + interp_spec1 = modes1.monitor.mode_spec.interp_spec + if interp_spec1 is not None: + overlaps1 = modes1._interp_dataarray( + overlaps1, freqs=f, method=interp_spec1.method + ) O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1) O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") @@ -288,6 +293,11 @@ def smatrix_in_basis( overlaps2 = modes2.outer_dot(port_modes2, conjugate=False) if not modes_in_2: overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1) + interp_spec2 = modes2.monitor.mode_spec.interp_spec + if interp_spec2: + overlaps2 = modes2._interp_dataarray( + overlaps2, freqs=f, method=interp_spec2.method + ) O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2) O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index d197c60677..abba122cef 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -11,9 +11,9 @@ from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq +from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size from tidy3d.constants import RADIAN, fp_eps, inf from tidy3d.exceptions import SetupError, ValidationError @@ -26,15 +26,19 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - track_freq: Union[TrackFreq, None] = pd.Field( - None, - title="Mode Tracking Frequency", - description="Parameter that turns on/off mode tracking based on their similarity. " - "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " - "mode tracking based on the lowest, central, or highest frequency. " - "If ``None`` no mode tracking is performed, which is the default for best performance.", + interp_spec: ModeInterpSpec = pd.Field( + ModeInterpSpec(method="cheb", num_points=4), title="interp spec", description="interp spec" ) + # track_freq: Union[TrackFreq, None] = pd.Field( + # None, + # title="Mode Tracking Frequency", + # description="Parameter that turns on/off mode tracking based on their similarity. " + # "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " + # "mode tracking based on the lowest, central, or highest frequency. " + # "If ``None`` no mode tracking is performed, which is the default for best performance.", + # ) + angle_theta: Literal[0.0] = pd.Field( 0.0, title="Polar Angle", diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index b01b928e4f..d8fc0e9d4c 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -554,11 +554,12 @@ def mode_solver_monitors(self) -> list[ModeSolverMonitor]: mode_planes = self.eme_grid.mode_planes mode_specs = [eme_mode_spec._to_mode_spec() for eme_mode_spec in self.eme_grid.mode_specs] for i in range(self.eme_grid.num_cells): + freqs_curr = freqs monitor = ModeSolverMonitor( center=mode_planes[i].center, size=mode_planes[i].size, name=f"_eme_mode_solver_monitor_{i}", - freqs=freqs, + freqs=freqs_curr, mode_spec=mode_specs[i], colocate=False, ) @@ -579,6 +580,7 @@ def port_modes_monitor(self) -> EMEModeSolverMonitor: num_modes=self.max_port_modes, num_sweep=None, normalize=self.normalize, + freqs=self.freqs, ) def _post_init_validators(self) -> None: @@ -745,6 +747,9 @@ def _validate_sweep_spec(self): def _validate_monitor_setup(self): """Check monitor setup.""" + for i in range(len(self.monitors)): + if self.monitors[i].freqs is None: + self.monitors[i] = self.monitors[i].updated_copy(freqs=self.freqs) for i, monitor in enumerate(self.monitors): if isinstance(monitor, EMEMonitor): _ = self._monitor_eme_cell_indices(monitor=monitor) @@ -1009,6 +1014,18 @@ def _monitor_freqs(self, monitor: Monitor) -> list[pd.NonNegativeFloat]: return list(self.freqs) return list(monitor.freqs) + def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[pd.NonNegativeFloat]: + """Monitor frequencies.""" + freqs = set() + cell_inds = self._monitor_eme_cell_indices(monitor=monitor) + for cell_ind in cell_inds: + interp_spec = self.eme_grid.mode_specs[cell_ind].interp_spec + if interp_spec is None: + freqs |= set(self.freqs) + else: + freqs |= set(interp_spec.sampling_points(self.freqs)) + return list(freqs) + def _monitor_num_freqs(self, monitor: Monitor) -> int: """Total number of freqs included in monitor.""" return len(self._monitor_freqs(monitor=monitor)) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index f703ca812e..e862b85fcd 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -44,7 +44,7 @@ from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor from tidy3d.components.microwave.path_integrals.factory import make_path_integrals -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor from tidy3d.components.scene import Scene from tidy3d.components.simulation import Simulation @@ -200,6 +200,17 @@ class ModeSolver(Tidy3dBaseModel): "like ``mode_area`` require all E-field components.", ) + interp_spec: Optional[ModeInterpSpec] = pydantic.Field( + None, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of ``interp_spec.num_points`` " + "frequencies and interpolated to obtain results at all requested frequencies. " + "This can significantly reduce computational cost for broadband mode solving where " + "modes vary smoothly with frequency. Requires mode tracking to be enabled via " + "``mode_spec.sort_spec.track_freq``.", + ) + @pydantic.validator("simulation", pre=True, always=True) def _convert_to_simulation(cls, val): """Convert to regular Simulation if e.g. JaxSimulation given.""" @@ -257,6 +268,43 @@ def _warn_plane_crosses_symmetry(cls, val, values): ) return val + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["mode_spec"]) + def _validate_interp_requires_tracking(cls, val, values): + """Validate that frequency tracking is enabled when interpolation is requested.""" + if val is None: + return val + + mode_spec = values.get("mode_spec") + track_freq = mode_spec._track_freq + + if track_freq is None: + raise ValidationError( + "Mode frequency interpolation requires mode tracking to be enabled. " + "Set 'mode_spec.sort_spec.track_freq' to 'central', 'lowest', or 'highest'." + ) + return val + + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["freqs"]) + def _warn_interp_num_points(cls, val, values): + """Warn if num_points is greater than or equal to total frequencies.""" + if val is None: + return val + + freqs = values.get("freqs") + num_freqs = len(freqs) + + if val.num_points >= num_freqs: + log.warning( + f"interp_spec.num_points ({val.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " + f"modes will be computed at all {num_freqs} frequencies.", + custom_loc=["interp_spec", "num_points"], + ) + + return val + def _post_init_validators(self) -> None: self._validate_mode_plane_radius( mode_spec=self.mode_spec, @@ -498,6 +546,30 @@ def _get_data_with_group_index(self) -> ModeSolverData: return mode_solver.data_raw._group_index_post_process(self.mode_spec.group_index_step) + def _get_data_with_interp(self) -> ModeSolverData: + """:class:`.ModeSolverData` computed at reduced frequencies and interpolated back. + + Returns + ------- + ModeSolverData + :class:`.ModeSolverData` object with modes computed at a reduced set of frequencies + specified by ``interp_spec.num_points`` and interpolated to the original frequencies. + """ + + # Create reduced frequency set based on interpolation method + freqs_reduced = self.interp_spec.sampling_points(self.freqs) + + # Create a copy of the mode solver with reduced frequencies and no interp_spec + # (to prevent recursion) + mode_solver_reduced = self.copy(update={"freqs": freqs_reduced, "interp_spec": None}) + + # Get data at reduced frequencies + data_reduced = mode_solver_reduced.data_raw + return data_reduced + + # Interpolate back to original frequencies + return data_reduced.interp(freqs=self.freqs, method=self.interp_spec.method) + @cached_property def grid_snapped(self) -> Grid: """The solver grid snapped to the plane normal and to simulation 0-sized dims if any.""" @@ -531,6 +603,9 @@ def data_raw(self) -> ModeSolverDataType: if self.mode_spec.group_index_step > 0: return self._get_data_with_group_index() + if self.interp_spec is not None and self.interp_spec.num_points < len(self.freqs): + return self._get_data_with_interp() + if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: return self.rotated_mode_solver_data diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index c0d37a0bb1..b491b09cc9 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -8,6 +8,7 @@ import numpy as np import pydantic.v1 as pd +from numpy.typing import ArrayLike from tidy3d.constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps from tidy3d.exceptions import SetupError, ValidationError @@ -86,6 +87,117 @@ class ModeSortSpec(Tidy3dBaseModel): ) +class ModeInterpSpec(Tidy3dBaseModel): + """Specification for mode frequency interpolation. + + Allows computing modes at a reduced set of frequencies and interpolating + to obtain results at all requested frequencies. This can significantly + reduce computational cost for broadband simulations where modes vary + smoothly with frequency. + + Note + ---- + Requires frequency tracking to be enabled (``mode_spec.sort_spec.track_freq`` + must not be ``None``) to ensure mode ordering is consistent across frequencies. + + Example + ------- + >>> interp_spec = ModeInterpSpec(num_points=10, method='linear') + + See Also + -------- + + :class:`ModeSolver`: + Mode solver that can use this specification for efficient broadband computation. + + :class:`ModeSolverMonitor`: + Monitor that can use this specification to reduce mode computation cost. + + :class:`ModeMonitor`: + Monitor that can use this specification to reduce mode computation cost. + """ + + num_points: int = pd.Field( + ..., + title="Number of Frequency Points", + description="Number of frequency points at which to actually compute modes. " + "Must be at least 2 and less than the total number of frequencies requested. " + "The mode solver will compute modes at these sampling frequencies " + "and interpolate to obtain results at all requested frequencies. " + "For 'linear' and 'cubic' methods, points are uniformly spaced. " + "For 'cheb' method, Chebyshev nodes are used.", + ge=2, + ) + + method: Literal["linear", "cubic", "cheb"] = pd.Field( + "linear", + title="Interpolation Method", + description="Method for interpolating mode data between computed frequencies. " + "'linear' uses linear interpolation (faster, requires 2+ points). " + "'cubic' uses cubic spline interpolation (smoother, more accurate, requires 4+ points). " + "'cheb' uses Chebyshev polynomial interpolation with barycentric formula " + "(optimal for smooth functions, requires 3+ points, samples at Chebyshev nodes). " + "For complex-valued data, real and imaginary parts are interpolated independently.", + ) + + @pd.validator("method", always=True) + @skip_if_fields_missing(["num_points"]) + def _validate_method_needs_points(cls, val, values): + """Validate that the method has enough points.""" + num_points = values.get("num_points", 0) + if val == "cubic" and num_points < 4: + raise ValidationError( + "Cubic interpolation requires at least 4 frequency points. " + f"Got num_points={num_points}. " + "Use method='linear' or increase num_points." + ) + if val == "cheb" and num_points < 3: + raise ValidationError( + "Chebyshev interpolation requires at least 3 frequency points. " + f"Got num_points={num_points}. " + "Use method='linear' or increase num_points." + ) + return val + + def sampling_points(self, freqs: ArrayLike) -> np.ndarray: + """Compute frequency sampling points based on the interpolation method. + + Parameters + ---------- + freqs : ArrayLike + Target frequency array. The sampling points will span from min(freqs) to max(freqs). + + Returns + ------- + np.ndarray + Array of ``num_points`` frequency sampling points. + For 'linear' and 'cubic' methods: uniformly spaced points. + For 'cheb' method: Chebyshev nodes of the second kind. + + Example + ------- + >>> import numpy as np + >>> freqs = np.linspace(1e14, 2e14, 100) + >>> interp_spec = ModeInterpSpec(num_points=10, method='cheb') + >>> sampling_freqs = interp_spec.sampling_points(freqs) + """ + freqs_array = np.asarray(freqs) + f_min, f_max = float(freqs_array.min()), float(freqs_array.max()) + + if self.method in ("linear", "cubic"): + # Uniformly spaced points + return np.linspace(f_min, f_max, self.num_points) + elif self.method == "cheb": + # Chebyshev nodes of the second kind: x_k = cos(k*pi/(n-1)) for k=0,...,n-1 + # Map from [-1, 1] to [f_min, f_max] + k = np.arange(self.num_points) + nodes_normalized = np.cos(k * np.pi / (self.num_points - 1)) + # Map from [-1, 1] to [f_min, f_max] + return 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * nodes_normalized + else: + raise ValueError(f"Unknown interpolation method: {self.method}") + + class AbstractModeSpec(Tidy3dBaseModel, ABC): """ Abstract base for mode specification data. @@ -362,3 +474,7 @@ class ModeSpec(AbstractModeSpec): * `Waveguide to ring coupling <../../notebooks/WaveguideToRingCoupling.html>`_ """ + + interp_spec: ModeInterpSpec = pd.Field( + ModeInterpSpec(method="cheb", num_points=4), title="interp spec", description="interp spec" + ) diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 64386a5d09..fcb332ba53 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -17,7 +17,7 @@ from .base_sim.monitor import AbstractMonitor from .medium import MediumType from .microwave.base import MicrowaveBaseModel -from .mode_spec import ModeSpec +from .mode_spec import ModeInterpSpec, ModeSpec from .types import ( ArrayFloat1D, AuxField, @@ -358,6 +358,17 @@ class AbstractModeMonitor(PlanarMonitor, FreqMonitor): description="Use conjugated or non-conjugated dot product for mode decomposition.", ) + interp_spec: Optional[ModeInterpSpec] = pydantic.Field( + None, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of frequencies specified by " + "``interp_spec.num_points`` and interpolated to obtain results at all monitor " + "frequencies. This can significantly reduce computational cost for broadband " + "simulations where modes vary smoothly with frequency. Requires mode tracking to be " + "enabled via ``mode_spec.sort_spec.track_freq``.", + ) + def plot( self, x: Optional[float] = None, @@ -424,6 +435,43 @@ def _warn_num_modes(cls, val, values): ) return val + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["mode_spec"]) + def _validate_interp_requires_tracking(cls, val, values): + """Validate that frequency tracking is enabled when interpolation is requested.""" + if val is None: + return val + + mode_spec = values.get("mode_spec") + track_freq = mode_spec._track_freq + + if track_freq is None: + raise ValidationError( + "Mode frequency interpolation requires mode tracking to be enabled. " + "Set 'mode_spec.sort_spec.track_freq' to 'central', 'lowest', or 'highest'.", + ) + return val + + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["freqs"]) + def _warn_interp_num_points(cls, val, values): + """Warn if num_points is greater than or equal to total frequencies.""" + if val is None: + return val + + freqs = values.get("freqs") + num_freqs = len(freqs) + + if val.num_points >= num_freqs: + log.warning( + f"'interp_spec.num_points' ({val.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " + f"modes will be computed at all {num_freqs} frequencies.", + custom_loc=["interp_spec", "num_points"], + ) + + return val + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" # Need to store all fields on the mode surface diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index 890c906446..fef7d81091 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -20,6 +20,7 @@ ImpedanceCalculator, VoltageIntegralType, ) +from tidy3d.components.mode_spec import ModeInterpSpec from tidy3d.components.monitor import ModeMonitor from tidy3d.components.simulation import Simulation from tidy3d.components.source.field import ModeSource, ModeSpec @@ -36,6 +37,7 @@ DEFAULT_WAVE_PORT_NUM_CELLS = 5 MIN_WAVE_PORT_NUM_CELLS = 3 DEFAULT_WAVE_PORT_FRAME = PECFrame() +DEFAULT_WAVE_PORT_INTERP_SPEC = ModeInterpSpec(num_points=21, method="cheb") class WavePort(AbstractTerminalPort, Box): @@ -108,6 +110,17 @@ class WavePort(AbstractTerminalPort, Box): description="Extrudes structures that intersect the wave port plane by a few grid cells when ``True``, improving mode injection accuracy.", ) + interp_spec: Optional[ModeInterpSpec] = pd.Field( + DEFAULT_WAVE_PORT_INTERP_SPEC, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of frequencies specified by " + "``interp_spec.num_points`` and interpolated to obtain results at all monitor " + "frequencies. This can significantly reduce computational cost for broadband " + "simulations where modes vary smoothly with frequency. Requires mode tracking to be " + "enabled via ``mode_spec.sort_spec.track_freq``.", + ) + def _mode_voltage_coefficients(self, mode_data: ModeData) -> FreqModeDataArray: """Calculates scaling coefficients to convert mode amplitudes to the total port voltage. @@ -185,6 +198,7 @@ def to_monitors( mode_spec=self.mode_spec, store_fields_direction=self.direction, conjugated_dot_product=self.conjugated_dot_product, + interp_spec=self.interp_spec, ) return [mode_mon] @@ -197,6 +211,7 @@ def to_mode_solver(self, simulation: Simulation, freqs: FreqArray) -> ModeSolver freqs=freqs, direction=self.direction, colocate=False, + interp_spec=self.interp_spec, ) return mode_solver