From 09e1e867ff740b6833a77d7db9cf629a2eeb49de Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Tue, 21 Oct 2025 21:16:46 -0700 Subject: [PATCH 1/8] working version --- tests/test_components/test_mode_interp.py | 761 ++++++++++++++++++++++ tidy3d/__init__.py | 3 +- tidy3d/components/data/monitor_data.py | 155 +++++ tidy3d/components/mode/mode_solver.py | 86 ++- tidy3d/components/mode_spec.py | 62 ++ tidy3d/components/monitor.py | 49 +- tidy3d/plugins/smatrix/ports/wave.py | 14 + 7 files changed, 1127 insertions(+), 3 deletions(-) create mode 100644 tests/test_components/test_mode_interp.py diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py new file mode 100644 index 0000000000..cdb1427e27 --- /dev/null +++ b/tests/test_components/test_mode_interp.py @@ -0,0 +1,761 @@ +"""Tests for mode frequency interpolation.""" + +from __future__ import annotations + +import numpy as np +import pydantic.v1 as pydantic +import pytest + +import tidy3d as td +td.config.use_local_subpixel = False + +from ..test_data.test_data_arrays import FS, MODE_SPEC, SIZE_2D +from ..utils import AssertLogLevel +from tidy3d.plugins.mode import ModeSolver + +# Shared test constants +FREQS_DENSE = np.linspace(1e14, 2e14, 20) + + +# ============================================================================ +# ModeInterpSpec Tests +# ============================================================================ + + +def test_interp_spec_valid_linear(): + """Test creating valid ModeInterpSpec with linear interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="linear") + assert spec.num_points == 5 + assert spec.method == "linear" + + +def test_interp_spec_valid_cubic(): + """Test creating valid ModeInterpSpec with cubic interpolation.""" + spec = td.ModeInterpSpec(num_points=10, method="cubic") + assert spec.num_points == 10 + assert spec.method == "cubic" + + +def test_interp_spec_default_method(): + """Test that default method is 'linear'.""" + spec = td.ModeInterpSpec(num_points=5) + assert spec.method == "linear" + + +def test_interp_spec_cubic_needs_4_points(): + """Test that cubic interpolation requires at least 4 points.""" + with pytest.raises(pydantic.ValidationError, match="Cubic interpolation requires at least 4"): + td.ModeInterpSpec(num_points=3, method="cubic") + + +def test_interp_spec_min_2_points(): + """Test that at least 2 points are required.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=1, method="linear") + + +def test_interp_spec_positive_points(): + """Test that num_points must be positive.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=0, method="linear") + + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=-5, method="linear") + + +def test_interp_spec_invalid_method(): + """Test that invalid interpolation method is rejected.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec(num_points=5, method="quadratic") + + +# ============================================================================ +# Monitor with interp_spec Tests +# ============================================================================ + + +def test_mode_monitor_requires_tracking(): + """Test that ModeMonitor with interp_spec requires track_freq.""" + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with pytest.raises(pydantic.ValidationError, match="tracking"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_valid_with_tracking(): + """Test that ModeMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + assert monitor.interp_spec.method == "linear" + + +def test_mode_solver_monitor_requires_tracking(): + """Test that ModeSolverMonitor with interp_spec requires track_freq.""" + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with pytest.raises(pydantic.ValidationError, match="tracking"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_valid_with_tracking(): + """Test that ModeSolverMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + + +def test_interp_num_points_less_than_freqs(): + """Test that num_points must be less than total freqs.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=25, method="linear") + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_interp_num_points_equal_to_freqs(): + """Test that num_points equal to freqs is rejected.""" + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=20, method="linear") + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_interp_spec_none_allowed(): + """Test that interp_spec=None is allowed (no interpolation).""" + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + interp_spec=None, + name="test", + ) + assert monitor.interp_spec is None + + +def test_interp_deprecated_track_freq_still_works(): + """Test that deprecated track_freq on ModeSpec still enables interpolation.""" + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + # Using deprecated track_freq instead of sort_spec.track_freq + with AssertLogLevel("WARNING", contains_str="deprecated"): + mode_spec = td.ModeSpec(num_modes=2, track_freq="central") + + # Should still work since _track_freq property resolves it + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + assert monitor.interp_spec.num_points == 5 + + +# ============================================================================ +# ModeSolver with interp_spec Tests +# ============================================================================ + + +def get_simple_sim(): + """Create a simple simulation for ModeSolver tests.""" + return td.Simulation( + size=(10, 10, 10), + grid_spec=td.GridSpec(wavelength=1.0), + structures=[ + td.Structure( + geometry=td.Box(size=(1, 1, 10)), + medium=td.Medium(permittivity=4.0), + ) + ], + run_time=1e-12, + ) + + +def test_mode_solver_requires_tracking(): + """Test that ModeSolver with interp_spec requires track_freq.""" + sim = get_simple_sim() + mode_spec_no_track = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq=None)) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + with pytest.raises(pydantic.ValidationError, match="tracking"): + ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + ) + + +def test_mode_solver_valid_with_tracking(): + """Test that ModeSolver validates with tracking enabled.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + assert solver.interp_spec.num_points == 5 + assert solver.interp_spec.method == "linear" + + +def test_mode_solver_warns_num_points(): + """Test that ModeSolver warns when num_points >= num_freqs.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=25, method="linear") + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + with AssertLogLevel("WARNING", contains_str="num_points"): + ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + +def test_mode_solver_interp_spec_none(): + """Test that ModeSolver accepts interp_spec=None.""" + sim = get_simple_sim() + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + interp_spec=None, + ) + assert solver.interp_spec is None + + +# ============================================================================ +# ModeSolverData.interp() Tests +# ============================================================================ + + +def get_mode_solver_data(): + """Create a simple ModeSolverData object for testing.""" + from ..test_data.test_data_arrays import ( + make_scalar_mode_field_data_array, + ) + from ..test_data.test_monitor_data import N_COMPLEX + + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test_monitor", + ) + + # Create mode data with the right frequencies + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + Ey=make_scalar_mode_field_data_array("Ey"), + Ez=make_scalar_mode_field_data_array("Ez"), + Hx=make_scalar_mode_field_data_array("Hx"), + Hy=make_scalar_mode_field_data_array("Hy"), + Hz=make_scalar_mode_field_data_array("Hz"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + return mode_data + + +def test_mode_solver_data_interp_linear(): + """Test linear interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = mode_data.n_complex.shape[1] + + # Interpolate to 20 frequencies + freqs_dense = np.linspace( + mode_data.monitor.freqs[0], + mode_data.monitor.freqs[-1], + 20 + ) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert data_interp.n_complex.shape[0] == 20 + + # Check mode dimension is preserved + assert data_interp.n_complex.shape[1] == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 20 + + +def test_mode_solver_data_interp_cubic(): + """Test cubic interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Need at least 4 frequencies for cubic + assert len(mode_data.monitor.freqs) >= 4 + + # Interpolate to 20 frequencies + freqs_dense = np.linspace( + mode_data.monitor.freqs[0], + mode_data.monitor.freqs[-1], + 20 + ) + data_interp = mode_data.interp(freqs=freqs_dense, method="cubic") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert data_interp.n_complex.shape[0] == 20 + + +def test_mode_solver_data_interp_preserves_modes(): + """Test that interpolation preserves mode count.""" + mode_data = get_mode_solver_data() + original_num_modes = mode_data.n_complex.shape[1] + + # Interpolate to different number of frequencies + freqs_dense = np.linspace( + mode_data.monitor.freqs[0], + mode_data.monitor.freqs[-1], + 20 + ) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Mode count should be unchanged + assert data_interp.n_complex.shape[1] == original_num_modes + + +def test_mode_solver_data_interp_too_few_target_freqs(): + """Test that interpolation fails with too few target frequencies.""" + mode_data = get_mode_solver_data() + + with pytest.raises(td.exceptions.DataError, match="fewer than 2"): + mode_data.interp(freqs=[1e14], method="linear") + + +def test_mode_solver_data_interp_cubic_needs_4_source(): + """Test that cubic interpolation fails with too few source frequencies.""" + # Create data with only 3 frequencies + freqs = np.linspace(1e14, 2e14, 3) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 4 source"): + mode_data.interp(freqs=freqs_dense, method="cubic") + + +def test_mode_solver_data_interp_invalid_method(): + """Test that invalid interpolation method raises error.""" + mode_data = get_mode_solver_data() + freqs_dense = np.linspace(1e14, 2e14, 10) + + with pytest.raises(td.exceptions.DataError, match="Invalid interpolation method"): + mode_data.interp(freqs=freqs_dense, method="quadratic") + + +def test_mode_solver_data_interp_extrapolation_warning(): + """Test that extrapolation triggers a warning.""" + mode_data = get_mode_solver_data() + + # Interpolate beyond original range + freqs_extrap = np.linspace(0.5e14, 2.5e14, 10) + + with AssertLogLevel("WARNING", contains_str="outside original range"): + mode_data.interp(freqs=freqs_extrap, method="linear") + + +# ============================================================================ +# ModeSolver Integration Tests (Phase 5) +# ============================================================================ + + +def test_mode_solver_with_interp(): + """Test that ModeSolver uses interpolation when interp_spec is provided.""" + sim = get_simple_sim() + + # Create solver with 10 frequencies + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # Create solver with interpolation: compute at 3 frequencies, interpolate to 10 + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + solver_with_interp = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + # The solver should have the original 10 frequencies + assert len(solver_with_interp.freqs) == 10 + + # The returned data should have 10 frequencies + data = solver_with_interp.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +def test_mode_solver_creates_reduced_freqs(): + """Test that solver creates correct reduced frequency set internally.""" + sim = get_simple_sim() + + # Create solver with 20 frequencies + freqs = np.linspace(1e14, 2e14, 20) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # Compute at 5 frequencies, interpolate to 20 + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + # The returned data should have all 20 frequencies + data = solver.data_raw + assert len(data.monitor.freqs) == 20 + + # The effective indices should be properly interpolated + assert data.n_complex.shape[0] == 20 + + +def test_mode_solver_interp_preserves_num_modes(): + """Test that interpolation preserves the number of modes.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 15) + mode_spec = td.ModeSpec( + num_modes=3, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + interp_spec = td.ModeInterpSpec(num_points=4, method="linear") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + + # Should have 3 modes at each of 15 frequencies + assert data.n_complex.shape == (15, 3) + + +def test_mode_solver_interp_cubic(): + """Test that ModeSolver works with cubic interpolation.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # Cubic interpolation requires at least 4 points + interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +def test_mode_solver_without_interp_returns_full_data(): + """Test that solver without interp_spec computes at all frequencies.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=None, # No interpolation + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape[0] == 10 + + +# ============================================================================ +# Monitor Integration Tests (Phase 6) +# ============================================================================ + + +def test_mode_monitor_with_interp_spec(): + """Test that ModeMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="mode_monitor", + ) + + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 3 + assert monitor.interp_spec.method == "linear" + + +def test_mode_solver_monitor_with_interp_spec(): + """Test that ModeSolverMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="mode_solver_monitor", + ) + + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 4 + assert monitor.interp_spec.method == "cubic" + + +def test_mode_monitor_interp_requires_tracking(): + """Test that ModeMonitor with interp_spec requires frequency tracking.""" + freqs = np.linspace(1e14, 2e14, 10) + + # Without tracking + mode_spec_no_track = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq=None) # No tracking + ) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_interp_requires_tracking(): + """Test that ModeSolverMonitor with interp_spec requires frequency tracking.""" + freqs = np.linspace(1e14, 2e14, 10) + + # Without tracking + mode_spec_no_track = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq=None) # No tracking + ) + interp_spec = td.ModeInterpSpec(num_points=3, method="linear") + + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec_no_track, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # num_points >= len(freqs) should trigger warning + interp_spec = td.ModeInterpSpec(num_points=5, method="linear") + + with AssertLogLevel("WARNING", contains_str="greater than or equal"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_solver_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeSolverMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # num_points >= len(freqs) should trigger warning + interp_spec = td.ModeInterpSpec(num_points=6, method="linear") + + with AssertLogLevel("WARNING", contains_str="greater than or equal"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + name="test", + ) + + +def test_mode_monitor_interp_spec_none(): + """Test that ModeMonitor works without interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + interp_spec=None, + name="test", + ) + + assert monitor.interp_spec is None + + +# ============================================================================ +# Placeholder tests for future phases +# ============================================================================ + diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index d82c1735cb..8135f224ad 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -342,7 +342,7 @@ from .components.mode.simulation import ModeSimulation # modes -from .components.mode_spec import ModeSortSpec, ModeSpec +from .components.mode_spec import ModeInterpSpec, ModeSortSpec, ModeSpec # monitors from .components.monitor import ( @@ -702,6 +702,7 @@ def set_logging_level(level: str) -> None: "ModeSortSpec", "ModeSource", "ModeSpec", + "ModeInterpSpec", "ModulationSpec", "Monitor", "MultiPhysicsMedium", diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index b8f310561a..d5c24e3c1f 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -47,6 +47,7 @@ Coordinate, EMField, EpsSpecType, + FreqArray, Numpy, PolarizationBasis, Size, @@ -101,6 +102,7 @@ AXIAL_RATIO_CAP = 100 # At this sampling rate, the computed area of a sphere is within ~1% of the true value. MIN_ANGULAR_SAMPLES_SPHERE = 10 +MODE_INTERP_EXTRAPOLATION_TOLERANCE = 1e-2 class MonitorData(AbstractMonitorData, ABC): @@ -2388,6 +2390,159 @@ def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolve """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() + def interp( + self, + freqs: FreqArray, + method: Literal["linear", "cubic"] = "linear", + ) -> ModeSolverData: + """Interpolate mode data to new frequency points. + + Interpolates all stored mode data (effective indices, field components, group indices, + and dispersion) from the current frequency grid to a new set of frequencies. This is + useful for obtaining mode data at many frequencies from computations at fewer frequencies, + when modes vary smoothly with frequency. + + Parameters + ---------- + freqs : FreqArray + New frequency points to interpolate to. Should generally span a similar range + as the original frequencies to avoid extrapolation. + method : Literal["linear", "cubic"] + Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source + frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source + frequencies). For complex-valued data, real and imaginary parts are interpolated + independently. + + Returns + ------- + ModeSolverData + New :class:`ModeSolverData` object with data interpolated to the requested frequencies. + + Raises + ------ + DataError + If interpolation parameters are invalid (e.g., too few source frequencies for the + chosen method). + + Note + ---- + Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate + near mode crossings or regions of rapid mode variation. Use frequency tracking + (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + Example + ------- + >>> # Compute modes at 5 frequencies + >>> import numpy as np + >>> freqs_sparse = np.linspace(1e14, 2e14, 5) + >>> # ... create mode_solver and compute modes ... + >>> # mode_data = mode_solver.solve() + >>> # Interpolate to 50 frequencies + >>> freqs_dense = np.linspace(1e14, 2e14, 50) + >>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear') + """ + # Validate input + freqs = np.array(freqs) + if len(freqs) < 2: + raise DataError("Cannot interpolate to fewer than 2 frequency points.") + + source_freqs = self.monitor.freqs + if method == "cubic" and len(source_freqs) < 4: + raise DataError( + f"Cubic interpolation requires at least 4 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + + if method not in ["linear", "cubic"]: + raise DataError(f"Invalid interpolation method '{method}'. Use 'linear' or 'cubic'.") + + # Build update dictionary + update_dict = {} + + # Interpolate n_complex (required field) + update_dict["n_complex"] = self._interp_dataarray( + self.n_complex, freqs, method + ) + + # Interpolate field components if present + for field_name, field_data in self.field_components.items(): + if field_data is not None: + update_dict[field_name] = self._interp_dataarray( + field_data, freqs, method + ) + + # Interpolate n_group_raw if present + if self.n_group_raw is not None: + update_dict["n_group_raw"] = self._interp_dataarray( + self.n_group_raw, freqs, method + ) + + # Interpolate dispersion_raw if present + if self.dispersion_raw is not None: + update_dict["dispersion_raw"] = self._interp_dataarray( + self.dispersion_raw, freqs, method + ) + + # Interpolate grid correction data if present + for key, data in self._grid_correction_dict.items(): + if isinstance(data, DataArray) and "f" in data.coords: + update_dict[key] = self._interp_dataarray(data, freqs, method) + + # Handle eps_spec if present - use nearest neighbor interpolation + if self.eps_spec is not None: + update_dict["eps_spec"] = list(self._interp_dataarray( + FreqDataArray(self.eps_spec, coords=dict(f=self.monitor.freqs)), freqs, "nearest" + ).data) + + # Update monitor with new frequencies + update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs)) + + return self.copy(update=update_dict) + + @staticmethod + def _interp_dataarray( + data: DataArray, + freqs: FreqArray, + method: str, + ) -> DataArray: + """Interpolate a DataArray along the frequency coordinate. + + Parameters + ---------- + data : DataArray + Data array to interpolate. Must have a frequency coordinate ``"f"``. + freqs : FreqArray + New frequency points. + method : str + Interpolation method (``"linear"`` or ``"cubic"``). + + Returns + ------- + DataArray + Interpolated data array with the same structure but new frequency points. + """ + # Use xarray's built-in interpolation + # For complex data, this automatically interpolates real and imaginary parts + interp_kwargs = {"method": method} + + # Check if we're extrapolating significantly and warn + freq_min, freq_max = float(data.coords["f"].min()), float(data.coords["f"].max()) + new_freq_min, new_freq_max = float(freqs.min()), float(freqs.max()) + + if new_freq_min < freq_min * (1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE): + log.warning( + f"Interpolating to frequencies outside original range " + f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: " + f"[{new_freq_min:.3e}, {new_freq_max:.3e}] Hz. " + "Results may be inaccurate due to extrapolation." + ) + interp_kwargs["kwargs"] = {"fill_value": "extrapolate"} + + if method == "nearest": + return data.sel(f=freqs, method="nearest") + else: + return data.interp(f=freqs, **interp_kwargs) + @property def time_reversed_copy(self) -> FieldData: """Make a copy of the data with direction-reversed fields. In lossy or gyrotropic systems, diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index f703ca812e..099c911e8d 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -44,7 +44,7 @@ from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor from tidy3d.components.microwave.path_integrals.factory import make_path_integrals -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor from tidy3d.components.scene import Scene from tidy3d.components.simulation import Simulation @@ -200,6 +200,17 @@ class ModeSolver(Tidy3dBaseModel): "like ``mode_area`` require all E-field components.", ) + interp_spec: Optional[ModeInterpSpec] = pydantic.Field( + None, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of ``interp_spec.num_points`` " + "frequencies and interpolated to obtain results at all requested frequencies. " + "This can significantly reduce computational cost for broadband mode solving where " + "modes vary smoothly with frequency. Requires mode tracking to be enabled via " + "``mode_spec.sort_spec.track_freq``.", + ) + @pydantic.validator("simulation", pre=True, always=True) def _convert_to_simulation(cls, val): """Convert to regular Simulation if e.g. JaxSimulation given.""" @@ -257,6 +268,42 @@ def _warn_plane_crosses_symmetry(cls, val, values): ) return val + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["mode_spec"]) + def _validate_interp_requires_tracking(cls, val, values): + """Validate that frequency tracking is enabled when interpolation is requested.""" + if val is None: + return val + + mode_spec = values.get("mode_spec") + track_freq = mode_spec._track_freq + + if track_freq is None: + raise ValidationError( + "Mode frequency interpolation requires mode tracking to be enabled. " + "Set 'mode_spec.sort_spec.track_freq' to 'central', 'lowest', or 'highest'." + ) + return val + + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["freqs"]) + def _warn_interp_num_points(cls, val, values): + """Warn if num_points is greater than or equal to total frequencies.""" + if val is None: + return val + + freqs = values.get("freqs") + num_freqs = len(freqs) + + if val.num_points >= num_freqs: + log.warning( + f"interp_spec.num_points ({val.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). No computational savings are achieved.", + custom_loc=["interp_spec", "num_points"], + ) + + return val + def _post_init_validators(self) -> None: self._validate_mode_plane_radius( mode_spec=self.mode_spec, @@ -498,6 +545,35 @@ def _get_data_with_group_index(self) -> ModeSolverData: return mode_solver.data_raw._group_index_post_process(self.mode_spec.group_index_step) + def _get_data_with_interp(self) -> ModeSolverData: + """:class:`.ModeSolverData` computed at reduced frequencies and interpolated back. + + Returns + ------- + ModeSolverData + :class:`.ModeSolverData` object with modes computed at a reduced set of frequencies + specified by ``interp_spec.num_points`` and interpolated to the original frequencies. + """ + + # Create reduced frequency set uniformly spaced over the original range + freqs_reduced = np.linspace( + self.freqs[0], + self.freqs[-1], + self.interp_spec.num_points + ) + + # Create a copy of the mode solver with reduced frequencies and no interp_spec + # (to prevent recursion) + mode_solver_reduced = self.copy( + update={"freqs": freqs_reduced, "interp_spec": None} + ) + + # Get data at reduced frequencies + data_reduced = mode_solver_reduced.data_raw + + # Interpolate back to original frequencies + return data_reduced.interp(freqs=self.freqs, method=self.interp_spec.method) + @cached_property def grid_snapped(self) -> Grid: """The solver grid snapped to the plane normal and to simulation 0-sized dims if any.""" @@ -531,23 +607,31 @@ def data_raw(self) -> ModeSolverDataType: if self.mode_spec.group_index_step > 0: return self._get_data_with_group_index() + if self.interp_spec is not None: + return self._get_data_with_interp() + if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: return self.rotated_mode_solver_data # Compute data on the Yee grid mode_solver_data = self._data_on_yee_grid() + # print("mode_solver_data: ", mode_solver_data) if self._has_microwave_mode_spec: mode_solver_data = MicrowaveModeSolverData(**mode_solver_data.dict(exclude={"type"})) + # print("mode_solver_data: ", mode_solver_data) # Colocate to grid boundaries if requested if self.colocate: mode_solver_data = self._colocate_data(mode_solver_data=mode_solver_data) + # print("mode_solver_data: ", mode_solver_data) # normalize modes self._normalize_modes(mode_solver_data=mode_solver_data) # filter polarization if requested + # print("mode_solver_data: ", mode_solver_data) mode_solver_data = self._filter_polarization(mode_solver_data=mode_solver_data) + # print("mode_solver_data: ", mode_solver_data) # filter and sort modes if requested by sort_spec mode_solver_data = mode_solver_data.sort_modes( diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index c0d37a0bb1..336393419e 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -86,6 +86,68 @@ class ModeSortSpec(Tidy3dBaseModel): ) +class ModeInterpSpec(Tidy3dBaseModel): + """Specification for mode frequency interpolation. + + Allows computing modes at a reduced set of frequencies and interpolating + to obtain results at all requested frequencies. This can significantly + reduce computational cost for broadband simulations where modes vary + smoothly with frequency. + + Note + ---- + Requires frequency tracking to be enabled (``mode_spec.sort_spec.track_freq`` + must not be ``None``) to ensure mode ordering is consistent across frequencies. + + Example + ------- + >>> interp_spec = ModeInterpSpec(num_points=10, method='linear') + + See Also + -------- + + :class:`ModeSolver`: + Mode solver that can use this specification for efficient broadband computation. + + :class:`ModeSolverMonitor`: + Monitor that can use this specification to reduce mode computation cost. + + :class:`ModeMonitor`: + Monitor that can use this specification to reduce mode computation cost. + """ + + num_points: int = pd.Field( + ..., + title="Number of Frequency Points", + description="Number of frequency points at which to actually compute modes. " + "Must be at least 2 and less than the total number of frequencies requested. " + "The mode solver will compute modes at this many uniformly-spaced frequencies " + "and interpolate to obtain results at all requested frequencies.", + ge=2, + ) + + method: Literal["linear", "cubic"] = pd.Field( + "linear", + title="Interpolation Method", + description="Method for interpolating mode data between computed frequencies. " + "'linear' uses linear interpolation (faster, requires 2+ points). " + "'cubic' uses cubic spline interpolation (smoother, more accurate, requires 4+ points). " + "For complex-valued data, real and imaginary parts are interpolated independently.", + ) + + @pd.validator("method", always=True) + @skip_if_fields_missing(["num_points"]) + def _validate_cubic_needs_points(cls, val, values): + """Cubic interpolation requires at least 4 points.""" + if val == "cubic" and values.get("num_points", 0) < 4: + raise ValidationError( + "Cubic interpolation requires at least 4 frequency points. " + f"Got num_points={values.get('num_points')}. " + "Use method='linear' or increase num_points." + ) + return val + + class AbstractModeSpec(Tidy3dBaseModel, ABC): """ Abstract base for mode specification data. diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 64386a5d09..b8d156f620 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -17,7 +17,7 @@ from .base_sim.monitor import AbstractMonitor from .medium import MediumType from .microwave.base import MicrowaveBaseModel -from .mode_spec import ModeSpec +from .mode_spec import ModeInterpSpec, ModeSpec from .types import ( ArrayFloat1D, AuxField, @@ -358,6 +358,17 @@ class AbstractModeMonitor(PlanarMonitor, FreqMonitor): description="Use conjugated or non-conjugated dot product for mode decomposition.", ) + interp_spec: Optional[ModeInterpSpec] = pydantic.Field( + None, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of frequencies specified by " + "``interp_spec.num_points`` and interpolated to obtain results at all monitor " + "frequencies. This can significantly reduce computational cost for broadband " + "simulations where modes vary smoothly with frequency. Requires mode tracking to be " + "enabled via ``mode_spec.sort_spec.track_freq``.", + ) + def plot( self, x: Optional[float] = None, @@ -424,6 +435,42 @@ def _warn_num_modes(cls, val, values): ) return val + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["mode_spec"]) + def _validate_interp_requires_tracking(cls, val, values): + """Validate that frequency tracking is enabled when interpolation is requested.""" + if val is None: + return val + + mode_spec = values.get("mode_spec") + track_freq = mode_spec._track_freq + + if track_freq is None: + raise ValidationError( + "Mode frequency interpolation requires mode tracking to be enabled. " + "Set 'mode_spec.sort_spec.track_freq' to 'central', 'lowest', or 'highest'.", + ) + return val + + @pydantic.validator("interp_spec", always=True) + @skip_if_fields_missing(["freqs"]) + def _warn_interp_num_points(cls, val, values): + """Warn if num_points is less than total frequencies.""" + if val is None: + return val + + freqs = values.get("freqs") + num_freqs = len(freqs) + + if val.num_points >= num_freqs: + log.warning( + f"interp_spec.num_points ({val.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). No savings are achieved.", + custom_loc=["interp_spec", "num_points"], + ) + + return val + def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" # Need to store all fields on the mode surface diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index 890c906446..ffaf187cb0 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -15,6 +15,7 @@ from tidy3d.components.geometry.base import Box from tidy3d.components.geometry.bound_ops import bounds_contains from tidy3d.components.grid.grid import Grid +from tidy3d.components.mode_spec import ModeInterpSpec from tidy3d.components.microwave.impedance_calculator import ( CurrentIntegralType, ImpedanceCalculator, @@ -36,6 +37,7 @@ DEFAULT_WAVE_PORT_NUM_CELLS = 5 MIN_WAVE_PORT_NUM_CELLS = 3 DEFAULT_WAVE_PORT_FRAME = PECFrame() +DEFAULT_WAVE_PORT_INTERP_SPEC = ModeInterpSpec(num_points=15, method="cubic") class WavePort(AbstractTerminalPort, Box): @@ -108,6 +110,17 @@ class WavePort(AbstractTerminalPort, Box): description="Extrudes structures that intersect the wave port plane by a few grid cells when ``True``, improving mode injection accuracy.", ) + interp_spec: Optional[ModeInterpSpec] = pd.Field( + DEFAULT_WAVE_PORT_INTERP_SPEC, + title="Mode Interpolation Specification", + description="Parameters for frequency interpolation of mode solver results. " + "If provided, modes are computed at a reduced set of frequencies specified by " + "``interp_spec.num_points`` and interpolated to obtain results at all monitor " + "frequencies. This can significantly reduce computational cost for broadband " + "simulations where modes vary smoothly with frequency. Requires mode tracking to be " + "enabled via ``mode_spec.sort_spec.track_freq``.", + ) + def _mode_voltage_coefficients(self, mode_data: ModeData) -> FreqModeDataArray: """Calculates scaling coefficients to convert mode amplitudes to the total port voltage. @@ -185,6 +198,7 @@ def to_monitors( mode_spec=self.mode_spec, store_fields_direction=self.direction, conjugated_dot_product=self.conjugated_dot_product, + interp_spec=self.interp_spec, ) return [mode_mon] From 4369ece08aa2ab189aa644c9518db3f40630ba44 Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Tue, 21 Oct 2025 21:50:41 -0700 Subject: [PATCH 2/8] chebyshev interpolation option --- tests/test_components/test_mode_interp.py | 243 ++++++++++++++++++++++ tidy3d/components/data/monitor_data.py | 76 ++++++- tidy3d/components/mode/mode_solver.py | 8 +- tidy3d/components/mode_spec.py | 64 +++++- 4 files changed, 367 insertions(+), 24 deletions(-) diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index cdb1427e27..819f78ec19 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -12,6 +12,8 @@ from ..test_data.test_data_arrays import FS, MODE_SPEC, SIZE_2D from ..utils import AssertLogLevel from tidy3d.plugins.mode import ModeSolver +from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC + # Shared test constants FREQS_DENSE = np.linspace(1e14, 2e14, 20) @@ -48,6 +50,52 @@ def test_interp_spec_cubic_needs_4_points(): td.ModeInterpSpec(num_points=3, method="cubic") +def test_interp_spec_valid_cheb(): + """Test creating valid ModeInterpSpec with Chebyshev interpolation.""" + spec = td.ModeInterpSpec(num_points=10, method="cheb") + assert spec.num_points == 10 + assert spec.method == "cheb" + + +def test_interp_spec_cheb_needs_3_points(): + """Test that Chebyshev interpolation requires at least 3 points.""" + with pytest.raises(pydantic.ValidationError, match="Chebyshev interpolation requires at least 3"): + td.ModeInterpSpec(num_points=2, method="cheb") + + +def test_interp_spec_sampling_points_linear(): + """Test sampling_points for linear interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="linear") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + assert np.isclose(sampling[0], 1e14) + assert np.isclose(sampling[-1], 2e14) + # Check uniform spacing + diffs = np.diff(sampling) + assert np.allclose(diffs, diffs[0]) + + +def test_interp_spec_sampling_points_cheb(): + """Test sampling_points for Chebyshev interpolation.""" + spec = td.ModeInterpSpec(num_points=5, method="cheb") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + # Chebyshev nodes should include endpoints + assert np.isclose(sampling.min(), 1e14) + assert np.isclose(sampling.max(), 2e14) + + # Verify they are Chebyshev nodes + f_min, f_max = 1e14, 2e14 + k = np.arange(5) + expected_normalized = np.cos(k * np.pi / 4) + expected = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * expected_normalized + assert np.allclose(np.sort(sampling), np.sort(expected)) + + def test_interp_spec_min_2_points(): """Test that at least 2 points are required.""" with pytest.raises(pydantic.ValidationError): @@ -380,6 +428,107 @@ def test_mode_solver_data_interp_cubic(): assert data_interp.n_complex.shape[0] == 20 +def test_mode_solver_data_interp_cheb(): + """Test Chebyshev interpolation on ModeSolverData.""" + # Create data with frequencies at Chebyshev nodes + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") + freqs_all = np.linspace(1e14, 2e14, 50) + freqs_cheb = interp_spec.sampling_points(freqs_all) + + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs_cheb, + mode_spec=mode_spec, + name="test_cheb", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + Ey=make_scalar_mode_field_data_array("Ey"), + Ez=make_scalar_mode_field_data_array("Ez"), + Hx=make_scalar_mode_field_data_array("Hx"), + Hy=make_scalar_mode_field_data_array("Hy"), + Hz=make_scalar_mode_field_data_array("Hz"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + # Interpolate to 50 frequencies + data_interp = mode_data.interp(freqs=freqs_all, method="cheb") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 50 + assert data_interp.n_complex.shape[0] == 50 + + +def test_mode_solver_data_interp_cheb_needs_3_source(): + """Test that Chebyshev interpolation fails with too few source frequencies.""" + # Create data with only 2 frequencies + freqs = np.linspace(1e14, 2e14, 2) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 3 source"): + mode_data.interp(freqs=freqs_dense, method="cheb") + + +def test_mode_solver_data_interp_cheb_validates_nodes(): + """Test that Chebyshev interpolation validates source frequencies are Chebyshev nodes.""" + # Create data with uniform (not Chebyshev) nodes + freqs_uniform = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs_uniform, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="must be at Chebyshev nodes"): + mode_data.interp(freqs=freqs_dense, method="cheb") + + def test_mode_solver_data_interp_preserves_modes(): """Test that interpolation preserves mode count.""" mode_data = get_mode_solver_data() @@ -573,6 +722,32 @@ def test_mode_solver_interp_cubic(): assert data.n_complex.shape[0] == 10 +def test_mode_solver_interp_cheb(): + """Test that ModeSolver works with Chebyshev interpolation.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 20) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central") + ) + + # Chebyshev interpolation requires at least 3 points + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + interp_spec=interp_spec, + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 20 + assert data.n_complex.shape[0] == 20 + + def test_mode_solver_without_interp_returns_full_data(): """Test that solver without interp_spec computes at all frequencies.""" sim = get_simple_sim() @@ -755,6 +930,74 @@ def test_mode_monitor_interp_spec_none(): assert monitor.interp_spec is None +# ============================================================================ +# WavePort interp_spec Tests +# ============================================================================ + +def make_wave_port(): + """Make a WavePort.""" + from tidy3d.plugins.smatrix.ports.wave import WavePort + from tidy3d.components.microwave.path_integrals.integrals.current import AxisAlignedCurrentIntegral + return WavePort( + center=(0, 0, 0), + size=(1, 1, 0), + direction="+", + name="port1", + current_integral=AxisAlignedCurrentIntegral( + center=(0, 0, 0), + size=(1, 1, 0), + sign="+", + extrapolate_to_endpoints=True, + snap_contour_to_grid=True, + ) + ) + + +def test_wave_port_to_monitors_propagates_default_interp_spec(): + """Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor.""" + + port = make_wave_port() + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == DEFAULT_WAVE_PORT_INTERP_SPEC.num_points + assert monitor.interp_spec.method == DEFAULT_WAVE_PORT_INTERP_SPEC.method + + +def test_wave_port_to_monitors_propagates_custom_interp_spec(): + """Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor.""" + custom_interp = td.ModeInterpSpec(num_points=8, method="cheb") + port = make_wave_port().updated_copy(interp_spec=custom_interp) + + freqs = np.linspace(1e14, 2e14, 50) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is not None + assert monitor.interp_spec.num_points == 8 + assert monitor.interp_spec.method == "cheb" + + +def test_wave_port_to_monitors_propagates_none_interp_spec(): + """Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor.""" + port = make_wave_port().updated_copy(interp_spec=None) + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.interp_spec is None + + # ============================================================================ # Placeholder tests for future phases # ============================================================================ diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index d5c24e3c1f..c67a2d1452 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -18,7 +18,7 @@ from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType -from tidy3d.components.mode_spec import ModeSortSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSortSpec from tidy3d.components.monitor import ( AuxFieldTimeMonitor, DiffractionMonitor, @@ -103,6 +103,7 @@ # At this sampling rate, the computed area of a sphere is within ~1% of the true value. MIN_ANGULAR_SAMPLES_SPHERE = 10 MODE_INTERP_EXTRAPOLATION_TOLERANCE = 1e-2 +CHEB_NODES_TOLERANCE = 1e-5 class MonitorData(AbstractMonitorData, ABC): @@ -2390,10 +2391,44 @@ def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolve """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() + @staticmethod + def _validate_cheb_nodes(freqs: np.ndarray) -> None: + """Validate that frequencies are approximately at Chebyshev nodes. + + Parameters + ---------- + freqs : np.ndarray + Frequency array to validate. + + Raises + ------ + DataError + If frequencies are not close to Chebyshev nodes. + """ + + mode_interp_spec = ModeInterpSpec(method="cheb", num_points=len(freqs)) + expected_freqs = mode_interp_spec.sampling_points(freqs) + + # Sort both arrays for comparison (Chebyshev nodes are naturally sorted in descending order) + freqs_sorted = np.sort(freqs) + expected_sorted = np.sort(expected_freqs) + + # Check relative error + freq_range = np.abs(expected_freqs[-1] - expected_freqs[0]) + max_error = np.max(np.abs(freqs_sorted - expected_sorted)) / freq_range + + if max_error > CHEB_NODES_TOLERANCE: + raise DataError( + f"For Chebyshev interpolation ('cheb'), source frequencies must be at " + f"Chebyshev nodes of the second kind. Maximum relative error: {max_error:.2e}, " + f"tolerance: {CHEB_NODES_TOLERANCE:.2e}. Use ModeInterpSpec.sampling_points() to generate " + f"appropriate frequencies." + ) + def interp( self, freqs: FreqArray, - method: Literal["linear", "cubic"] = "linear", + method: Literal["linear", "cubic", "cheb"] = "linear", ) -> ModeSolverData: """Interpolate mode data to new frequency points. @@ -2407,11 +2442,12 @@ def interp( freqs : FreqArray New frequency points to interpolate to. Should generally span a similar range as the original frequencies to avoid extrapolation. - method : Literal["linear", "cubic"] + method : Literal["linear", "cubic", "cheb"] Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source - frequencies). For complex-valued data, real and imaginary parts are interpolated - independently. + frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric + formula (requires 3+ source frequencies at Chebyshev nodes). + For complex-valued data, real and imaginary parts are interpolated independently. Returns ------- @@ -2422,13 +2458,16 @@ def interp( ------ DataError If interpolation parameters are invalid (e.g., too few source frequencies for the - chosen method). + chosen method, or source frequencies not at Chebyshev nodes for 'cheb' method). Note ---- Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate near mode crossings or regions of rapid mode variation. Use frequency tracking (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the + second kind within the frequency range. Example ------- @@ -2446,15 +2485,26 @@ def interp( if len(freqs) < 2: raise DataError("Cannot interpolate to fewer than 2 frequency points.") - source_freqs = self.monitor.freqs + source_freqs = np.array(self.monitor.freqs) + + # Validate method-specific requirements if method == "cubic" and len(source_freqs) < 4: raise DataError( f"Cubic interpolation requires at least 4 source frequency points. " f"Got {len(source_freqs)}. Use method='linear' instead." ) + + if method == "cheb": + if len(source_freqs) < 3: + raise DataError( + f"Chebyshev interpolation requires at least 3 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + # Validate that source frequencies are approximately Chebyshev nodes + self._validate_cheb_nodes(source_freqs) - if method not in ["linear", "cubic"]: - raise DataError(f"Invalid interpolation method '{method}'. Use 'linear' or 'cubic'.") + if method not in ["linear", "cubic", "cheb"]: + raise DataError(f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'cheb'.") # Build update dictionary update_dict = {} @@ -2514,16 +2564,20 @@ def _interp_dataarray( freqs : FreqArray New frequency points. method : str - Interpolation method (``"linear"`` or ``"cubic"``). + Interpolation method (``"linear"``, ``"cubic"``, or ``"cheb"``). + For ``"cheb"``, uses barycentric formula for Chebyshev interpolation. Returns ------- DataArray Interpolated data array with the same structure but new frequency points. """ + # Map 'cheb' to xarray's 'barycentric' method + xr_method = "barycentric" if method == "cheb" else method + # Use xarray's built-in interpolation # For complex data, this automatically interpolates real and imaginary parts - interp_kwargs = {"method": method} + interp_kwargs = {"method": xr_method} # Check if we're extrapolating significantly and warn freq_min, freq_max = float(data.coords["f"].min()), float(data.coords["f"].max()) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 099c911e8d..037b883e10 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -555,12 +555,8 @@ def _get_data_with_interp(self) -> ModeSolverData: specified by ``interp_spec.num_points`` and interpolated to the original frequencies. """ - # Create reduced frequency set uniformly spaced over the original range - freqs_reduced = np.linspace( - self.freqs[0], - self.freqs[-1], - self.interp_spec.num_points - ) + # Create reduced frequency set based on interpolation method + freqs_reduced = self.interp_spec.sampling_points(self.freqs) # Create a copy of the mode solver with reduced frequencies and no interp_spec # (to prevent recursion) diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 336393419e..91013da39c 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -7,6 +7,7 @@ from typing import Literal, Optional, Union import numpy as np +from numpy.typing import ArrayLike import pydantic.v1 as pd from tidy3d.constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps @@ -121,32 +122,81 @@ class ModeInterpSpec(Tidy3dBaseModel): title="Number of Frequency Points", description="Number of frequency points at which to actually compute modes. " "Must be at least 2 and less than the total number of frequencies requested. " - "The mode solver will compute modes at this many uniformly-spaced frequencies " - "and interpolate to obtain results at all requested frequencies.", + "The mode solver will compute modes at these sampling frequencies " + "and interpolate to obtain results at all requested frequencies. " + "For 'linear' and 'cubic' methods, points are uniformly spaced. " + "For 'cheb' method, Chebyshev nodes are used.", ge=2, ) - method: Literal["linear", "cubic"] = pd.Field( + method: Literal["linear", "cubic", "cheb"] = pd.Field( "linear", title="Interpolation Method", description="Method for interpolating mode data between computed frequencies. " "'linear' uses linear interpolation (faster, requires 2+ points). " "'cubic' uses cubic spline interpolation (smoother, more accurate, requires 4+ points). " + "'cheb' uses Chebyshev polynomial interpolation with barycentric formula " + "(optimal for smooth functions, requires 3+ points, samples at Chebyshev nodes). " "For complex-valued data, real and imaginary parts are interpolated independently.", ) @pd.validator("method", always=True) @skip_if_fields_missing(["num_points"]) - def _validate_cubic_needs_points(cls, val, values): - """Cubic interpolation requires at least 4 points.""" - if val == "cubic" and values.get("num_points", 0) < 4: + def _validate_method_needs_points(cls, val, values): + """Validate that the method has enough points.""" + num_points = values.get("num_points", 0) + if val == "cubic" and num_points < 4: raise ValidationError( "Cubic interpolation requires at least 4 frequency points. " - f"Got num_points={values.get('num_points')}. " + f"Got num_points={num_points}. " + "Use method='linear' or increase num_points." + ) + if val == "cheb" and num_points < 3: + raise ValidationError( + "Chebyshev interpolation requires at least 3 frequency points. " + f"Got num_points={num_points}. " "Use method='linear' or increase num_points." ) return val + def sampling_points(self, freqs: ArrayLike) -> np.ndarray: + """Compute frequency sampling points based on the interpolation method. + + Parameters + ---------- + freqs : ArrayLike + Target frequency array. The sampling points will span from min(freqs) to max(freqs). + + Returns + ------- + np.ndarray + Array of ``num_points`` frequency sampling points. + For 'linear' and 'cubic' methods: uniformly spaced points. + For 'cheb' method: Chebyshev nodes of the second kind. + + Example + ------- + >>> import numpy as np + >>> freqs = np.linspace(1e14, 2e14, 100) + >>> interp_spec = ModeInterpSpec(num_points=10, method='cheb') + >>> sampling_freqs = interp_spec.sampling_points(freqs) + """ + freqs_array = np.asarray(freqs) + f_min, f_max = float(freqs_array.min()), float(freqs_array.max()) + + if self.method in ("linear", "cubic"): + # Uniformly spaced points + return np.linspace(f_min, f_max, self.num_points) + elif self.method == "cheb": + # Chebyshev nodes of the second kind: x_k = cos(k*pi/(n-1)) for k=0,...,n-1 + # Map from [-1, 1] to [f_min, f_max] + k = np.arange(self.num_points) + nodes_normalized = np.cos(k * np.pi / (self.num_points - 1)) + # Map from [-1, 1] to [f_min, f_max] + return 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * nodes_normalized + else: + raise ValueError(f"Unknown interpolation method: {self.method}") + class AbstractModeSpec(Tidy3dBaseModel, ABC): """ From 7e94e6cfd7103c1489ee1b3be707a6b0d4617293 Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Wed, 22 Oct 2025 14:03:55 -0700 Subject: [PATCH 3/8] polish --- tests/test_components/test_mode_interp.py | 197 +++++++++------------- tidy3d/__init__.py | 2 +- tidy3d/components/data/monitor_data.py | 44 ++--- tidy3d/components/mode/mode_solver.py | 4 +- tidy3d/components/mode_spec.py | 2 +- tidy3d/plugins/smatrix/ports/wave.py | 5 +- 6 files changed, 108 insertions(+), 146 deletions(-) diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index 819f78ec19..0e4d22ea3c 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -7,13 +7,14 @@ import pytest import tidy3d as td + td.config.use_local_subpixel = False -from ..test_data.test_data_arrays import FS, MODE_SPEC, SIZE_2D -from ..utils import AssertLogLevel from tidy3d.plugins.mode import ModeSolver from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC +from ..test_data.test_data_arrays import MODE_SPEC, SIZE_2D +from ..utils import AssertLogLevel # Shared test constants FREQS_DENSE = np.linspace(1e14, 2e14, 20) @@ -59,7 +60,9 @@ def test_interp_spec_valid_cheb(): def test_interp_spec_cheb_needs_3_points(): """Test that Chebyshev interpolation requires at least 3 points.""" - with pytest.raises(pydantic.ValidationError, match="Chebyshev interpolation requires at least 3"): + with pytest.raises( + pydantic.ValidationError, match="Chebyshev interpolation requires at least 3" + ): td.ModeInterpSpec(num_points=2, method="cheb") @@ -68,7 +71,7 @@ def test_interp_spec_sampling_points_linear(): spec = td.ModeInterpSpec(num_points=5, method="linear") freqs = np.linspace(1e14, 2e14, 100) sampling = spec.sampling_points(freqs) - + assert len(sampling) == 5 assert np.isclose(sampling[0], 1e14) assert np.isclose(sampling[-1], 2e14) @@ -82,12 +85,12 @@ def test_interp_spec_sampling_points_cheb(): spec = td.ModeInterpSpec(num_points=5, method="cheb") freqs = np.linspace(1e14, 2e14, 100) sampling = spec.sampling_points(freqs) - + assert len(sampling) == 5 # Chebyshev nodes should include endpoints assert np.isclose(sampling.min(), 1e14) assert np.isclose(sampling.max(), 2e14) - + # Verify they are Chebyshev nodes f_min, f_max = 1e14, 2e14 k = np.arange(5) @@ -387,11 +390,7 @@ def test_mode_solver_data_interp_linear(): original_num_modes = mode_data.n_complex.shape[1] # Interpolate to 20 frequencies - freqs_dense = np.linspace( - mode_data.monitor.freqs[0], - mode_data.monitor.freqs[-1], - 20 - ) + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) data_interp = mode_data.interp(freqs=freqs_dense, method="linear") # Check frequency dimension @@ -416,11 +415,7 @@ def test_mode_solver_data_interp_cubic(): assert len(mode_data.monitor.freqs) >= 4 # Interpolate to 20 frequencies - freqs_dense = np.linspace( - mode_data.monitor.freqs[0], - mode_data.monitor.freqs[-1], - 20 - ) + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) data_interp = mode_data.interp(freqs=freqs_dense, method="cubic") # Check frequency dimension @@ -434,7 +429,7 @@ def test_mode_solver_data_interp_cheb(): interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") freqs_all = np.linspace(1e14, 2e14, 50) freqs_cheb = interp_spec.sampling_points(freqs_all) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) monitor = td.ModeSolverMonitor( center=(0, 0, 0), @@ -535,11 +530,7 @@ def test_mode_solver_data_interp_preserves_modes(): original_num_modes = mode_data.n_complex.shape[1] # Interpolate to different number of frequencies - freqs_dense = np.linspace( - mode_data.monitor.freqs[0], - mode_data.monitor.freqs[-1], - 20 - ) + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) data_interp = mode_data.interp(freqs=freqs_dense, method="linear") # Mode count should be unchanged @@ -612,17 +603,14 @@ def test_mode_solver_data_interp_extrapolation_warning(): def test_mode_solver_with_interp(): """Test that ModeSolver uses interpolation when interp_spec is provided.""" sim = get_simple_sim() - + # Create solver with 10 frequencies freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # Create solver with interpolation: compute at 3 frequencies, interpolate to 10 interp_spec = td.ModeInterpSpec(num_points=3, method="linear") - + solver_with_interp = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -630,10 +618,10 @@ def test_mode_solver_with_interp(): mode_spec=mode_spec, interp_spec=interp_spec, ) - + # The solver should have the original 10 frequencies assert len(solver_with_interp.freqs) == 10 - + # The returned data should have 10 frequencies data = solver_with_interp.data_raw assert len(data.monitor.freqs) == 10 @@ -643,17 +631,14 @@ def test_mode_solver_with_interp(): def test_mode_solver_creates_reduced_freqs(): """Test that solver creates correct reduced frequency set internally.""" sim = get_simple_sim() - + # Create solver with 20 frequencies freqs = np.linspace(1e14, 2e14, 20) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # Compute at 5 frequencies, interpolate to 20 interp_spec = td.ModeInterpSpec(num_points=5, method="linear") - + solver = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -661,11 +646,11 @@ def test_mode_solver_creates_reduced_freqs(): mode_spec=mode_spec, interp_spec=interp_spec, ) - + # The returned data should have all 20 frequencies data = solver.data_raw assert len(data.monitor.freqs) == 20 - + # The effective indices should be properly interpolated assert data.n_complex.shape[0] == 20 @@ -673,15 +658,12 @@ def test_mode_solver_creates_reduced_freqs(): def test_mode_solver_interp_preserves_num_modes(): """Test that interpolation preserves the number of modes.""" sim = get_simple_sim() - + freqs = np.linspace(1e14, 2e14, 15) - mode_spec = td.ModeSpec( - num_modes=3, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=3, sort_spec=td.ModeSortSpec(track_freq="central")) + interp_spec = td.ModeInterpSpec(num_points=4, method="linear") - + solver = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -689,9 +671,9 @@ def test_mode_solver_interp_preserves_num_modes(): mode_spec=mode_spec, interp_spec=interp_spec, ) - + data = solver.data_raw - + # Should have 3 modes at each of 15 frequencies assert data.n_complex.shape == (15, 3) @@ -699,16 +681,13 @@ def test_mode_solver_interp_preserves_num_modes(): def test_mode_solver_interp_cubic(): """Test that ModeSolver works with cubic interpolation.""" sim = get_simple_sim() - + freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # Cubic interpolation requires at least 4 points interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") - + solver = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -716,7 +695,7 @@ def test_mode_solver_interp_cubic(): mode_spec=mode_spec, interp_spec=interp_spec, ) - + data = solver.data_raw assert len(data.monitor.freqs) == 10 assert data.n_complex.shape[0] == 10 @@ -725,16 +704,13 @@ def test_mode_solver_interp_cubic(): def test_mode_solver_interp_cheb(): """Test that ModeSolver works with Chebyshev interpolation.""" sim = get_simple_sim() - + freqs = np.linspace(1e14, 2e14, 20) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # Chebyshev interpolation requires at least 3 points interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") - + solver = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -742,7 +718,7 @@ def test_mode_solver_interp_cheb(): mode_spec=mode_spec, interp_spec=interp_spec, ) - + data = solver.data_raw assert len(data.monitor.freqs) == 20 assert data.n_complex.shape[0] == 20 @@ -751,13 +727,10 @@ def test_mode_solver_interp_cheb(): def test_mode_solver_without_interp_returns_full_data(): """Test that solver without interp_spec computes at all frequencies.""" sim = get_simple_sim() - + freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + solver = ModeSolver( simulation=sim, plane=td.Box(center=(0, 0, 0), size=SIZE_2D), @@ -765,7 +738,7 @@ def test_mode_solver_without_interp_returns_full_data(): mode_spec=mode_spec, interp_spec=None, # No interpolation ) - + data = solver.data_raw assert len(data.monitor.freqs) == 10 assert data.n_complex.shape[0] == 10 @@ -779,12 +752,9 @@ def test_mode_solver_without_interp_returns_full_data(): def test_mode_monitor_with_interp_spec(): """Test that ModeMonitor can be created with interp_spec.""" freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) interp_spec = td.ModeInterpSpec(num_points=3, method="linear") - + monitor = td.ModeMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -793,7 +763,7 @@ def test_mode_monitor_with_interp_spec(): interp_spec=interp_spec, name="mode_monitor", ) - + assert monitor.interp_spec is not None assert monitor.interp_spec.num_points == 3 assert monitor.interp_spec.method == "linear" @@ -802,12 +772,9 @@ def test_mode_monitor_with_interp_spec(): def test_mode_solver_monitor_with_interp_spec(): """Test that ModeSolverMonitor can be created with interp_spec.""" freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) interp_spec = td.ModeInterpSpec(num_points=4, method="cubic") - + monitor = td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -816,7 +783,7 @@ def test_mode_solver_monitor_with_interp_spec(): interp_spec=interp_spec, name="mode_solver_monitor", ) - + assert monitor.interp_spec is not None assert monitor.interp_spec.num_points == 4 assert monitor.interp_spec.method == "cubic" @@ -825,14 +792,14 @@ def test_mode_solver_monitor_with_interp_spec(): def test_mode_monitor_interp_requires_tracking(): """Test that ModeMonitor with interp_spec requires frequency tracking.""" freqs = np.linspace(1e14, 2e14, 10) - + # Without tracking mode_spec_no_track = td.ModeSpec( num_modes=2, - sort_spec=td.ModeSortSpec(track_freq=None) # No tracking + sort_spec=td.ModeSortSpec(track_freq=None), # No tracking ) interp_spec = td.ModeInterpSpec(num_points=3, method="linear") - + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): td.ModeMonitor( center=(0, 0, 0), @@ -847,14 +814,14 @@ def test_mode_monitor_interp_requires_tracking(): def test_mode_solver_monitor_interp_requires_tracking(): """Test that ModeSolverMonitor with interp_spec requires frequency tracking.""" freqs = np.linspace(1e14, 2e14, 10) - + # Without tracking mode_spec_no_track = td.ModeSpec( num_modes=2, - sort_spec=td.ModeSortSpec(track_freq=None) # No tracking + sort_spec=td.ModeSortSpec(track_freq=None), # No tracking ) interp_spec = td.ModeInterpSpec(num_points=3, method="linear") - + with pytest.raises(pydantic.ValidationError, match="requires mode tracking to be enabled"): td.ModeSolverMonitor( center=(0, 0, 0), @@ -869,14 +836,11 @@ def test_mode_solver_monitor_interp_requires_tracking(): def test_mode_monitor_warns_redundant_num_points(): """Test warning when num_points >= number of frequencies in ModeMonitor.""" freqs = np.linspace(1e14, 2e14, 5) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # num_points >= len(freqs) should trigger warning interp_spec = td.ModeInterpSpec(num_points=5, method="linear") - + with AssertLogLevel("WARNING", contains_str="greater than or equal"): td.ModeMonitor( center=(0, 0, 0), @@ -891,14 +855,11 @@ def test_mode_monitor_warns_redundant_num_points(): def test_mode_solver_monitor_warns_redundant_num_points(): """Test warning when num_points >= number of frequencies in ModeSolverMonitor.""" freqs = np.linspace(1e14, 2e14, 5) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + # num_points >= len(freqs) should trigger warning interp_spec = td.ModeInterpSpec(num_points=6, method="linear") - + with AssertLogLevel("WARNING", contains_str="greater than or equal"): td.ModeSolverMonitor( center=(0, 0, 0), @@ -913,11 +874,8 @@ def test_mode_solver_monitor_warns_redundant_num_points(): def test_mode_monitor_interp_spec_none(): """Test that ModeMonitor works without interp_spec.""" freqs = np.linspace(1e14, 2e14, 10) - mode_spec = td.ModeSpec( - num_modes=2, - sort_spec=td.ModeSortSpec(track_freq="central") - ) - + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -926,7 +884,7 @@ def test_mode_monitor_interp_spec_none(): interp_spec=None, name="test", ) - + assert monitor.interp_spec is None @@ -934,10 +892,14 @@ def test_mode_monitor_interp_spec_none(): # WavePort interp_spec Tests # ============================================================================ + def make_wave_port(): """Make a WavePort.""" + from tidy3d.components.microwave.path_integrals.integrals.current import ( + AxisAlignedCurrentIntegral, + ) from tidy3d.plugins.smatrix.ports.wave import WavePort - from tidy3d.components.microwave.path_integrals.integrals.current import AxisAlignedCurrentIntegral + return WavePort( center=(0, 0, 0), size=(1, 1, 0), @@ -949,18 +911,18 @@ def make_wave_port(): sign="+", extrapolate_to_endpoints=True, snap_contour_to_grid=True, - ) + ), ) def test_wave_port_to_monitors_propagates_default_interp_spec(): """Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor.""" - + port = make_wave_port() - + freqs = np.linspace(1e14, 2e14, 20) monitors = port.to_monitors(freqs=freqs) - + assert len(monitors) == 1 monitor = monitors[0] assert isinstance(monitor, td.ModeMonitor) @@ -973,10 +935,10 @@ def test_wave_port_to_monitors_propagates_custom_interp_spec(): """Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor.""" custom_interp = td.ModeInterpSpec(num_points=8, method="cheb") port = make_wave_port().updated_copy(interp_spec=custom_interp) - + freqs = np.linspace(1e14, 2e14, 50) monitors = port.to_monitors(freqs=freqs) - + assert len(monitors) == 1 monitor = monitors[0] assert isinstance(monitor, td.ModeMonitor) @@ -988,10 +950,10 @@ def test_wave_port_to_monitors_propagates_custom_interp_spec(): def test_wave_port_to_monitors_propagates_none_interp_spec(): """Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor.""" port = make_wave_port().updated_copy(interp_spec=None) - + freqs = np.linspace(1e14, 2e14, 20) monitors = port.to_monitors(freqs=freqs) - + assert len(monitors) == 1 monitor = monitors[0] assert isinstance(monitor, td.ModeMonitor) @@ -1001,4 +963,3 @@ def test_wave_port_to_monitors_propagates_none_interp_spec(): # ============================================================================ # Placeholder tests for future phases # ============================================================================ - diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 8135f224ad..21071408d9 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -693,6 +693,7 @@ def set_logging_level(level: str) -> None: "ModeAmpsDataArray", "ModeData", "ModeIndexDataArray", + "ModeInterpSpec", "ModeMonitor", "ModeSimulation", "ModeSimulationData", @@ -702,7 +703,6 @@ def set_logging_level(level: str) -> None: "ModeSortSpec", "ModeSource", "ModeSpec", - "ModeInterpSpec", "ModulationSpec", "Monitor", "MultiPhysicsMedium", diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index c67a2d1452..b98481272b 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -2408,15 +2408,15 @@ def _validate_cheb_nodes(freqs: np.ndarray) -> None: mode_interp_spec = ModeInterpSpec(method="cheb", num_points=len(freqs)) expected_freqs = mode_interp_spec.sampling_points(freqs) - + # Sort both arrays for comparison (Chebyshev nodes are naturally sorted in descending order) freqs_sorted = np.sort(freqs) expected_sorted = np.sort(expected_freqs) - + # Check relative error freq_range = np.abs(expected_freqs[-1] - expected_freqs[0]) max_error = np.max(np.abs(freqs_sorted - expected_sorted)) / freq_range - + if max_error > CHEB_NODES_TOLERANCE: raise DataError( f"For Chebyshev interpolation ('cheb'), source frequencies must be at " @@ -2465,7 +2465,7 @@ def interp( Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate near mode crossings or regions of rapid mode variation. Use frequency tracking (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. - + For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the second kind within the frequency range. @@ -2486,14 +2486,14 @@ def interp( raise DataError("Cannot interpolate to fewer than 2 frequency points.") source_freqs = np.array(self.monitor.freqs) - + # Validate method-specific requirements if method == "cubic" and len(source_freqs) < 4: raise DataError( f"Cubic interpolation requires at least 4 source frequency points. " f"Got {len(source_freqs)}. Use method='linear' instead." ) - + if method == "cheb": if len(source_freqs) < 3: raise DataError( @@ -2504,28 +2504,24 @@ def interp( self._validate_cheb_nodes(source_freqs) if method not in ["linear", "cubic", "cheb"]: - raise DataError(f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'cheb'.") + raise DataError( + f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'cheb'." + ) # Build update dictionary update_dict = {} # Interpolate n_complex (required field) - update_dict["n_complex"] = self._interp_dataarray( - self.n_complex, freqs, method - ) + update_dict["n_complex"] = self._interp_dataarray(self.n_complex, freqs, method) # Interpolate field components if present for field_name, field_data in self.field_components.items(): if field_data is not None: - update_dict[field_name] = self._interp_dataarray( - field_data, freqs, method - ) + update_dict[field_name] = self._interp_dataarray(field_data, freqs, method) # Interpolate n_group_raw if present if self.n_group_raw is not None: - update_dict["n_group_raw"] = self._interp_dataarray( - self.n_group_raw, freqs, method - ) + update_dict["n_group_raw"] = self._interp_dataarray(self.n_group_raw, freqs, method) # Interpolate dispersion_raw if present if self.dispersion_raw is not None: @@ -2540,9 +2536,13 @@ def interp( # Handle eps_spec if present - use nearest neighbor interpolation if self.eps_spec is not None: - update_dict["eps_spec"] = list(self._interp_dataarray( - FreqDataArray(self.eps_spec, coords=dict(f=self.monitor.freqs)), freqs, "nearest" - ).data) + update_dict["eps_spec"] = list( + self._interp_dataarray( + FreqDataArray(self.eps_spec, coords={"f": self.monitor.freqs}), + freqs, + "nearest", + ).data + ) # Update monitor with new frequencies update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs)) @@ -2574,7 +2574,7 @@ def _interp_dataarray( """ # Map 'cheb' to xarray's 'barycentric' method xr_method = "barycentric" if method == "cheb" else method - + # Use xarray's built-in interpolation # For complex data, this automatically interpolates real and imaginary parts interp_kwargs = {"method": xr_method} @@ -2583,7 +2583,9 @@ def _interp_dataarray( freq_min, freq_max = float(data.coords["f"].min()), float(data.coords["f"].max()) new_freq_min, new_freq_max = float(freqs.min()), float(freqs.max()) - if new_freq_min < freq_min * (1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE): + if new_freq_min < freq_min * ( + 1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE + ) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE): log.warning( f"Interpolating to frequencies outside original range " f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: " diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 037b883e10..85188fa701 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -560,9 +560,7 @@ def _get_data_with_interp(self) -> ModeSolverData: # Create a copy of the mode solver with reduced frequencies and no interp_spec # (to prevent recursion) - mode_solver_reduced = self.copy( - update={"freqs": freqs_reduced, "interp_spec": None} - ) + mode_solver_reduced = self.copy(update={"freqs": freqs_reduced, "interp_spec": None}) # Get data at reduced frequencies data_reduced = mode_solver_reduced.data_raw diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 91013da39c..24137b14ff 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -7,8 +7,8 @@ from typing import Literal, Optional, Union import numpy as np -from numpy.typing import ArrayLike import pydantic.v1 as pd +from numpy.typing import ArrayLike from tidy3d.constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps from tidy3d.exceptions import SetupError, ValidationError diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index ffaf187cb0..fef7d81091 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -15,12 +15,12 @@ from tidy3d.components.geometry.base import Box from tidy3d.components.geometry.bound_ops import bounds_contains from tidy3d.components.grid.grid import Grid -from tidy3d.components.mode_spec import ModeInterpSpec from tidy3d.components.microwave.impedance_calculator import ( CurrentIntegralType, ImpedanceCalculator, VoltageIntegralType, ) +from tidy3d.components.mode_spec import ModeInterpSpec from tidy3d.components.monitor import ModeMonitor from tidy3d.components.simulation import Simulation from tidy3d.components.source.field import ModeSource, ModeSpec @@ -37,7 +37,7 @@ DEFAULT_WAVE_PORT_NUM_CELLS = 5 MIN_WAVE_PORT_NUM_CELLS = 3 DEFAULT_WAVE_PORT_FRAME = PECFrame() -DEFAULT_WAVE_PORT_INTERP_SPEC = ModeInterpSpec(num_points=15, method="cubic") +DEFAULT_WAVE_PORT_INTERP_SPEC = ModeInterpSpec(num_points=21, method="cheb") class WavePort(AbstractTerminalPort, Box): @@ -211,6 +211,7 @@ def to_mode_solver(self, simulation: Simulation, freqs: FreqArray) -> ModeSolver freqs=freqs, direction=self.direction, colocate=False, + interp_spec=self.interp_spec, ) return mode_solver From bee845b9033814ca9caf86be9edc17b3cb1396b4 Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Wed, 22 Oct 2025 14:17:26 -0700 Subject: [PATCH 4/8] leftover --- tests/test_components/test_mode_interp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index 0e4d22ea3c..60ff27043e 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -8,8 +8,6 @@ import tidy3d as td -td.config.use_local_subpixel = False - from tidy3d.plugins.mode import ModeSolver from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC From 19df1ca082bbfefa4cd70f7c0b81e8f242be43d0 Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Wed, 22 Oct 2025 23:21:01 -0700 Subject: [PATCH 5/8] greptile comments --- schemas/EMESimulation.json | 40 +++++++++ schemas/ModeSimulation.json | 47 ++++++++++ schemas/Simulation.json | 61 +++++++++++++ schemas/TerminalComponentModeler.json | 74 ++++++++++++++++ tests/test_components/test_mode_interp.py | 100 ++++++++++++++++++---- tidy3d/components/data/monitor_data.py | 3 - tidy3d/components/mode/mode_solver.py | 10 +-- tidy3d/components/monitor.py | 7 +- 8 files changed, 314 insertions(+), 28 deletions(-) diff --git a/schemas/EMESimulation.json b/schemas/EMESimulation.json index 176d2a501e..4053a94c18 100644 --- a/schemas/EMESimulation.json +++ b/schemas/EMESimulation.json @@ -8140,6 +8140,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeSolverMonitor": { "additionalProperties": false, "properties": { @@ -8261,6 +8294,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/ModeSimulation.json b/schemas/ModeSimulation.json index e43308481a..c0fa04be8b 100644 --- a/schemas/ModeSimulation.json +++ b/schemas/ModeSimulation.json @@ -7353,6 +7353,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -7444,6 +7477,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -7705,6 +7745,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/Simulation.json b/schemas/Simulation.json index c5ee4e818f..e406b1edef 100644 --- a/schemas/Simulation.json +++ b/schemas/Simulation.json @@ -10737,6 +10737,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -10968,6 +10975,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11319,6 +11333,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -11410,6 +11457,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11671,6 +11725,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, diff --git a/schemas/TerminalComponentModeler.json b/schemas/TerminalComponentModeler.json index 9adbed89cc..82388e488a 100644 --- a/schemas/TerminalComponentModeler.json +++ b/schemas/TerminalComponentModeler.json @@ -11393,6 +11393,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11624,6 +11631,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -11975,6 +11989,39 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cheb", + "cubic", + "linear" + ], + "type": "string" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -12066,6 +12113,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -12327,6 +12381,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "interval_space": { "default": [ 1, @@ -18047,6 +18108,19 @@ "type": "PECFrame" } }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ], + "default": { + "attrs": {}, + "method": "cheb", + "num_points": 21, + "type": "ModeInterpSpec" + } + }, "mode_index": { "default": 0, "minimum": 0, diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index 60ff27043e..b2257bdc50 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -7,7 +7,6 @@ import pytest import tidy3d as td - from tidy3d.plugins.mode import ModeSolver from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC @@ -315,7 +314,7 @@ def test_mode_solver_warns_num_points(): interp_spec = td.ModeInterpSpec(num_points=25, method="linear") plane = td.Box(center=(0, 0, 0), size=SIZE_2D) - with AssertLogLevel("WARNING", contains_str="num_points"): + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): ModeSolver( simulation=sim, plane=plane, @@ -347,13 +346,17 @@ def test_mode_solver_interp_spec_none(): def get_mode_solver_data(): """Create a simple ModeSolverData object for testing.""" + from tidy3d.components.data.data_array import GroupIndexDataArray, ModeDispersionDataArray + from ..test_data.test_data_arrays import ( make_scalar_mode_field_data_array, ) from ..test_data.test_monitor_data import N_COMPLEX freqs = np.linspace(1e14, 2e14, 5) - mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + num_modes = len(N_COMPLEX.mode_index) + mode_indices = np.arange(num_modes) + mode_spec = td.ModeSpec(num_modes=num_modes, sort_spec=td.ModeSortSpec(track_freq="central")) monitor = td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -362,6 +365,17 @@ def get_mode_solver_data(): name="test_monitor", ) + # Create n_group_raw and dispersion_raw with same shape as n_complex + n_group_values = 1.5 + 0.1 * np.random.random((len(freqs), num_modes)) + n_group_raw = GroupIndexDataArray( + n_group_values, coords={"f": freqs, "mode_index": mode_indices} + ) + + dispersion_values = 10.0 + 2.0 * np.random.random((len(freqs), num_modes)) + dispersion_raw = ModeDispersionDataArray( + dispersion_values, coords={"f": freqs, "mode_index": mode_indices} + ) + # Create mode data with the right frequencies mode_data = td.ModeSolverData( monitor=monitor, @@ -372,6 +386,8 @@ def get_mode_solver_data(): Hy=make_scalar_mode_field_data_array("Hy"), Hz=make_scalar_mode_field_data_array("Hz"), n_complex=N_COMPLEX.copy(), + n_group_raw=n_group_raw, + dispersion_raw=dispersion_raw, symmetry=(0, 0, 0), symmetry_center=(0, 0, 0), grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), @@ -385,7 +401,7 @@ def test_mode_solver_data_interp_linear(): # Original has 5 frequencies assert len(mode_data.monitor.freqs) == 5 - original_num_modes = mode_data.n_complex.shape[1] + original_num_modes = len(mode_data.n_complex.mode_index) # Interpolate to 20 frequencies freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) @@ -396,7 +412,7 @@ def test_mode_solver_data_interp_linear(): assert data_interp.n_complex.shape[0] == 20 # Check mode dimension is preserved - assert data_interp.n_complex.shape[1] == original_num_modes + assert len(data_interp.n_complex.mode_index) == original_num_modes # Check field components are interpolated for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: @@ -418,7 +434,7 @@ def test_mode_solver_data_interp_cubic(): # Check frequency dimension assert len(data_interp.monitor.freqs) == 20 - assert data_interp.n_complex.shape[0] == 20 + assert len(data_interp.n_complex.mode_index) == len(mode_data.n_complex.mode_index) def test_mode_solver_data_interp_cheb(): @@ -459,7 +475,7 @@ def test_mode_solver_data_interp_cheb(): # Check frequency dimension assert len(data_interp.monitor.freqs) == 50 - assert data_interp.n_complex.shape[0] == 50 + assert len(data_interp.n_complex.mode_index) == len(N_COMPLEX.mode_index) def test_mode_solver_data_interp_cheb_needs_3_source(): @@ -525,22 +541,76 @@ def test_mode_solver_data_interp_cheb_validates_nodes(): def test_mode_solver_data_interp_preserves_modes(): """Test that interpolation preserves mode count.""" mode_data = get_mode_solver_data() - original_num_modes = mode_data.n_complex.shape[1] + original_num_modes = len(mode_data.n_complex.mode_index) # Interpolate to different number of frequencies freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) data_interp = mode_data.interp(freqs=freqs_dense, method="linear") # Mode count should be unchanged - assert data_interp.n_complex.shape[1] == original_num_modes + assert len(data_interp.n_complex.mode_index) == original_num_modes -def test_mode_solver_data_interp_too_few_target_freqs(): - """Test that interpolation fails with too few target frequencies.""" +def test_mode_solver_data_interp_includes_n_group_and_dispersion(): + """Test that interpolation includes n_group_raw and dispersion_raw.""" mode_data = get_mode_solver_data() - with pytest.raises(td.exceptions.DataError, match="fewer than 2"): - mode_data.interp(freqs=[1e14], method="linear") + # Verify source data has n_group_raw and dispersion_raw + assert mode_data.n_group_raw is not None + assert mode_data.dispersion_raw is not None + assert mode_data.n_group_raw.shape == (5, len(mode_data.n_complex.mode_index)) + assert mode_data.dispersion_raw.shape == (5, len(mode_data.n_complex.mode_index)) + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp(freqs=freqs_dense, method="linear") + + # Verify interpolated data has n_group_raw and dispersion_raw + assert data_interp.n_group_raw is not None + assert data_interp.dispersion_raw is not None + + # Check shapes are correct + assert data_interp.n_group_raw.shape == (20, len(data_interp.n_complex.mode_index)) + assert data_interp.dispersion_raw.shape == (20, len(data_interp.n_complex.mode_index)) + + # Check frequency coordinates are correct + assert len(data_interp.n_group_raw.coords["f"]) == 20 + assert len(data_interp.dispersion_raw.coords["f"]) == 20 + + +def test_mode_solver_data_interp_single_frequency(): + """Test that interpolation works with a single target frequency.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to a single frequency in the middle of the range + single_freq = np.array([1.5e14]) + data_interp = mode_data.interp(freqs=single_freq, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 1 + assert data_interp.n_complex.shape[0] == 1 + + # Check mode dimension is preserved + assert len(data_interp.n_complex.mode_index) == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 1 + assert float(field_data.coords["f"]) == 1.5e14 + + # Check n_group_raw and dispersion_raw if present + if data_interp.n_group_raw is not None: + print(data_interp.n_group_raw.shape) + print((1, original_num_modes)) + assert data_interp.n_group_raw.shape == (1, original_num_modes) + if data_interp.dispersion_raw is not None: + assert data_interp.dispersion_raw.shape == (1, original_num_modes) def test_mode_solver_data_interp_cubic_needs_4_source(): @@ -839,7 +909,7 @@ def test_mode_monitor_warns_redundant_num_points(): # num_points >= len(freqs) should trigger warning interp_spec = td.ModeInterpSpec(num_points=5, method="linear") - with AssertLogLevel("WARNING", contains_str="greater than or equal"): + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): td.ModeMonitor( center=(0, 0, 0), size=SIZE_2D, @@ -858,7 +928,7 @@ def test_mode_solver_monitor_warns_redundant_num_points(): # num_points >= len(freqs) should trigger warning interp_spec = td.ModeInterpSpec(num_points=6, method="linear") - with AssertLogLevel("WARNING", contains_str="greater than or equal"): + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): td.ModeSolverMonitor( center=(0, 0, 0), size=SIZE_2D, diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index b98481272b..5a506c476a 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -2482,9 +2482,6 @@ def interp( """ # Validate input freqs = np.array(freqs) - if len(freqs) < 2: - raise DataError("Cannot interpolate to fewer than 2 frequency points.") - source_freqs = np.array(self.monitor.freqs) # Validate method-specific requirements diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 85188fa701..55fc26f84c 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -298,7 +298,8 @@ def _warn_interp_num_points(cls, val, values): if val.num_points >= num_freqs: log.warning( f"interp_spec.num_points ({val.num_points}) is greater than or equal to " - f"the number of frequencies ({num_freqs}). No computational savings are achieved.", + f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " + f"modes will be computed at all {num_freqs} frequencies.", custom_loc=["interp_spec", "num_points"], ) @@ -601,7 +602,7 @@ def data_raw(self) -> ModeSolverDataType: if self.mode_spec.group_index_step > 0: return self._get_data_with_group_index() - if self.interp_spec is not None: + if self.interp_spec is not None and self.interp_spec.num_points < len(self.freqs): return self._get_data_with_interp() if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: @@ -609,23 +610,18 @@ def data_raw(self) -> ModeSolverDataType: # Compute data on the Yee grid mode_solver_data = self._data_on_yee_grid() - # print("mode_solver_data: ", mode_solver_data) if self._has_microwave_mode_spec: mode_solver_data = MicrowaveModeSolverData(**mode_solver_data.dict(exclude={"type"})) - # print("mode_solver_data: ", mode_solver_data) # Colocate to grid boundaries if requested if self.colocate: mode_solver_data = self._colocate_data(mode_solver_data=mode_solver_data) - # print("mode_solver_data: ", mode_solver_data) # normalize modes self._normalize_modes(mode_solver_data=mode_solver_data) # filter polarization if requested - # print("mode_solver_data: ", mode_solver_data) mode_solver_data = self._filter_polarization(mode_solver_data=mode_solver_data) - # print("mode_solver_data: ", mode_solver_data) # filter and sort modes if requested by sort_spec mode_solver_data = mode_solver_data.sort_modes( diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index b8d156f620..fcb332ba53 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -455,7 +455,7 @@ def _validate_interp_requires_tracking(cls, val, values): @pydantic.validator("interp_spec", always=True) @skip_if_fields_missing(["freqs"]) def _warn_interp_num_points(cls, val, values): - """Warn if num_points is less than total frequencies.""" + """Warn if num_points is greater than or equal to total frequencies.""" if val is None: return val @@ -464,8 +464,9 @@ def _warn_interp_num_points(cls, val, values): if val.num_points >= num_freqs: log.warning( - f"interp_spec.num_points ({val.num_points}) is greater than or equal to " - f"the number of frequencies ({num_freqs}). No savings are achieved.", + f"'interp_spec.num_points' ({val.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " + f"modes will be computed at all {num_freqs} frequencies.", custom_loc=["interp_spec", "num_points"], ) From eabec9664d09a2fc12780c6d7b0a85d29cf426f9 Mon Sep 17 00:00:00 2001 From: Casey Wojcik Date: Mon, 27 Oct 2025 16:12:56 -0700 Subject: [PATCH 6/8] Add interpolation to EME --- tidy3d/components/data/monitor_data.py | 17 ++++++++------- tidy3d/components/eme/grid.py | 29 +++++++++++++++++++------- tidy3d/components/eme/simulation.py | 5 ++++- tidy3d/components/mode_spec.py | 6 ++++++ 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 5a506c476a..3ead50bdbc 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -2429,6 +2429,7 @@ def interp( self, freqs: FreqArray, method: Literal["linear", "cubic", "cheb"] = "linear", + assume_constant_modes: bool = False ) -> ModeSolverData: """Interpolate mode data to new frequency points. @@ -2512,9 +2513,10 @@ def interp( update_dict["n_complex"] = self._interp_dataarray(self.n_complex, freqs, method) # Interpolate field components if present - for field_name, field_data in self.field_components.items(): - if field_data is not None: - update_dict[field_name] = self._interp_dataarray(field_data, freqs, method) + if not assume_constant_modes: + for field_name, field_data in self.field_components.items(): + if field_data is not None: + update_dict[field_name] = self._interp_dataarray(field_data, freqs, method) # Interpolate n_group_raw if present if self.n_group_raw is not None: @@ -2527,15 +2529,16 @@ def interp( ) # Interpolate grid correction data if present - for key, data in self._grid_correction_dict.items(): - if isinstance(data, DataArray) and "f" in data.coords: - update_dict[key] = self._interp_dataarray(data, freqs, method) + if not assume_constant_modes: + for key, data in self._grid_correction_dict.items(): + if isinstance(data, DataArray) and "f" in data.coords: + update_dict[key] = self._interp_dataarray(data, freqs, method) # Handle eps_spec if present - use nearest neighbor interpolation if self.eps_spec is not None: update_dict["eps_spec"] = list( self._interp_dataarray( - FreqDataArray(self.eps_spec, coords={"f": self.monitor.freqs}), + FreqDataArray(self.eps_spec, coords={"f": np.array(self.monitor.freqs)}), freqs, "nearest", ).data diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index d197c60677..5e4b1d33fa 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -11,7 +11,7 @@ from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D -from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.mode_spec import ModeSpec, ModeInterpSpec from tidy3d.components.structure import Structure from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq from tidy3d.constants import RADIAN, fp_eps, inf @@ -25,16 +25,28 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" + + interp_spec: ModeInterpSpec = pd.Field( + ModeInterpSpec(method="cheb", num_points=5), + title="interp spec", + description="interp spec" + ) - track_freq: Union[TrackFreq, None] = pd.Field( - None, - title="Mode Tracking Frequency", - description="Parameter that turns on/off mode tracking based on their similarity. " - "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " - "mode tracking based on the lowest, central, or highest frequency. " - "If ``None`` no mode tracking is performed, which is the default for best performance.", + assume_constant_modes: bool = pd.Field( + True, + title="assume constant modes", + description="assume constant modes" ) + # track_freq: Union[TrackFreq, None] = pd.Field( + # None, + # title="Mode Tracking Frequency", + # description="Parameter that turns on/off mode tracking based on their similarity. " + # "Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to " + # "mode tracking based on the lowest, central, or highest frequency. " + # "If ``None`` no mode tracking is performed, which is the default for best performance.", + # ) + angle_theta: Literal[0.0] = pd.Field( 0.0, title="Polar Angle", @@ -85,6 +97,7 @@ def _to_mode_spec(self) -> ModeSpec: """Convert to ordinary :class:`.ModeSpec`.""" ms_dict = self.dict() ms_dict.pop("type") + ms_dict.pop("interp_spec") return ModeSpec.parse_obj(ms_dict) diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index b01b928e4f..a38f2eb183 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -554,11 +554,14 @@ def mode_solver_monitors(self) -> list[ModeSolverMonitor]: mode_planes = self.eme_grid.mode_planes mode_specs = [eme_mode_spec._to_mode_spec() for eme_mode_spec in self.eme_grid.mode_specs] for i in range(self.eme_grid.num_cells): + freqs_curr = freqs + if self.eme_grid.mode_specs[i].interp_spec is not None: + freqs_curr = np.array(self.eme_grid.mode_specs[i].interp_spec.sampling_points(freqs)) monitor = ModeSolverMonitor( center=mode_planes[i].center, size=mode_planes[i].size, name=f"_eme_mode_solver_monitor_{i}", - freqs=freqs, + freqs=freqs_curr, mode_spec=mode_specs[i], colocate=False, ) diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 24137b14ff..560285ab66 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -474,3 +474,9 @@ class ModeSpec(AbstractModeSpec): * `Waveguide to ring coupling <../../notebooks/WaveguideToRingCoupling.html>`_ """ + + assume_constant_modes: bool = pd.Field( + False, + title="assume constant modes", + description="assume constant modes" + ) From 0065aaa16e73c530df6e79a1ac830052239a9d82 Mon Sep 17 00:00:00 2001 From: Casey Wojcik Date: Mon, 27 Oct 2025 16:21:08 -0700 Subject: [PATCH 7/8] revise --- tidy3d/components/eme/grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index 5e4b1d33fa..c204ef5968 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -27,7 +27,7 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" interp_spec: ModeInterpSpec = pd.Field( - ModeInterpSpec(method="cheb", num_points=5), + ModeInterpSpec(method="cheb", num_points=4), title="interp spec", description="interp spec" ) From f4565b6a0110013df53e12065cc8138418a7b842 Mon Sep 17 00:00:00 2001 From: Casey Wojcik Date: Tue, 28 Oct 2025 14:10:39 -0700 Subject: [PATCH 8/8] working state --- tidy3d/components/data/monitor_data.py | 5 ++++- tidy3d/components/eme/data/sim_data.py | 14 ++++++++++++-- tidy3d/components/eme/grid.py | 17 ++++------------- tidy3d/components/eme/simulation.py | 18 ++++++++++++++++-- tidy3d/components/mode/mode_solver.py | 1 + tidy3d/components/mode_spec.py | 6 ++---- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 3ead50bdbc..d7f5e0f407 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -1660,6 +1660,9 @@ def eps_spec_match_mode_spec(cls, val, values): """Raise validation error if frequencies in eps_spec does not match frequency list""" if val: mode_data_freqs = values["monitor"].freqs + interp_spec = values["monitor"].mode_spec.interp_spec + if interp_spec is not None: + return val if len(val) != len(mode_data_freqs): raise ValidationError( "eps_spec must be provided at the same frequencies as mode solver data." @@ -2429,7 +2432,7 @@ def interp( self, freqs: FreqArray, method: Literal["linear", "cubic", "cheb"] = "linear", - assume_constant_modes: bool = False + assume_constant_modes: bool = False, ) -> ModeSolverData: """Interpolate mode data to new frequency points. diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index 03d16e65d9..1296457d02 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -197,8 +197,8 @@ def smatrix_in_basis( modes1 = port_modes1 if not modes2_provided: modes2 = port_modes2 - f1 = list(modes1.field_components.values())[0].f.values - f2 = list(modes2.field_components.values())[0].f.values + f1 = list(modes1.monitor.freqs) + f2 = list(modes2.monitor.freqs) f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs))) @@ -259,6 +259,11 @@ def smatrix_in_basis( overlaps1 = modes1.outer_dot(port_modes1, conjugate=False) if not modes_in_1: overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1) + interp_spec1 = modes1.monitor.mode_spec.interp_spec + if interp_spec1 is not None: + overlaps1 = modes1._interp_dataarray( + overlaps1, freqs=f, method=interp_spec1.method + ) O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1) O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") @@ -288,6 +293,11 @@ def smatrix_in_basis( overlaps2 = modes2.outer_dot(port_modes2, conjugate=False) if not modes_in_2: overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1) + interp_spec2 = modes2.monitor.mode_spec.interp_spec + if interp_spec2: + overlaps2 = modes2._interp_dataarray( + overlaps2, freqs=f, method=interp_spec2.method + ) O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2) O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old") diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index c204ef5968..abba122cef 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -11,9 +11,9 @@ from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D -from tidy3d.components.mode_spec import ModeSpec, ModeInterpSpec +from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq +from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size from tidy3d.constants import RADIAN, fp_eps, inf from tidy3d.exceptions import SetupError, ValidationError @@ -25,17 +25,9 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - - interp_spec: ModeInterpSpec = pd.Field( - ModeInterpSpec(method="cheb", num_points=4), - title="interp spec", - description="interp spec" - ) - assume_constant_modes: bool = pd.Field( - True, - title="assume constant modes", - description="assume constant modes" + interp_spec: ModeInterpSpec = pd.Field( + ModeInterpSpec(method="cheb", num_points=4), title="interp spec", description="interp spec" ) # track_freq: Union[TrackFreq, None] = pd.Field( @@ -97,7 +89,6 @@ def _to_mode_spec(self) -> ModeSpec: """Convert to ordinary :class:`.ModeSpec`.""" ms_dict = self.dict() ms_dict.pop("type") - ms_dict.pop("interp_spec") return ModeSpec.parse_obj(ms_dict) diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index a38f2eb183..d8fc0e9d4c 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -555,8 +555,6 @@ def mode_solver_monitors(self) -> list[ModeSolverMonitor]: mode_specs = [eme_mode_spec._to_mode_spec() for eme_mode_spec in self.eme_grid.mode_specs] for i in range(self.eme_grid.num_cells): freqs_curr = freqs - if self.eme_grid.mode_specs[i].interp_spec is not None: - freqs_curr = np.array(self.eme_grid.mode_specs[i].interp_spec.sampling_points(freqs)) monitor = ModeSolverMonitor( center=mode_planes[i].center, size=mode_planes[i].size, @@ -582,6 +580,7 @@ def port_modes_monitor(self) -> EMEModeSolverMonitor: num_modes=self.max_port_modes, num_sweep=None, normalize=self.normalize, + freqs=self.freqs, ) def _post_init_validators(self) -> None: @@ -748,6 +747,9 @@ def _validate_sweep_spec(self): def _validate_monitor_setup(self): """Check monitor setup.""" + for i in range(len(self.monitors)): + if self.monitors[i].freqs is None: + self.monitors[i] = self.monitors[i].updated_copy(freqs=self.freqs) for i, monitor in enumerate(self.monitors): if isinstance(monitor, EMEMonitor): _ = self._monitor_eme_cell_indices(monitor=monitor) @@ -1012,6 +1014,18 @@ def _monitor_freqs(self, monitor: Monitor) -> list[pd.NonNegativeFloat]: return list(self.freqs) return list(monitor.freqs) + def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[pd.NonNegativeFloat]: + """Monitor frequencies.""" + freqs = set() + cell_inds = self._monitor_eme_cell_indices(monitor=monitor) + for cell_ind in cell_inds: + interp_spec = self.eme_grid.mode_specs[cell_ind].interp_spec + if interp_spec is None: + freqs |= set(self.freqs) + else: + freqs |= set(interp_spec.sampling_points(self.freqs)) + return list(freqs) + def _monitor_num_freqs(self, monitor: Monitor) -> int: """Total number of freqs included in monitor.""" return len(self._monitor_freqs(monitor=monitor)) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 55fc26f84c..e862b85fcd 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -565,6 +565,7 @@ def _get_data_with_interp(self) -> ModeSolverData: # Get data at reduced frequencies data_reduced = mode_solver_reduced.data_raw + return data_reduced # Interpolate back to original frequencies return data_reduced.interp(freqs=self.freqs, method=self.interp_spec.method) diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 560285ab66..b491b09cc9 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -475,8 +475,6 @@ class ModeSpec(AbstractModeSpec): """ - assume_constant_modes: bool = pd.Field( - False, - title="assume constant modes", - description="assume constant modes" + interp_spec: ModeInterpSpec = pd.Field( + ModeInterpSpec(method="cheb", num_points=4), title="interp spec", description="interp spec" )