From 5a642c9ceae6ac53b8ed21a59e76ff42a3543d65 Mon Sep 17 00:00:00 2001 From: Casey Wojcik Date: Thu, 6 Nov 2025 13:21:21 -0800 Subject: [PATCH] Add interpolation to EME --- CHANGELOG.md | 2 ++ tests/test_components/test_eme.py | 4 +++- tidy3d/components/data/dataset.py | 4 ++++ tidy3d/components/eme/data/sim_data.py | 28 ++++++++++++++++++++++++-- tidy3d/components/eme/grid.py | 19 ++++++++--------- tidy3d/components/eme/simulation.py | 12 +++++++++++ tidy3d/components/mode/mode_solver.py | 4 ++-- tidy3d/components/mode_spec.py | 2 +- tidy3d/components/monitor.py | 11 ++++++++-- 9 files changed, 69 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6cfda5301..392ffc7895 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `tidy3d.plugins.design.DesignSpace.run(..., fn_post=...)` now accepts a `priority` keyword to propagate vGPU queue priority to all automatically batched simulations. - Introduced `BroadbandPulse` for exciting simulations across a wide frequency spectrum. - Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency. +- Added `interp_spec` in `EMEModeSpec` to enable faster multi-frequency EME simulations. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. @@ -62,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion. - Port names in `ModalComponentModeler` and `TerminalComponentModeler` can no longer include the `@` symbol. - Improved speed of convolutions for large inputs. +- Default value of `EMEModeSpec.interp_spec` is `ModeInterpSpec.cheb(num_points=3, reduce_data=True)` for faster multi-frequency EME simulations. ### Fixed - Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL. diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index c2c5d5a4af..dbe5e69e15 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -911,7 +911,9 @@ def _get_mode_solver_data(modes_out=False, num_modes=3): size=(td.inf, td.inf, 0), center=(0, 0, offset), freqs=[td.C_0], - mode_spec=td.ModeSpec(num_modes=num_modes), + mode_spec=td.ModeSpec( + num_modes=num_modes, interp_spec=td.ModeInterpSpec.cheb(num_points=3, reduce_data=True) + ), name=name, ) eme_mode_data = _get_eme_mode_solver_data() diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index e3e4307fc2..aed00f8aa6 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -139,6 +139,10 @@ def _interp_dataarray_in_freq( DataArray Interpolated data array with the same structure but new frequency points. """ + # if dataarray is already stored at the correct frequencies, do nothing + if np.array_equal(freqs, data.f): + return data + # Map 'poly' to xarray's 'barycentric' method xr_method = "barycentric" if method == "poly" else method diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index 03d16e65d9..d9dda28b1f 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -197,11 +197,27 @@ def smatrix_in_basis( modes1 = port_modes1 if not modes2_provided: modes2 = port_modes2 - f1 = list(modes1.field_components.values())[0].f.values - f2 = list(modes2.field_components.values())[0].f.values + f1 = list(modes1.monitor.freqs) + f2 = list(modes2.monitor.freqs) f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs))) + interp_spec1 = ( + modes1.monitor.mode_spec.interp_spec if isinstance(modes1, ModeData) else None + ) + interp_spec2 = ( + modes2.monitor.mode_spec.interp_spec if isinstance(modes2, ModeData) else None + ) + + interp_overlaps = False + if interp_spec1 is not None and interp_spec2 is not None and interp_spec1 == interp_spec2: + interp_overlaps = True + else: + if interp_spec1 is not None: + modes1 = modes1.interpolated_copy + if interp_spec2 is not None: + modes2 = modes2.interpolated_copy + modes_in_1 = "mode_index" in list(modes1.field_components.values())[0].coords modes_in_2 = "mode_index" in list(modes2.field_components.values())[0].coords @@ -259,6 +275,10 @@ def smatrix_in_basis( overlaps1 = modes1.outer_dot(port_modes1, conjugate=False) if not modes_in_1: overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1) + if interp_overlaps: + overlaps1 = modes1._interp_dataarray_in_freq( + overlaps1, freqs=f, method=interp_spec1.method + ) O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1) O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") @@ -288,6 +308,10 @@ def smatrix_in_basis( overlaps2 = modes2.outer_dot(port_modes2, conjugate=False) if not modes_in_2: overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1) + if interp_overlaps: + overlaps2 = modes2._interp_dataarray_in_freq( + overlaps2, freqs=f, method=interp_spec2.method + ) O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2) O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index 2ed22e26ca..7b1b548784 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -11,9 +11,9 @@ from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq +from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size from tidy3d.constants import RADIAN, fp_eps, inf from tidy3d.exceptions import SetupError, ValidationError @@ -26,13 +26,14 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - track_freq: Union[TrackFreq, None] = pd.Field( - None, - title="Mode Tracking Frequency", - description="Parameter that turns on/off mode tracking based on their similarity. " - "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " - "mode tracking based on the lowest, central, or highest frequency. " - "If ``None`` no mode tracking is performed, which is the default for best performance.", + interp_spec: Optional[ModeInterpSpec] = pd.Field( + ModeInterpSpec.cheb(num_points=3, reduce_data=True), + title="Mode frequency interpolation specification", + description="Specification for computing modes at a reduced set of frequencies and " + "interpolating to obtain results at all requested frequencies. This can significantly " + "reduce computational cost for broadband simulations where modes vary smoothly with " + "frequency. Requires frequency tracking to be enabled (``sort_spec.track_freq`` must " + "not be ``None``) to ensure consistent mode ordering across frequencies.", ) angle_theta: Literal[0.0] = pd.Field( diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index 84b3fec261..aaef8aa047 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -1007,6 +1007,18 @@ def _monitor_freqs(self, monitor: Monitor) -> list[pd.NonNegativeFloat]: return list(self.freqs) return list(monitor.freqs) + def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[pd.NonNegativeFloat]: + """Monitor frequencies.""" + freqs = set() + cell_inds = self._monitor_eme_cell_indices(monitor=monitor) + for cell_ind in cell_inds: + interp_spec = self.eme_grid.mode_specs[cell_ind].interp_spec + if interp_spec is None: + freqs |= set(self.freqs) + else: + freqs |= set(interp_spec.sampling_points(self.freqs)) + return list(freqs) + def _monitor_num_freqs(self, monitor: Monitor) -> int: """Total number of freqs included in monitor.""" return len(self._monitor_freqs(monitor=monitor)) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 74218e96b8..055037a416 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -515,8 +515,8 @@ def data_raw(self) -> ModeSolverDataType: A mode solver data type object containing the effective index and mode fields. """ - if self.mode_spec.interp_spec is not None: - _warn_interp_num_points(self.mode_spec.interp_spec, self.freqs) + # if self.mode_spec.interp_spec is not None: + # _warn_interp_num_points(self.mode_spec.interp_spec, self.freqs) if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: return self.rotated_mode_solver_data diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 0ae0256c80..002bf8e6f9 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -433,7 +433,7 @@ def sampling_points(self, freqs: FreqArray) -> FreqArray: >>> interp_spec = ModeInterpSpec.cheb(num_points=10) >>> sampling_freqs = interp_spec.sampling_points(freqs) """ - if self.num_points > len(freqs): + if self.num_points >= len(freqs): return freqs return self.sampling_spec.sampling_points(freqs) diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 1739dbe40f..52e9859284 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -429,15 +429,22 @@ def _warn_num_modes(cls, val, values): ) return val + @property + def _stored_freqs(self) -> list[float]: + """Return actually stored frequencies of the data.""" + return self.mode_spec._sampling_freqs_mode_solver_data(freqs=self.freqs) + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" # Need to store all fields on the mode surface - bytes_single = BYTES_COMPLEX * num_cells * len(self.freqs) * self.mode_spec.num_modes * 6 + bytes_single = ( + BYTES_COMPLEX * num_cells * len(self._stored_freqs) * self.mode_spec.num_modes * 6 + ) if self.mode_spec.precision == "double": return 2 * bytes_single return bytes_single - _warn_interp_num_points = validate_interp_num_points() + # _warn_interp_num_points = validate_interp_num_points() class FieldMonitor(AbstractFieldMonitor, FreqMonitor):