From 73e8ac0066d94eaba83a51446463cb8497a98be6 Mon Sep 17 00:00:00 2001 From: dbochkov-flexcompute Date: Wed, 12 Nov 2025 09:30:22 -0800 Subject: [PATCH] feat: downsample frequencies in mode monitors (fxc-3351) --- CHANGELOG.md | 1 + docs/api/mode.rst | 4 +- schemas/EMESimulation.json | 158 +++ schemas/ModeSimulation.json | 153 +++ schemas/Simulation.json | 153 +++ schemas/TerminalComponentModeler.json | 153 +++ tests/test_components/test_microwave.py | 334 +++++ tests/test_components/test_mode_interp.py | 1154 +++++++++++++++++ tidy3d/__init__.py | 13 +- tidy3d/components/data/dataset.py | 139 +- tidy3d/components/data/monitor_data.py | 251 +++- tidy3d/components/microwave/data/dataset.py | 4 +- .../components/microwave/data/monitor_data.py | 101 +- tidy3d/components/mode/mode_solver.py | 225 ++-- tidy3d/components/mode/simulation.py | 3 + tidy3d/components/mode_spec.py | 432 +++++- tidy3d/components/monitor.py | 30 +- tidy3d/components/validators.py | 32 + 18 files changed, 3232 insertions(+), 108 deletions(-) create mode 100644 tests/test_components/test_mode_interp.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c986979303..b6cfda5301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for `nonlinear_spec` in `CustomMedium` and `CustomDispersiveMedium`. - `tidy3d.plugins.design.DesignSpace.run(..., fn_post=...)` now accepts a `priority` keyword to propagate vGPU queue priority to all automatically batched simulations. - Introduced `BroadbandPulse` for exciting simulations across a wide frequency spectrum. +- Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. diff --git a/docs/api/mode.rst b/docs/api/mode.rst index 7237cb4b2f..84b21c2af1 100644 --- a/docs/api/mode.rst +++ b/docs/api/mode.rst @@ -7,4 +7,6 @@ Mode Specifications :toctree: _autosummary/ :template: module.rst - tidy3d.ModeSpec \ No newline at end of file + tidy3d.ModeSpec + tidy3d.ModeSortSpec + tidy3d.ModeInterpSpec \ No newline at end of file diff --git a/schemas/EMESimulation.json b/schemas/EMESimulation.json index a43d923d8c..c70989201a 100644 --- a/schemas/EMESimulation.json +++ b/schemas/EMESimulation.json @@ -1748,6 +1748,30 @@ }, "type": "object" }, + "ChebSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 3, + "type": "integer" + }, + "type": { + "default": "ChebSampling", + "enum": [ + "ChebSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ClipOperation": { "additionalProperties": false, "properties": { @@ -3907,6 +3931,39 @@ ], "type": "object" }, + "CustomSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "freqs": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "ArrayLike" + } + ] + }, + "type": { + "default": "CustomSampling", + "enum": [ + "CustomSampling" + ], + "type": "string" + } + }, + "required": [ + "freqs" + ], + "type": "object" + }, "CustomSellmeier": { "additionalProperties": false, "properties": { @@ -5232,6 +5289,13 @@ ], "default": false }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -7967,6 +8031,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -8073,6 +8144,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -8133,6 +8205,60 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cubic", + "linear", + "poly" + ], + "type": "string" + }, + "reduce_data": { + "default": false, + "type": "boolean" + }, + "sampling_spec": { + "discriminator": { + "mapping": { + "ChebSampling": "#/definitions/ChebSampling", + "CustomSampling": "#/definitions/CustomSampling", + "UniformSampling": "#/definitions/UniformSampling" + }, + "propertyName": "type" + }, + "oneOf": [ + { + "$ref": "#/definitions/ChebSampling" + }, + { + "$ref": "#/definitions/CustomSampling" + }, + { + "$ref": "#/definitions/UniformSampling" + } + ] + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "sampling_spec" + ], + "type": "object" + }, "ModeSolverMonitor": { "additionalProperties": false, "properties": { @@ -8299,6 +8425,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -8515,6 +8642,13 @@ ], "default": false }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -12052,6 +12186,30 @@ ], "type": "object" }, + "UniformSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "UniformSampling", + "enum": [ + "UniformSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "VarshniEnergyBandGap": { "additionalProperties": false, "properties": { diff --git a/schemas/ModeSimulation.json b/schemas/ModeSimulation.json index a6a28125d6..15f8e42e94 100644 --- a/schemas/ModeSimulation.json +++ b/schemas/ModeSimulation.json @@ -1800,6 +1800,30 @@ }, "type": "object" }, + "ChebSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 3, + "type": "integer" + }, + "type": { + "default": "ChebSampling", + "enum": [ + "ChebSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ClipOperation": { "additionalProperties": false, "properties": { @@ -4002,6 +4026,39 @@ ], "type": "object" }, + "CustomSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "freqs": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "ArrayLike" + } + ] + }, + "type": { + "default": "CustomSampling", + "enum": [ + "CustomSampling" + ], + "type": "string" + } + }, + "required": [ + "freqs" + ], + "type": "object" + }, "CustomSellmeier": { "additionalProperties": false, "properties": { @@ -7232,6 +7289,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -7338,6 +7402,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -7398,6 +7463,60 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cubic", + "linear", + "poly" + ], + "type": "string" + }, + "reduce_data": { + "default": false, + "type": "boolean" + }, + "sampling_spec": { + "discriminator": { + "mapping": { + "ChebSampling": "#/definitions/ChebSampling", + "CustomSampling": "#/definitions/CustomSampling", + "UniformSampling": "#/definitions/UniformSampling" + }, + "propertyName": "type" + }, + "oneOf": [ + { + "$ref": "#/definitions/ChebSampling" + }, + { + "$ref": "#/definitions/CustomSampling" + }, + { + "$ref": "#/definitions/UniformSampling" + } + ] + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "sampling_spec" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -7534,6 +7653,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -7795,6 +7915,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -8048,6 +8169,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -8229,6 +8351,13 @@ ], "default": false }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -11819,6 +11948,30 @@ ], "type": "object" }, + "UniformSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "UniformSampling", + "enum": [ + "UniformSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "VarshniEnergyBandGap": { "additionalProperties": false, "properties": { diff --git a/schemas/Simulation.json b/schemas/Simulation.json index de54588e30..97ab08ff52 100644 --- a/schemas/Simulation.json +++ b/schemas/Simulation.json @@ -2181,6 +2181,30 @@ }, "type": "object" }, + "ChebSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 3, + "type": "integer" + }, + "type": { + "default": "ChebSampling", + "enum": [ + "ChebSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ClipOperation": { "additionalProperties": false, "properties": { @@ -4697,6 +4721,39 @@ ], "type": "object" }, + "CustomSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "freqs": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "ArrayLike" + } + ] + }, + "type": { + "default": "CustomSampling", + "enum": [ + "CustomSampling" + ], + "type": "string" + } + }, + "required": [ + "freqs" + ], + "type": "object" + }, "CustomSellmeier": { "additionalProperties": false, "properties": { @@ -11214,6 +11271,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -11320,6 +11384,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -11380,6 +11445,60 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cubic", + "linear", + "poly" + ], + "type": "string" + }, + "reduce_data": { + "default": false, + "type": "boolean" + }, + "sampling_spec": { + "discriminator": { + "mapping": { + "ChebSampling": "#/definitions/ChebSampling", + "CustomSampling": "#/definitions/CustomSampling", + "UniformSampling": "#/definitions/UniformSampling" + }, + "propertyName": "type" + }, + "oneOf": [ + { + "$ref": "#/definitions/ChebSampling" + }, + { + "$ref": "#/definitions/CustomSampling" + }, + { + "$ref": "#/definitions/UniformSampling" + } + ] + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "sampling_spec" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -11516,6 +11635,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -11777,6 +11897,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -12030,6 +12151,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -12211,6 +12333,13 @@ ], "default": false }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -16514,6 +16643,30 @@ ], "type": "object" }, + "UniformSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "UniformSampling", + "enum": [ + "UniformSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "VarshniEnergyBandGap": { "additionalProperties": false, "properties": { diff --git a/schemas/TerminalComponentModeler.json b/schemas/TerminalComponentModeler.json index d3fb5bb01a..317815af3e 100644 --- a/schemas/TerminalComponentModeler.json +++ b/schemas/TerminalComponentModeler.json @@ -2181,6 +2181,30 @@ }, "type": "object" }, + "ChebSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 3, + "type": "integer" + }, + "type": { + "default": "ChebSampling", + "enum": [ + "ChebSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "ClipOperation": { "additionalProperties": false, "properties": { @@ -4801,6 +4825,39 @@ ], "type": "object" }, + "CustomSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "freqs": { + "anyOf": [ + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "type": "ArrayLike" + } + ] + }, + "type": { + "default": "CustomSampling", + "enum": [ + "CustomSampling" + ], + "type": "string" + } + }, + "required": [ + "freqs" + ], + "type": "object" + }, "CustomSellmeier": { "additionalProperties": false, "properties": { @@ -11556,6 +11613,13 @@ } ] }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -11662,6 +11726,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -11722,6 +11787,60 @@ ], "type": "object" }, + "ModeInterpSpec": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "method": { + "default": "linear", + "enum": [ + "cubic", + "linear", + "poly" + ], + "type": "string" + }, + "reduce_data": { + "default": false, + "type": "boolean" + }, + "sampling_spec": { + "discriminator": { + "mapping": { + "ChebSampling": "#/definitions/ChebSampling", + "CustomSampling": "#/definitions/CustomSampling", + "UniformSampling": "#/definitions/UniformSampling" + }, + "propertyName": "type" + }, + "oneOf": [ + { + "$ref": "#/definitions/ChebSampling" + }, + { + "$ref": "#/definitions/CustomSampling" + }, + { + "$ref": "#/definitions/UniformSampling" + } + ] + }, + "type": { + "default": "ModeInterpSpec", + "enum": [ + "ModeInterpSpec" + ], + "type": "string" + } + }, + "required": [ + "sampling_spec" + ], + "type": "object" + }, "ModeMonitor": { "additionalProperties": false, "properties": { @@ -11858,6 +11977,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -12119,6 +12239,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -12372,6 +12493,7 @@ "bend_radius": null, "filter_pol": null, "group_index_step": false, + "interp_spec": null, "num_modes": 1, "num_pml": [ 0, @@ -12553,6 +12675,13 @@ ], "default": false }, + "interp_spec": { + "allOf": [ + { + "$ref": "#/definitions/ModeInterpSpec" + } + ] + }, "num_modes": { "default": 1, "exclusiveMinimum": 0, @@ -17636,6 +17765,30 @@ ], "type": "object" }, + "UniformSampling": { + "additionalProperties": false, + "properties": { + "attrs": { + "default": {}, + "type": "object" + }, + "num_points": { + "minimum": 2, + "type": "integer" + }, + "type": { + "default": "UniformSampling", + "enum": [ + "UniformSampling" + ], + "type": "string" + } + }, + "required": [ + "num_points" + ], + "type": "object" + }, "VarshniEnergyBandGap": { "additionalProperties": false, "properties": { diff --git a/tests/test_components/test_microwave.py b/tests/test_components/test_microwave.py index 22ca41f79b..6da4bb00f0 100644 --- a/tests/test_components/test_microwave.py +++ b/tests/test_components/test_microwave.py @@ -1876,3 +1876,337 @@ def test_RF_license_suppression(): with AssertLogLevel(None): mode_spec = td.MicrowaveModeSpec._default_without_license_warning() td.config.microwave.suppress_rf_license_warning = original_setting + + +def test_microwave_mode_data_reordering_with_transmission_line_data(): + """Test that transmission_line_data is correctly reordered when modes are reordered.""" + from tidy3d.components.data.data_array import ( + CurrentFreqModeDataArray, + ImpedanceFreqModeDataArray, + ModeIndexDataArray, + ScalarModeFieldDataArray, + VoltageFreqModeDataArray, + ) + from tidy3d.components.microwave.data.dataset import TransmissionLineDataset + + # Setup coordinates + x = [-1, 1, 3] + y = [-2, 0] + z = [-3, -1, 1, 3, 5] + f = [2e14, 3e14] + mode_index = np.arange(3) + + grid = td.Grid(boundaries=td.Coords(x=x, y=y, z=z)) + field_coords = {"x": x[:-1], "y": y[:-1], "z": z[:-1], "f": f, "mode_index": mode_index} + index_coords = {"f": f, "mode_index": mode_index} + + # Create field data with distinct values for each mode + field_values = np.zeros((2, 1, 4, 2, 3), dtype=complex) + for mode_idx in range(3): + # Each mode gets a unique value to track reordering + field_values[:, :, :, :, mode_idx] = (mode_idx + 1) * (1 + 1j) + + field = ScalarModeFieldDataArray(field_values, coords=field_coords) + + # Create mode index data with distinct values for each mode + index_values = np.zeros((2, 3), dtype=complex) + for mode_idx in range(3): + index_values[:, mode_idx] = (mode_idx + 1) * 1.5 + 0.1j + index_data = ModeIndexDataArray(index_values, coords=index_coords) + + # Create transmission line data with distinct values for each mode + impedance_values = np.zeros((2, 3)) + voltage_values = np.zeros((2, 3), dtype=complex) + current_values = np.zeros((2, 3), dtype=complex) + + for mode_idx in range(3): + # Each mode gets unique impedance, voltage, and current values + impedance_values[:, mode_idx] = 50 * (mode_idx + 1) + voltage_values[:, mode_idx] = (mode_idx + 1) * (10 + 5j) + current_values[:, mode_idx] = (mode_idx + 1) * (0.2 + 0.1j) + + impedance_data = ImpedanceFreqModeDataArray(impedance_values, coords=index_coords) + voltage_data = VoltageFreqModeDataArray(voltage_values, coords=index_coords) + current_data = CurrentFreqModeDataArray(current_values, coords=index_coords) + + tl_data = TransmissionLineDataset( + Z0=impedance_data, voltage_coeffs=voltage_data, current_coeffs=current_data + ) + + # Create monitor + monitor = td.MicrowaveModeSolverMonitor( + center=(0, 0, 0), + size=(2, 0, 6), + freqs=[2e14, 3e14], + mode_spec=td.MicrowaveModeSpec(num_modes=3, impedance_specs=td.AutoImpedanceSpec()), + name="microwave_mode_solver", + ) + + # Create MicrowaveModeSolverData + data = td.MicrowaveModeSolverData( + monitor=monitor, + Ex=field, + Ey=field, + Ez=field, + Hx=field, + Hy=field, + Hz=field, + n_complex=index_data, + grid_expanded=grid, + transmission_line_data=tl_data, + ) + + # Define a reordering: reverse the mode order for each frequency + # Shape: (num_freqs, num_modes) = (2, 3) + # Original order: [0, 1, 2] -> New order: [2, 1, 0] + sort_inds_2d = np.array([[2, 1, 0], [2, 1, 0]]) + + # Apply mode reordering + reordered_data = data._apply_mode_reorder(sort_inds_2d) + + # Verify that the main mode data is reordered correctly + # Original mode 2 should now be at index 0 + original_mode_2_value = (2 + 1) * (1 + 1j) # Mode 2 had value 3*(1+1j) + assert np.allclose( + reordered_data.Ex.isel(mode_index=0, x=0, y=0, z=0).values, original_mode_2_value + ), "Main field data not reordered correctly" + + # Original mode 0 should now be at index 2 + original_mode_0_value = (0 + 1) * (1 + 1j) # Mode 0 had value 1*(1+1j) + assert np.allclose( + reordered_data.Ex.isel(mode_index=2, x=0, y=0, z=0).values, original_mode_0_value + ), "Main field data not reordered correctly" + + # Verify that transmission_line_data is also reordered correctly + assert reordered_data.transmission_line_data is not None, ( + "transmission_line_data should not be None" + ) + + # Check Z0 reordering + # Original mode 2 had Z0 = 50 * 3 = 150 + assert np.allclose(reordered_data.transmission_line_data.Z0.isel(mode_index=0).values, 150.0), ( + "transmission_line_data.Z0 not reordered correctly" + ) + + # Original mode 0 had Z0 = 50 * 1 = 50 + assert np.allclose(reordered_data.transmission_line_data.Z0.isel(mode_index=2).values, 50.0), ( + "transmission_line_data.Z0 not reordered correctly" + ) + + # Check voltage_coeffs reordering + # Original mode 2 had voltage = 3 * (10 + 5j) + assert np.allclose( + reordered_data.transmission_line_data.voltage_coeffs.isel(mode_index=0).values, + 3 * (10 + 5j), + ), "transmission_line_data.voltage_coeffs not reordered correctly" + + # Original mode 0 had voltage = 1 * (10 + 5j) + assert np.allclose( + reordered_data.transmission_line_data.voltage_coeffs.isel(mode_index=2).values, + 1 * (10 + 5j), + ), "transmission_line_data.voltage_coeffs not reordered correctly" + + # Check current_coeffs reordering + # Original mode 2 had current = 3 * (0.2 + 0.1j) + assert np.allclose( + reordered_data.transmission_line_data.current_coeffs.isel(mode_index=0).values, + 3 * (0.2 + 0.1j), + ), "transmission_line_data.current_coeffs not reordered correctly" + + # Original mode 0 had current = 1 * (0.2 + 0.1j) + assert np.allclose( + reordered_data.transmission_line_data.current_coeffs.isel(mode_index=2).values, + 1 * (0.2 + 0.1j), + ), "transmission_line_data.current_coeffs not reordered correctly" + + # Verify mode index data is also reordered + # Original mode 2 had n_complex = 3 * 1.5 + 0.1j + assert np.allclose(reordered_data.n_complex.isel(mode_index=0).values, 3 * 1.5 + 0.1j), ( + "n_complex not reordered correctly" + ) + + +def test_microwave_mode_data_interpolation(): + """Test that MicrowaveModeSolverData interpolation correctly handles transmission_line_data.""" + from tidy3d.components.data.data_array import ( + CurrentFreqModeDataArray, + ImpedanceFreqModeDataArray, + ModeIndexDataArray, + ScalarModeFieldDataArray, + VoltageFreqModeDataArray, + ) + from tidy3d.components.microwave.data.dataset import TransmissionLineDataset + + # Setup coordinates with sparse frequencies + x = [-1, 1, 3] + y = [-2, 0] + z = [-3, -1, 1, 3, 5] + f_sparse = np.array([1e14, 1.5e14, 2e14]) # 3 source frequencies + mode_index = np.arange(2) + + grid = td.Grid(boundaries=td.Coords(x=x, y=y, z=z)) + field_coords = {"x": x[:-1], "y": y[:-1], "z": z[:-1], "f": f_sparse, "mode_index": mode_index} + index_coords = {"f": f_sparse, "mode_index": mode_index} + + # Create field data with frequency-dependent values + field_values = np.zeros( + (len(x) - 1, len(y) - 1, len(z) - 1, len(f_sparse), len(mode_index)), dtype=complex + ) + for f_idx, freq in enumerate(f_sparse): + for mode_idx in range(len(mode_index)): + # Value depends on both frequency and mode: (freq/1e14) * (mode+1) * (1+1j) + field_values[:, :, :, f_idx, mode_idx] = (freq / 1e14) * (mode_idx + 1) * (1 + 1j) + + field = ScalarModeFieldDataArray(field_values, coords=field_coords) + + # Create mode index data with frequency dependence + index_values = np.zeros((len(f_sparse), len(mode_index)), dtype=complex) + for f_idx, freq in enumerate(f_sparse): + for mode_idx in range(len(mode_index)): + # n_eff increases with frequency: 1.5 + (freq/1e14)*0.1 + mode_idx*0.2 + index_values[f_idx, mode_idx] = 1.5 + (freq / 1e14) * 0.1 + mode_idx * 0.2 + 0.01j + index_data = ModeIndexDataArray(index_values, coords=index_coords) + + # Create transmission line data with frequency dependence + impedance_values = np.zeros((len(f_sparse), len(mode_index))) + voltage_values = np.zeros((len(f_sparse), len(mode_index)), dtype=complex) + current_values = np.zeros((len(f_sparse), len(mode_index)), dtype=complex) + + for f_idx, freq in enumerate(f_sparse): + for mode_idx in range(len(mode_index)): + # Impedance varies with frequency: 50 + (freq/1e14)*10 + mode_idx*20 + impedance_values[f_idx, mode_idx] = 50 + (freq / 1e14) * 10 + mode_idx * 20 + # Voltage varies with frequency + voltage_values[f_idx, mode_idx] = ((freq / 1e14) + mode_idx) * (10 + 5j) + # Current varies with frequency + current_values[f_idx, mode_idx] = ((freq / 1e14) + mode_idx) * (0.2 + 0.1j) + + impedance_data = ImpedanceFreqModeDataArray(impedance_values, coords=index_coords) + voltage_data = VoltageFreqModeDataArray(voltage_values, coords=index_coords) + current_data = CurrentFreqModeDataArray(current_values, coords=index_coords) + + tl_data = TransmissionLineDataset( + Z0=impedance_data, voltage_coeffs=voltage_data, current_coeffs=current_data + ) + + # Create monitor + monitor = td.MicrowaveModeSolverMonitor( + center=(0, 0, 0), + size=(2, 0, 6), + freqs=f_sparse, + mode_spec=td.MicrowaveModeSpec(num_modes=2, impedance_specs=td.AutoImpedanceSpec()), + name="microwave_mode_solver", + ) + + # Create MicrowaveModeSolverData + data = td.MicrowaveModeSolverData( + monitor=monitor, + Ex=field, + Ey=field, + Ez=field, + Hx=field, + Hy=field, + Hz=field, + n_complex=index_data, + grid_expanded=grid, + transmission_line_data=tl_data, + ) + + # Interpolate to denser frequency grid + f_dense = np.linspace(1e14, 2e14, 11) + + # Test linear interpolation + data_interp_linear = data.interp_in_freq(freqs=f_dense, method="linear", renormalize=False) + + # Verify that interpolated data has correct shape + assert len(data_interp_linear.monitor.freqs) == len(f_dense), ( + "Interpolated data should have new frequency count" + ) + assert data_interp_linear.Ex.shape[-1] == len(mode_index), "Mode count should be preserved" + assert data_interp_linear.Ex.shape[-2] == len(f_dense), ( + "Frequency dimension should match target" + ) + + # Verify that transmission_line_data is also interpolated + assert data_interp_linear.transmission_line_data is not None, ( + "transmission_line_data should be interpolated" + ) + assert len(data_interp_linear.transmission_line_data.Z0.coords["f"]) == len(f_dense), ( + "transmission_line_data.Z0 should be interpolated to new frequencies" + ) + assert len(data_interp_linear.transmission_line_data.voltage_coeffs.coords["f"]) == len( + f_dense + ), "transmission_line_data.voltage_coeffs should be interpolated to new frequencies" + assert len(data_interp_linear.transmission_line_data.current_coeffs.coords["f"]) == len( + f_dense + ), "transmission_line_data.current_coeffs should be interpolated to new frequencies" + + # Test interpolation accuracy at midpoint + f_mid = 1.5e14 + + # Check that field interpolation is reasonable + # At f=1.5e14, mode 0 should have value approximately (1.5) * 1 * (1+1j) = 1.5*(1+1j) + field_at_mid_mode0 = data_interp_linear.Ex.sel( + f=f_mid, mode_index=0, x=0, y=0, z=0, method="nearest" + ).values + expected_field_mode0 = 1.5 * 1 * (1 + 1j) + assert np.allclose(field_at_mid_mode0, expected_field_mode0, rtol=0.01), ( + f"Field interpolation for mode 0: expected {expected_field_mode0}, got {field_at_mid_mode0}" + ) + + # Check that n_complex interpolation is reasonable + # At f=1.5e14, mode 0 should have n_eff approximately 1.5 + 1.5*0.1 + 0*0.2 = 1.65 + n_complex_at_mid_mode0 = data_interp_linear.n_complex.sel( + f=f_mid, mode_index=0, method="nearest" + ).values + expected_n_complex_mode0 = 1.5 + 1.5 * 0.1 + 0 * 0.2 + 0.01j + assert np.allclose(n_complex_at_mid_mode0, expected_n_complex_mode0, rtol=0.01), ( + f"n_complex interpolation for mode 0: expected {expected_n_complex_mode0}, got {n_complex_at_mid_mode0}" + ) + + # Check that transmission line data interpolation is reasonable + # At f=1.5e14, mode 0 should have Z0 approximately 50 + 1.5*10 + 0*20 = 65 + Z0_at_mid_mode0 = data_interp_linear.transmission_line_data.Z0.sel( + f=f_mid, mode_index=0, method="nearest" + ).values + expected_Z0_mode0 = 50 + 1.5 * 10 + 0 * 20 + assert np.allclose(Z0_at_mid_mode0, expected_Z0_mode0, rtol=0.01), ( + f"Z0 interpolation for mode 0: expected {expected_Z0_mode0}, got {Z0_at_mid_mode0}" + ) + + # At f=1.5e14, mode 0 should have voltage approximately (1.5 + 0) * (10 + 5j) = 1.5*(10+5j) + voltage_at_mid_mode0 = data_interp_linear.transmission_line_data.voltage_coeffs.sel( + f=f_mid, mode_index=0, method="nearest" + ).values + expected_voltage_mode0 = 1.5 * (10 + 5j) + assert np.allclose(voltage_at_mid_mode0, expected_voltage_mode0, rtol=0.01), ( + f"voltage_coeffs interpolation for mode 0: expected {expected_voltage_mode0}, got {voltage_at_mid_mode0}" + ) + + # At f=1.5e14, mode 0 should have current approximately (1.5 + 0) * (0.2 + 0.1j) = 1.5*(0.2+0.1j) + current_at_mid_mode0 = data_interp_linear.transmission_line_data.current_coeffs.sel( + f=f_mid, mode_index=0, method="nearest" + ).values + expected_current_mode0 = 1.5 * (0.2 + 0.1j) + assert np.allclose(current_at_mid_mode0, expected_current_mode0, rtol=0.01), ( + f"current_coeffs interpolation for mode 0: expected {expected_current_mode0}, got {current_at_mid_mode0}" + ) + + # Test at endpoints to ensure they match original values + # At f=1e14, mode 1 should have Z0 = 50 + 1*10 + 1*20 = 80 + Z0_at_start_mode1 = data_interp_linear.transmission_line_data.Z0.sel( + f=1e14, mode_index=1, method="nearest" + ).values + expected_Z0_start_mode1 = 50 + 1 * 10 + 1 * 20 + assert np.allclose(Z0_at_start_mode1, expected_Z0_start_mode1, rtol=1e-6), ( + f"Z0 at endpoint should match original: expected {expected_Z0_start_mode1}, got {Z0_at_start_mode1}" + ) + + # At f=2e14, mode 1 should have Z0 = 50 + 2*10 + 1*20 = 90 + Z0_at_end_mode1 = data_interp_linear.transmission_line_data.Z0.sel( + f=2e14, mode_index=1, method="nearest" + ).values + expected_Z0_end_mode1 = 50 + 2 * 10 + 1 * 20 + assert np.allclose(Z0_at_end_mode1, expected_Z0_end_mode1, rtol=1e-6), ( + f"Z0 at endpoint should match original: expected {expected_Z0_end_mode1}, got {Z0_at_end_mode1}" + ) diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py new file mode 100644 index 0000000000..07b3e2efc1 --- /dev/null +++ b/tests/test_components/test_mode_interp.py @@ -0,0 +1,1154 @@ +"""Tests for mode frequency interpolation.""" + +from __future__ import annotations + +import numpy as np +import pydantic.v1 as pydantic +import pytest + +import tidy3d as td +from tidy3d.plugins.mode import ModeSolver + +# from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_MODE_SPEC +from ..test_data.test_data_arrays import MODE_SPEC, SIZE_2D +from ..utils import AssertLogLevel + +# Shared test constants +FREQS_DENSE = np.linspace(1e14, 2e14, 20) + + +# ============================================================================ +# ModeInterpSpec Tests +# ============================================================================ + + +def test_interp_spec_valid_linear(): + """Test creating valid ModeInterpSpec with linear interpolation.""" + spec = td.ModeInterpSpec.uniform(num_points=5, method="linear") + assert spec.num_points == 5 + assert spec.method == "linear" + assert not spec.reduce_data # default value + + +def test_interp_spec_valid_cubic(): + """Test creating valid ModeInterpSpec with cubic interpolation.""" + spec = td.ModeInterpSpec.uniform(num_points=10, method="cubic") + assert spec.num_points == 10 + assert spec.method == "cubic" + assert not spec.reduce_data # default value + + +def test_interp_spec_default_method(): + """Test that default method is 'linear' and reduce_data is False.""" + spec = td.ModeInterpSpec.uniform(num_points=5) + assert spec.method == "linear" + assert not spec.reduce_data + + +def test_interp_spec_cubic_needs_4_points(): + """Test that cubic interpolation requires at least 4 points.""" + with pytest.raises(pydantic.ValidationError, match="Cubic interpolation requires at least 4"): + td.ModeInterpSpec.uniform(num_points=3, method="cubic") + + +def test_interp_spec_valid_poly(): + """Test creating valid ModeInterpSpec with polynomial interpolation.""" + spec = td.ModeInterpSpec.uniform(num_points=10, method="poly") + assert spec.num_points == 10 + assert spec.method == "poly" + + +def test_interp_spec_poly_needs_3_points(): + """Test that polynomial interpolation requires at least 3 points.""" + with pytest.raises( + pydantic.ValidationError, match="Polynomial interpolation requires at least 3" + ): + td.ModeInterpSpec.uniform(num_points=2, method="poly") + + +def test_interp_spec_cheb_convenience(): + """Test the convenience method for Chebyshev sampling with polynomial interpolation.""" + spec = td.ModeInterpSpec.cheb(num_points=10) + assert spec.num_points == 10 + assert spec.method == "poly" + assert isinstance(spec.sampling_spec, td.ChebSampling) + + +def test_interp_spec_uniform_convenience(): + """Test the convenience method for uniform sampling.""" + spec = td.ModeInterpSpec.uniform(num_points=5, method="linear") + assert spec.num_points == 5 + assert spec.method == "linear" + assert isinstance(spec.sampling_spec, td.UniformSampling) + + +def test_interp_spec_custom_convenience(): + """Test the convenience method for custom sampling.""" + custom_freqs = [1e14, 1.5e14, 1.8e14, 2e14] + spec = td.ModeInterpSpec.custom(freqs=custom_freqs, method="cubic") + assert spec.num_points == 4 + assert spec.method == "cubic" + assert isinstance(spec.sampling_spec, td.CustomSampling) + + +def test_interp_spec_sampling_points_uniform(): + """Test sampling_points for uniform sampling.""" + spec = td.ModeInterpSpec.uniform(num_points=5, method="linear") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + assert np.isclose(sampling[0], 1e14) + assert np.isclose(sampling[-1], 2e14) + # Check uniform spacing + diffs = np.diff(sampling) + assert np.allclose(diffs, diffs[0]) + # Check ascending order + assert np.all(np.diff(sampling) > 0), "Uniform frequencies should be in ascending order" + + +def test_interp_spec_sampling_points_cheb(): + """Test sampling_points for Chebyshev sampling.""" + spec = td.ModeInterpSpec.cheb(num_points=5) + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 5 + # Chebyshev nodes should include endpoints + assert np.isclose(sampling.min(), 1e14) + assert np.isclose(sampling.max(), 2e14) + + # Verify they are Chebyshev nodes + f_min, f_max = 1e14, 2e14 + k = np.arange(5) + expected_normalized = np.cos(k * np.pi / 4) + expected = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * expected_normalized + assert np.allclose(np.sort(sampling), np.sort(expected)) + + +def test_cheb_sampling_ascending_order(): + """Test that Chebyshev sampling returns frequencies in ascending order.""" + spec = td.ModeInterpSpec.cheb(num_points=10) + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + # Verify the frequencies are in strictly ascending order + assert len(sampling) == 10 + assert np.all(np.diff(sampling) > 0), "Chebyshev frequencies should be in ascending order" + + # Verify endpoints are correct + assert np.isclose(sampling[0], 1e14), "First frequency should be f_min" + assert np.isclose(sampling[-1], 2e14), "Last frequency should be f_max" + + +def test_interp_spec_sampling_points_custom(): + """Test sampling_points for custom sampling.""" + custom_freqs = np.array([1e14, 1.3e14, 1.7e14, 2e14]) + spec = td.ModeInterpSpec.custom(freqs=custom_freqs, method="linear") + freqs = np.linspace(1e14, 2e14, 100) + sampling = spec.sampling_points(freqs) + + assert len(sampling) == 4 + assert np.allclose(sampling, custom_freqs) + + +def test_interp_spec_min_2_points(): + """Test that at least 2 points are required.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec.uniform(num_points=1, method="linear") + + +def test_interp_spec_positive_points(): + """Test that num_points must be positive.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec.uniform(num_points=0, method="linear") + + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec.uniform(num_points=-5, method="linear") + + +def test_interp_spec_invalid_method(): + """Test that invalid interpolation method is rejected.""" + with pytest.raises(pydantic.ValidationError): + td.ModeInterpSpec.uniform(num_points=5, method="quadratic") + + +def test_interp_spec_reduce_data_true(): + """Test creating ModeInterpSpec with reduce_data=True.""" + spec = td.ModeInterpSpec.uniform(num_points=5, method="linear", reduce_data=True) + assert spec.num_points == 5 + assert spec.method == "linear" + assert spec.reduce_data + + +def test_interp_spec_reduce_data_false(): + """Test creating ModeInterpSpec with reduce_data=False (explicit).""" + spec = td.ModeInterpSpec.uniform(num_points=5, method="linear", reduce_data=False) + assert spec.num_points == 5 + assert spec.method == "linear" + assert not spec.reduce_data + + +def test_interp_spec_requires_tracking(): + """Test that ModeMonitor with interp_spec requires track_freq.""" + + with pytest.raises(pydantic.ValidationError, match="tracking"): + mode_spec_no_track = td.ModeSpec( + num_modes=2, + track_freq=None, + sort_spec=td.ModeSortSpec(track_freq=None), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + + +# ============================================================================ +# Monitor with interp_spec Tests +# ============================================================================ + + +def test_mode_monitor_valid_with_tracking(): + """Test that ModeMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + name="test", + ) + assert monitor.mode_spec.interp_spec.num_points == 5 + assert monitor.mode_spec.interp_spec.method == "linear" + + +def test_mode_solver_monitor_valid_with_tracking(): + """Test that ModeSolverMonitor validates with tracking enabled.""" + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + name="test", + ) + assert monitor.mode_spec.interp_spec.num_points == 5 + + +def test_interp_num_points_less_than_freqs(): + """Test that num_points must be less than total freqs.""" + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=25, method="linear"), + ) + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + name="test", + ) + + +def test_interp_num_points_equal_to_freqs(): + """Test that num_points equal to freqs is rejected.""" + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=20, method="linear"), + ) + + with AssertLogLevel("WARNING", contains_str="num_points"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + name="test", + ) + + +def test_interp_spec_none_allowed(): + """Test that interp_spec=None is allowed (no interpolation).""" + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + name="test", + ) + assert monitor.mode_spec.interp_spec is None + + +def test_interp_deprecated_track_freq_still_works(): + """Test that deprecated track_freq on ModeSpec still enables interpolation.""" + # Using deprecated track_freq instead of sort_spec.track_freq + with AssertLogLevel("WARNING", contains_str="deprecated"): + mode_spec = td.ModeSpec( + num_modes=2, + track_freq="central", + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + + # Should still work since _track_freq property resolves it + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + name="test", + ) + assert monitor.mode_spec.interp_spec.num_points == 5 + + +# ============================================================================ +# ModeSolver with interp_spec Tests +# ============================================================================ + + +def get_simple_sim(): + """Create a simple simulation for ModeSolver tests.""" + return td.Simulation( + size=(10, 10, 10), + grid_spec=td.GridSpec(wavelength=1.0), + structures=[ + td.Structure( + geometry=td.Box(size=(1, 1, 10)), + medium=td.Medium(permittivity=4.0), + ) + ], + run_time=1e-12, + ) + + +def test_mode_solver_valid_with_tracking(): + """Test that ModeSolver validates with tracking enabled.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + ) + assert solver.mode_spec.interp_spec.num_points == 5 + assert solver.mode_spec.interp_spec.method == "linear" + + +@td.packaging.disable_local_subpixel +def test_mode_solver_warns_num_points(): + """Test that ModeSolver warns when num_points >= num_freqs.""" + sim = get_simple_sim() + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=25, method="linear"), + ) + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + ms = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=mode_spec, + ) + _ = ms.data_raw + + +def test_mode_solver_interp_spec_none(): + """Test that ModeSolver accepts interp_spec=None.""" + sim = get_simple_sim() + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + solver = ModeSolver( + simulation=sim, + plane=plane, + freqs=FREQS_DENSE, + mode_spec=MODE_SPEC, + ) + assert solver.mode_spec.interp_spec is None + + +# ============================================================================ +# ModeSolverData.interp() Tests +# ============================================================================ + + +def get_mode_solver_data(): + """Create a simple ModeSolverData object for testing.""" + from tidy3d.components.data.data_array import GroupIndexDataArray, ModeDispersionDataArray + + from ..test_data.test_data_arrays import SIM, make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + freqs = N_COMPLEX.f.data + num_modes = len(N_COMPLEX.mode_index) + mode_indices = np.arange(num_modes) + mode_spec = td.ModeSpec(num_modes=num_modes, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test_monitor", + colocate=False, + ) + + # Create n_group_raw and dispersion_raw with same shape as n_complex + n_group_values = 1.5 + 0.1 * np.random.random((len(freqs), num_modes)) + n_group_raw = GroupIndexDataArray( + n_group_values, coords={"f": freqs, "mode_index": mode_indices} + ) + + dispersion_values = 10.0 + 2.0 * np.random.random((len(freqs), num_modes)) + dispersion_raw = ModeDispersionDataArray( + dispersion_values, coords={"f": freqs, "mode_index": mode_indices} + ) + + # Create mode data with the right frequencies + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex", symmetry=False), + Ey=make_scalar_mode_field_data_array("Ey", symmetry=False), + Ez=make_scalar_mode_field_data_array("Ez", symmetry=False), + Hx=make_scalar_mode_field_data_array("Hx", symmetry=False), + Hy=make_scalar_mode_field_data_array("Hy", symmetry=False), + Hz=make_scalar_mode_field_data_array("Hz", symmetry=False), + n_complex=N_COMPLEX, + n_group_raw=n_group_raw, + dispersion_raw=dispersion_raw, + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=SIM.discretize_monitor(monitor), + ) + return mode_data + + +def test_mode_solver_data_interp_linear(): + """Test linear interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp_in_freq(freqs=freqs_dense, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert data_interp.n_complex.shape[0] == 20 + + # Check mode dimension is preserved + assert len(data_interp.n_complex.mode_index) == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 20 + + +def test_mode_solver_data_interp_cubic(): + """Test cubic interpolation on ModeSolverData.""" + mode_data = get_mode_solver_data() + + # Need at least 4 frequencies for cubic + assert len(mode_data.monitor.freqs) >= 4 + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp_in_freq(freqs=freqs_dense, method="cubic") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 20 + assert len(data_interp.n_complex.mode_index) == len(mode_data.n_complex.mode_index) + + +def test_mode_solver_data_interp_poly(): + """Test polynomial interpolation on ModeSolverData.""" + # Create data with frequencies at Chebyshev nodes for optimal polynomial interpolation + interp_spec = td.ModeInterpSpec.cheb(num_points=5) + freqs_all = np.linspace(1e14, 2e14, 50) + freqs_cheb = interp_spec.sampling_points(freqs_all) + + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs_cheb, + mode_spec=mode_spec, + name="test_poly", + colocate=False, + ) + + from ..test_data.test_data_arrays import SIM, make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex", symmetry=False), + Ey=make_scalar_mode_field_data_array("Ey", symmetry=False), + Ez=make_scalar_mode_field_data_array("Ez", symmetry=False), + Hx=make_scalar_mode_field_data_array("Hx", symmetry=False), + Hy=make_scalar_mode_field_data_array("Hy", symmetry=False), + Hz=make_scalar_mode_field_data_array("Hz", symmetry=False), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=SIM.discretize_monitor(monitor), + ) + + # Interpolate to 50 frequencies using polynomial method + data_interp = mode_data.interp_in_freq(freqs=freqs_all, method="poly") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 50 + assert len(data_interp.n_complex.mode_index) == len(N_COMPLEX.mode_index) + + +def test_mode_solver_data_interp_poly_needs_3_source(): + """Test that polynomial interpolation fails with too few source frequencies.""" + # Create data with only 2 frequencies + freqs = np.linspace(1e14, 2e14, 2) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 3 source"): + mode_data.interp_in_freq(freqs=freqs_dense, method="poly") + + +def test_mode_solver_data_interp_preserves_modes(): + """Test that interpolation preserves mode count.""" + mode_data = get_mode_solver_data() + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to different number of frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp_in_freq(freqs=freqs_dense, method="linear") + + # Mode count should be unchanged + assert len(data_interp.n_complex.mode_index) == original_num_modes + + +def test_mode_solver_data_interp_includes_n_group_and_dispersion(): + """Test that interpolation includes n_group_raw and dispersion_raw.""" + mode_data = get_mode_solver_data() + + # Verify source data has n_group_raw and dispersion_raw + assert mode_data.n_group_raw is not None + assert mode_data.dispersion_raw is not None + assert mode_data.n_group_raw.shape == (5, len(mode_data.n_complex.mode_index)) + assert mode_data.dispersion_raw.shape == (5, len(mode_data.n_complex.mode_index)) + + # Interpolate to 20 frequencies + freqs_dense = np.linspace(mode_data.monitor.freqs[0], mode_data.monitor.freqs[-1], 20) + data_interp = mode_data.interp_in_freq(freqs=freqs_dense, method="linear") + + # Verify interpolated data has n_group_raw and dispersion_raw + assert data_interp.n_group_raw is not None + assert data_interp.dispersion_raw is not None + + # Check shapes are correct + assert data_interp.n_group_raw.shape == (20, len(data_interp.n_complex.mode_index)) + assert data_interp.dispersion_raw.shape == (20, len(data_interp.n_complex.mode_index)) + + # Check frequency coordinates are correct + assert len(data_interp.n_group_raw.coords["f"]) == 20 + assert len(data_interp.dispersion_raw.coords["f"]) == 20 + + +def test_mode_solver_data_interp_single_frequency(): + """Test that interpolation works with a single target frequency.""" + mode_data = get_mode_solver_data() + + # Original has 5 frequencies + assert len(mode_data.monitor.freqs) == 5 + original_num_modes = len(mode_data.n_complex.mode_index) + + # Interpolate to a single frequency in the middle of the range + single_freq = np.array([1.5e14]) + data_interp = mode_data.interp_in_freq(freqs=single_freq, method="linear") + + # Check frequency dimension + assert len(data_interp.monitor.freqs) == 1 + assert data_interp.n_complex.shape[0] == 1 + + # Check mode dimension is preserved + assert len(data_interp.n_complex.mode_index) == original_num_modes + + # Check field components are interpolated + for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: + field_data = getattr(data_interp, field_name) + assert field_data is not None + assert field_data.coords["f"].size == 1 + assert float(field_data.coords["f"]) == 1.5e14 + + # Check n_group_raw and dispersion_raw if present + if data_interp.n_group_raw is not None: + print(data_interp.n_group_raw.shape) + print((1, original_num_modes)) + assert data_interp.n_group_raw.shape == (1, original_num_modes) + if data_interp.dispersion_raw is not None: + assert data_interp.dispersion_raw.shape == (1, original_num_modes) + + +def test_mode_solver_data_interp_cubic_needs_4_source(): + """Test that cubic interpolation fails with too few source frequencies.""" + # Create data with only 3 frequencies + freqs = np.linspace(1e14, 2e14, 3) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array + from ..test_data.test_monitor_data import N_COMPLEX + + mode_data = td.ModeSolverData( + monitor=monitor, + Ex=make_scalar_mode_field_data_array("Ex"), + n_complex=N_COMPLEX.copy(), + symmetry=(0, 0, 0), + symmetry_center=(0, 0, 0), + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), + ) + + freqs_dense = np.linspace(1e14, 2e14, 10) + with pytest.raises(td.exceptions.DataError, match="at least 4 source"): + mode_data.interp_in_freq(freqs=freqs_dense, method="cubic") + + +def test_mode_solver_data_interp_invalid_method(): + """Test that invalid interpolation method raises error.""" + mode_data = get_mode_solver_data() + freqs_dense = np.linspace(1e14, 2e14, 10) + + with pytest.raises(td.exceptions.DataError, match="Invalid interpolation method"): + mode_data.interp_in_freq(freqs=freqs_dense, method="quadratic") + + +def test_mode_solver_data_interp_extrapolation_warning(): + """Test that extrapolation triggers a warning.""" + mode_data = get_mode_solver_data() + + # Interpolate beyond original range + freqs_extrap = np.linspace(0.5e14, 2.5e14, 10) + + with AssertLogLevel("WARNING", contains_str="outside original range"): + mode_data.interp_in_freq(freqs=freqs_extrap, method="linear") + + +# ============================================================================ +# assume_sorted Tests (for source frequencies) +# ============================================================================ + + +@pytest.mark.parametrize("method", ["linear", "cubic", "poly"]) +def test_interp_assume_sorted_source_frequencies(method): + """Test that assume_sorted correctly handles sorted vs unsorted source frequencies. + + This test verifies that: + - assume_sorted=True works correctly when source frequencies are sorted + - assume_sorted=False correctly handles unsorted source frequencies + - Both approaches produce identical interpolation results + + Parameters + ---------- + method : str + Interpolation method to test: "linear", "cubic", or "poly" + """ + from tidy3d.components.data.data_array import ModeIndexDataArray + from tidy3d.components.data.dataset import FreqDataset + + # Define number of source frequencies based on method requirements + num_source_points = {"linear": 5, "cubic": 5, "poly": 5}[method] + + # Generate sorted source frequencies using ModeInterpSpec.sampling_points + # Use Chebyshev sampling for polynomial interpolation for best accuracy + if method == "poly": + interp_spec = td.ModeInterpSpec.cheb(num_points=num_source_points) + else: + interp_spec = td.ModeInterpSpec.uniform(num_points=num_source_points, method=method) + freqs_full_range = np.linspace(1e14, 2e14, 100) # Full frequency range + freqs_source_sorted = interp_spec.sampling_points(freqs_full_range) + + # Create unsorted version by shuffling + rng = np.random.RandomState(42) # Fixed seed for reproducibility + shuffle_indices = np.arange(len(freqs_source_sorted)) + rng.shuffle(shuffle_indices) + freqs_source_unsorted = freqs_source_sorted[shuffle_indices] + + # Create test data values + mode_indices = np.arange(2) + data_values = (1.5 + 0.1j) * np.random.random((num_source_points, 2)) + + # Create dataset with SORTED source frequencies + data_sorted = ModeIndexDataArray( + data_values.copy(), coords={"f": freqs_source_sorted, "mode_index": mode_indices} + ) + + # Create dataset with UNSORTED source frequencies + # Need to reorder data values to match the shuffled frequencies + data_unsorted = ModeIndexDataArray( + data_values[shuffle_indices], + coords={"f": freqs_source_unsorted, "mode_index": mode_indices}, + ) + + # Define target frequencies for interpolation + freqs_target = np.linspace(1e14, 2e14, 50) + + # Interpolate with sorted source frequencies and assume_sorted=True + result_sorted = FreqDataset._interp_dataarray_in_freq( + data_sorted, freqs_target, method=method, assume_sorted=True + ) + + # Interpolate with unsorted source frequencies and assume_sorted=False + # This should internally sort the data before interpolating + result_unsorted = FreqDataset._interp_dataarray_in_freq( + data_unsorted, freqs_target, method=method, assume_sorted=False + ) + + # Both approaches should produce identical results + assert np.allclose(result_sorted.values, result_unsorted.values, rtol=1e-10), ( + f"{method} interpolation: sorted and unsorted source frequencies " + "should produce identical results" + ) + + # Verify the interpolation actually happened correctly + assert result_sorted.shape == (50, 2) + assert len(result_sorted.coords["f"]) == 50 + assert np.allclose(result_sorted.coords["f"].values, freqs_target) + + +# ============================================================================ +# ModeSolver Integration Tests (Phase 5) +# ============================================================================ + + +@pytest.mark.parametrize("method", ["linear", "cubic", "poly"]) +def test_mode_solver_with_interp(method): + """Test that ModeSolver uses interpolation when interp_spec is provided. + + With reduce_data=False (default), data_raw returns automatically interpolated data + with all requested frequencies. + """ + sim = get_simple_sim() + + # Create solver with 10 frequencies, reduce_data=False (default) + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method=method, reduce_data=False), + ) + + solver_with_interp = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + ) + + # The solver should have the original 10 frequencies + assert len(solver_with_interp.freqs) == 10 + + # With reduce_data=False, data_raw automatically interpolates and returns all 10 frequencies + data = solver_with_interp.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape == (10, 2) + assert data.monitor.mode_spec.interp_spec is None + + +@pytest.mark.parametrize("method", ["linear", "cubic", "poly"]) +def test_mode_solver_creates_reduced_freqs(method): + """Test that solver creates correct reduced frequency set internally.""" + sim = get_simple_sim() + + # Create solver with 20 frequencies + freqs = np.linspace(1e14, 2e14, 20) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method=method, reduce_data=True), + ) + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + ) + + # The returned data should have 5 frequencies + data = solver.data_raw + assert len(data.monitor.freqs) == 20 + assert len(data.monitor._stored_freqs) == 5 + assert data.n_complex.shape == (5, 2) + + interpolated_data = data.interpolated_copy + assert len(interpolated_data.monitor.freqs) == 20 + assert len(interpolated_data.monitor._stored_freqs) == 20 + assert interpolated_data.n_complex.shape == (20, 2) + + +def test_mode_solver_without_interp_returns_full_data(): + """Test that solver without interp_spec computes at all frequencies.""" + sim = get_simple_sim() + + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + # No interpolation + ) + + data = solver.data_raw + assert len(data.monitor.freqs) == 10 + assert data.n_complex.shape == (10, 2) + assert data.monitor.mode_spec.interp_spec is None + + +def test_mode_solver_with_reduce_data_but_small_num_freqs(): + """Test that ModeSolver with reduce_data=True but small number of frequencies + returns data at the original frequencies instead of the reduced frequencies. + """ + sim = get_simple_sim() + + # Create solver with 2 fsrequencies but reduce_data=True + freqs = np.linspace(1e14, 2e14, 2) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear", reduce_data=True), + ) + + solver = ModeSolver( + simulation=sim, + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), + freqs=freqs, + mode_spec=mode_spec, + ) + + data = solver.data_raw + assert len(data.monitor._stored_freqs) == 2 + assert len(data.monitor.freqs) == 2 + assert data.n_complex.shape == (2, 2) + assert data.monitor.mode_spec.interp_spec is None + + assert data.interpolated_copy == data + + +# ============================================================================ +# Monitor Integration Tests (Phase 6) +# ============================================================================ + + +def test_mode_monitor_with_interp_spec(): + """Test that ModeMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=3, method="linear"), + ) + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="mode_monitor", + ) + + assert monitor.mode_spec.interp_spec is not None + assert monitor.mode_spec.interp_spec.num_points == 3 + assert monitor.mode_spec.interp_spec.method == "linear" + + +def test_mode_solver_monitor_with_interp_spec(): + """Test that ModeSolverMonitor can be created with interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=4, method="cubic"), + ) + + monitor = td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="mode_solver_monitor", + ) + + assert monitor.mode_spec.interp_spec is not None + assert monitor.mode_spec.interp_spec.num_points == 4 + assert monitor.mode_spec.interp_spec.method == "cubic" + + +def test_mode_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"), + ) + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + +def test_mode_solver_monitor_warns_redundant_num_points(): + """Test warning when num_points >= number of frequencies in ModeSolverMonitor.""" + freqs = np.linspace(1e14, 2e14, 5) + mode_spec = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=6, method="linear"), + ) + + with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"): + td.ModeSolverMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + +def test_mode_monitor_interp_spec_none(): + """Test that ModeMonitor works without interp_spec.""" + freqs = np.linspace(1e14, 2e14, 10) + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + + monitor = td.ModeMonitor( + center=(0, 0, 0), + size=SIZE_2D, + freqs=freqs, + mode_spec=mode_spec, + name="test", + ) + + assert monitor.mode_spec.interp_spec is None + + +def test_mode_solver_sampling_freqs(): + """Test that ModeSolver computes modes at the correct frequencies.""" + sim = get_simple_sim() + freqs = np.linspace(1e14, 2e14, 10) + plane = td.Box(center=(0, 0, 0), size=SIZE_2D) + + # With interp_spec.num_points < len(freqs) + mode_spec_true = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=3, method="linear"), + ) + solver_true = ModeSolver( + simulation=sim, + plane=plane, + freqs=freqs, + mode_spec=mode_spec_true, + ) + assert len(solver_true._sampling_freqs) == 3 + + # With interp_spec.num_points < len(freqs) + group_index_step + mode_spec_true = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=3, method="linear"), + group_index_step=True, + ) + solver_true = ModeSolver( + simulation=sim, + plane=plane, + freqs=freqs, + mode_spec=mode_spec_true, + ) + assert len(solver_true._sampling_freqs) == 3 * 3 + + # With interp_spec.num_points > len(freqs) + mode_spec_true = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=3, method="linear"), + ) + solver_true = ModeSolver( + simulation=sim, + plane=plane, + freqs=[1e14], + mode_spec=mode_spec_true, + ) + assert len(solver_true._sampling_freqs) == 1 + + # With interp_spec.num_points > len(freqs) + group_index_step + mode_spec_true = td.ModeSpec( + num_modes=2, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=3, method="linear"), + group_index_step=True, + ) + solver_true = ModeSolver( + simulation=sim, + plane=plane, + freqs=[1e14, 2e14], + mode_spec=mode_spec_true, + ) + assert len(solver_true._sampling_freqs) == 2 * 3 + + # Without interp_spec + mode_spec_none = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) + solver_none = ModeSolver( + simulation=sim, + plane=plane, + freqs=freqs, + mode_spec=mode_spec_none, + ) + assert len(solver_none._sampling_freqs) == 10 + + # Without interp_spec, with group_index_step + mode_spec_none = td.ModeSpec( + num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central"), group_index_step=True + ) + solver_none = ModeSolver( + simulation=sim, + plane=plane, + freqs=freqs, + mode_spec=mode_spec_none, + ) + assert len(solver_none._sampling_freqs) == 10 * 3 + + +# ============================================================================ +# WavePort interp_spec Tests +# ============================================================================ + + +def make_wave_port(num_interp_points=3, method="linear"): + """Make a WavePort.""" + from tidy3d.plugins.smatrix.ports.wave import WavePort + + return WavePort( + center=(0, 0, 0), + size=(1, 1, 0), + direction="+", + name="port1", + mode_spec=td.MicrowaveModeSpec( + num_modes=1, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.uniform(num_points=num_interp_points, method=method), + ), + ) + + +def test_wave_port_to_monitors_propagates_default_interp_spec(): + """Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor.""" + + port = make_wave_port(num_interp_points=3, method="linear") + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.mode_spec.interp_spec is not None + assert monitor.mode_spec.interp_spec.num_points == 3 + assert monitor.mode_spec.interp_spec.method == "linear" + + +def test_wave_port_to_monitors_propagates_custom_interp_spec(): + """Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor.""" + custom_mode_spec = td.MicrowaveModeSpec( + num_modes=1, + sort_spec=td.ModeSortSpec(track_freq="central"), + interp_spec=td.ModeInterpSpec.cheb(num_points=8), + ) + port = make_wave_port().updated_copy(mode_spec=custom_mode_spec) + + freqs = np.linspace(1e14, 2e14, 50) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.mode_spec.interp_spec is not None + assert monitor.mode_spec.interp_spec.num_points == 8 + assert monitor.mode_spec.interp_spec.method == "poly" + + +def test_wave_port_to_monitors_propagates_none_interp_spec(): + """Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor.""" + mode_spec_no_interp = td.MicrowaveModeSpec( + num_modes=1, sort_spec=td.ModeSortSpec(track_freq="central"), interp_spec=None + ) + port = make_wave_port().updated_copy(mode_spec=mode_spec_no_interp) + + freqs = np.linspace(1e14, 2e14, 20) + monitors = port.to_monitors(freqs=freqs) + + assert len(monitors) == 1 + monitor = monitors[0] + assert isinstance(monitor, td.ModeMonitor) + assert monitor.mode_spec.interp_spec is None + + +# ============================================================================ +# Placeholder tests for future phases +# ============================================================================ diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 9c544876da..4acc881996 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -339,7 +339,14 @@ from .components.mode.simulation import ModeSimulation # modes -from .components.mode_spec import ModeSortSpec, ModeSpec +from .components.mode_spec import ( + ChebSampling, + CustomSampling, + ModeInterpSpec, + ModeSortSpec, + ModeSpec, + UniformSampling, +) # monitors from .components.monitor import ( @@ -524,6 +531,7 @@ def set_logging_level(level: str) -> None: "ChargeDataArray", "ChargeInsulatorMedium", "ChargeToleranceSpec", + "ChebSampling", "ClipOperation", "CoaxialLumpedResistor", "CompositeCurrentIntegral", @@ -558,6 +566,7 @@ def set_logging_level(level: str) -> None: "CustomLorentz", "CustomMedium", "CustomPoleResidue", + "CustomSampling", "CustomSellmeier", "CustomSourceTime", "Cylinder", @@ -702,6 +711,7 @@ def set_logging_level(level: str) -> None: "ModeAmpsDataArray", "ModeData", "ModeIndexDataArray", + "ModeInterpSpec", "ModeMonitor", "ModeSimulation", "ModeSimulationData", @@ -809,6 +819,7 @@ def set_logging_level(level: str) -> None: "UniformCurrentSource", "UniformGrid", "UniformHeatSource", + "UniformSampling", "UniformUnstructuredGrid", "UnsteadyHeatAnalysis", "UnsteadySpec", diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index fda1c82a7f..e3e4307fc2 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -3,14 +3,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union, get_args +from typing import Any, Callable, Literal, Optional, Union, get_args import numpy as np import pydantic.v1 as pd import xarray as xr from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import Axis, xyz +from tidy3d.components.types import Axis, FreqArray, xyz from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling from tidy3d.exceptions import DataError from tidy3d.log import log @@ -50,6 +50,139 @@ def data_arrs(self) -> dict: return data_arrs +class FreqDataset(Dataset, ABC): + """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" + + def _interp_in_freq_update_dict( + self, + freqs: FreqArray, + method: Literal["linear", "cubic", "poly"] = "linear", + assume_sorted: bool = False, + ) -> dict[str, DataArray]: + """Interpolate mode data to new frequency points. + + Interpolates all stored mode data (effective indices, field components, group indices, + and dispersion) from the current frequency grid to a new set of frequencies. This is + useful for obtaining mode data at many frequencies from computations at fewer frequencies, + when modes vary smoothly with frequency. + + Parameters + ---------- + freqs : FreqArray + New frequency points to interpolate to. Should generally span a similar range + as the original frequencies to avoid extrapolation. + method : Literal["linear", "cubic", "poly"] + Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source + frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source + frequencies), ``"poly"`` for polynomial interpolation using barycentric + formula (requires 3+ source frequencies). + For complex-valued data, real and imaginary parts are interpolated independently. + assume_sorted: bool = False, + Whether to assume the frequency points are sorted. + + Returns + ------- + ModeSolverData + New :class:`ModeSolverData` object with data interpolated to the requested frequencies. + + Note + ---- + Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate + near mode crossings or regions of rapid mode variation. Use frequency tracking + (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + For polynomial interpolation, source frequencies at Chebyshev nodes provide + optimal accuracy within the frequency range. + + Example + ------- + >>> # Compute modes at 5 frequencies + >>> import numpy as np + >>> freqs_sparse = np.linspace(1e14, 2e14, 5) + >>> # ... create mode_solver and compute modes ... + >>> # mode_data = mode_solver.solve() + >>> # Interpolate to 50 frequencies + >>> freqs_dense = np.linspace(1e14, 2e14, 50) + >>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear') + """ + freqs = np.array(freqs) + + modify_data = {} + for key, data in self.data_arrs.items(): + modify_data[key] = self._interp_dataarray_in_freq(data, freqs, method, assume_sorted) + + return modify_data + + @staticmethod + def _interp_dataarray_in_freq( + data: DataArray, + freqs: FreqArray, + method: Literal["linear", "cubic", "poly", "nearest"], + assume_sorted: bool = False, + ) -> DataArray: + """Interpolate a DataArray along the frequency coordinate. + + Parameters + ---------- + data : DataArray + Data array to interpolate. Must have a frequency coordinate ``"f"``. + freqs : FreqArray + New frequency points. + method : Literal["linear", "cubic", "poly", "nearest"] + Interpolation method (``"linear"``, ``"cubic"``, ``"poly"``, or ``"nearest"``). + For ``"poly"``, uses barycentric formula for polynomial interpolation. + assume_sorted: bool = False, + Whether to assume the frequency points are sorted. + + Returns + ------- + DataArray + Interpolated data array with the same structure but new frequency points. + """ + # Map 'poly' to xarray's 'barycentric' method + xr_method = "barycentric" if method == "poly" else method + + # Use xarray's built-in interpolation + # For complex data, this automatically interpolates real and imaginary parts + interp_kwargs = {"method": xr_method} + + if method == "nearest": + return data.sel(f=freqs, method="nearest") + else: + if method != "poly": + interp_kwargs["kwargs"] = {"fill_value": "extrapolate"} + return data.interp(f=freqs, assume_sorted=assume_sorted, **interp_kwargs) + + +class ModeFreqDataset(FreqDataset, ABC): + """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" + + def _apply_mode_reorder(self, sort_inds_2d): + """Apply a mode reordering along mode_index for all frequency indices. + + Parameters + ---------- + sort_inds_2d : np.ndarray + Array of shape (num_freqs, num_modes) where each row is the + permutation to apply to the mode_index for that frequency. + """ + num_freqs, num_modes = sort_inds_2d.shape + modify_data = {} + for key, data in self.data_arrs.items(): + if "mode_index" not in data.dims or "f" not in data.dims: + continue + dims_orig = data.dims + f_coord = data.coords["f"] + slices = [] + for ifreq in range(num_freqs): + sl = data.isel(f=ifreq, mode_index=sort_inds_2d[ifreq]) + slices.append(sl.assign_coords(mode_index=np.arange(num_modes))) + # Concatenate along the 'f' dimension name and then restore original frequency coordinates + data = xr.concat(slices, dim="f").assign_coords(f=f_coord).transpose(*dims_orig) + modify_data[key] = data + return self.updated_copy(**modify_data) + + class AbstractFieldDataset(Dataset, ABC): """Collection of scalar fields with some symmetry properties.""" @@ -492,7 +625,7 @@ class AuxFieldTimeDataset(AuxFieldDataset): ) -class ModeSolverDataset(ElectromagneticFieldDataset): +class ModeSolverDataset(ElectromagneticFieldDataset, ModeFreqDataset): """Dataset storing scalar components of E and H fields as a function of freq. and mode_index. Example diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 081e8ce5c6..79cd62b655 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -18,7 +18,7 @@ from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType -from tidy3d.components.mode_spec import ModeSortSpec +from tidy3d.components.mode_spec import ModeSortSpec, ModeSpec from tidy3d.components.monitor import ( AuxFieldTimeMonitor, DiffractionMonitor, @@ -45,8 +45,10 @@ ArrayFloat1D, ArrayFloat2D, Coordinate, + Direction, EMField, EpsSpecType, + FreqArray, Numpy, PolarizationBasis, Size, @@ -78,6 +80,7 @@ MixedModeDataArray, ModeAmpsDataArray, ModeDispersionDataArray, + ModeIndexDataArray, ScalarFieldDataArray, ScalarFieldTimeDataArray, TimeDataArray, @@ -103,6 +106,7 @@ MIN_ANGULAR_SAMPLES_SPHERE = 10 # Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles COS_THETA_THRESH = 1e-5 +MODE_INTERP_EXTRAPOLATION_TOLERANCE = 1e-2 class MonitorData(AbstractMonitorData, ABC): @@ -1672,11 +1676,11 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData): ) @pd.validator("eps_spec", always=True) - @skip_if_fields_missing(["monitor"]) + @skip_if_fields_missing(["n_complex"]) def eps_spec_match_mode_spec(cls, val, values): """Raise validation error if frequencies in eps_spec does not match frequency list""" if val: - mode_data_freqs = values["monitor"].freqs + mode_data_freqs = values["n_complex"].coords["f"].values if len(val) != len(mode_data_freqs): raise ValidationError( "eps_spec must be provided at the same frequencies as mode solver data." @@ -1719,7 +1723,8 @@ def overlap_sort( """ if len(self.field_components) == 0: return self.copy() - num_freqs = len(self.monitor.freqs) + + num_freqs = len(self.monitor._stored_freqs) num_modes = self.monitor.mode_spec.num_modes if track_freq == "lowest": @@ -1759,7 +1764,7 @@ def overlap_sort( # Get next frequency to sort data_to_sort = data_expanded._isel(f=[freq_id]) # Assign to the base frequency so that outer_dot will compare them - data_to_sort = data_to_sort._assign_coords(f=[self.monitor.freqs[f0_ind]]) + data_to_sort = data_to_sort._assign_coords(f=[self.monitor._stored_freqs[f0_ind]]) # Compute "sorting w.r.t. to neighbor" and overlap values sorting_one_mode, amps_one_mode = data_template._find_ordering_one_freq( @@ -1777,8 +1782,8 @@ def overlap_sort( for mode_ind in list(np.nonzero(overlap[freq_id, :] < overlap_thresh)[0]): log.warning( f"Mode '{mode_ind}' appears to undergo a discontinuous change " - f"between frequencies '{data_expanded.monitor.freqs[freq_id]}' " - f"and '{data_expanded.monitor.freqs[freq_id - step]}' " + f"between frequencies '{self.monitor._stored_freqs[freq_id]}' " + f"and '{self.monitor._stored_freqs[freq_id - step]}' " f"(overlap: '{overlap[freq_id, mode_ind]:.2f}')." ) @@ -2455,10 +2460,242 @@ class ModeSolverData(ModeData): None, title="Amplitudes", description="Unused for ModeSolverData." ) + grid_distances_primal: Union[tuple[float], tuple[float, float]] = pd.Field( + (0.0,), + title="Distances to the Primal Grid", + description="Relative distances to the primal grid locations along the normal direction in " + "the original simulation grid. Needed to recalculate grid corrections after " + "interpolating in frequency.", + ) + + grid_distances_dual: Union[tuple[float], tuple[float, float]] = pd.Field( + (0.0,), + title="Distances to the Dual Grid", + description="Relative distances to the dual grid locations along the normal direction in " + "the original simulation grid. Needed to recalculate grid corrections after " + "interpolating in frequency.", + ) + def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData: """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() + def _normalize_modes(self): + """Normalize modes. Note: this modifies ``self`` in-place.""" + scaling = np.sqrt(np.abs(self.flux)) + for field in self.field_components.values(): + field /= scaling + + @staticmethod + def _grid_correction_factors( + primal_distances: tuple[float, ...], + dual_distances: tuple[float, ...], + mode_spec: ModeSpec, + n_complex: ModeIndexDataArray, + direction: Direction, + normal_dim: str, + ) -> tuple[FreqModeDataArray, FreqModeDataArray]: + """Calculate the grid correction factors for the primal and dual grid. + + Parameters + ---------- + primal_distances : tuple[float, ...] + Relative distances to the primal grid locations along the normal direction in the original simulation grid. + dual_distances : tuple[float, ...] + Relative distances to the dual grid locations along the normal direction in the original simulation grid. + mode_spec : ModeSpec + Mode specification. + n_complex : ModeIndexDataArray + Effective indices of the modes. + direction : Direction + Direction of the propagation. + normal_dim : str + Name of the normal dimension. + + Returns + ------- + tuple[FreqModeDataArray, FreqModeDataArray] + Grid correction factors for the primal and dual grid. + """ + + distances_primal = xr.DataArray(primal_distances, coords={normal_dim: primal_distances}) + distances_dual = xr.DataArray(dual_distances, coords={normal_dim: dual_distances}) + + # Propagation phase at the primal and dual locations. The k-vector is along the propagation + # direction, so angle_theta has to be taken into account. The distance along the propagation + # direction is the distance along the normal direction over cosine(theta). + cos_theta = np.cos(mode_spec.angle_theta) + k_vec = cos_theta * 2 * np.pi * n_complex * n_complex.f / C_0 + if direction == "-": + k_vec *= -1 + phase_primal = np.exp(1j * k_vec * distances_primal) + phase_dual = np.exp(1j * k_vec * distances_dual) + + # Fields are modified by a linear interpolation to the exact monitor position + if distances_primal.size > 1: + phase_primal = phase_primal.interp(**{normal_dim: 0}) + else: + phase_primal = phase_primal.squeeze(dim=normal_dim) + if distances_dual.size > 1: + phase_dual = phase_dual.interp(**{normal_dim: 0}) + else: + phase_dual = phase_dual.squeeze(dim=normal_dim) + + return FreqModeDataArray(phase_primal), FreqModeDataArray(phase_dual) + + def interp_in_freq( + self, + freqs: FreqArray, + method: Literal["linear", "cubic", "poly"] = "linear", + renormalize: bool = True, + recalculate_grid_correction: bool = True, + assume_sorted: bool = False, + ) -> ModeSolverData: + """Interpolate mode data to new frequency points. + + Interpolates all stored mode data (effective indices, field components, group indices, + and dispersion) from the current frequency grid to a new set of frequencies. This is + useful for obtaining mode data at many frequencies from computations at fewer frequencies, + when modes vary smoothly with frequency. + + Parameters + ---------- + freqs : FreqArray + New frequency points to interpolate to. Should generally span a similar range + as the original frequencies to avoid extrapolation. + method : Literal["linear", "cubic", "poly"] + Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source + frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source + frequencies), ``"poly"`` for polynomial interpolation using barycentric + formula (requires 3+ source frequencies). + For complex-valued data, real and imaginary parts are interpolated independently. + renormalize : bool = True + Whether to renormalize the mode profiles to unity power after interpolation. + recalculate_grid_correction : bool = True + Whether to recalculate the grid correction factors after interpolation or use interpolated + grid corrections. + assume_sorted: bool = False, + Whether to assume the frequency points are sorted. + + Returns + ------- + ModeSolverData + New :class:`ModeSolverData` object with data interpolated to the requested frequencies. + + Note + ---- + Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate + near mode crossings or regions of rapid mode variation. Use frequency tracking + (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + Example + ------- + >>> # Compute modes at 5 frequencies + >>> import numpy as np + >>> freqs_sparse = np.linspace(1e14, 2e14, 5) + >>> # ... create mode_solver and compute modes ... + >>> # mode_data = mode_solver.solve() + >>> # Interpolate to 50 frequencies + >>> freqs_dense = np.linspace(1e14, 2e14, 50) + >>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear') + """ + # Validate input + freqs = np.array(freqs) + + source_freqs = self.monitor._stored_freqs + + # Validate method-specific requirements + if method == "cubic" and len(source_freqs) < 4: + raise DataError( + f"Cubic interpolation requires at least 4 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + + if method == "poly": + if len(source_freqs) < 3: + raise DataError( + f"Polynomial interpolation requires at least 3 source frequency points. " + f"Got {len(source_freqs)}. Use method='linear' instead." + ) + + if method not in ["linear", "cubic", "poly"]: + raise DataError( + f"Invalid interpolation method '{method}'. Use 'linear', 'cubic', or 'poly'." + ) + + # Check if we're extrapolating significantly and warn + freq_min, freq_max = np.min(source_freqs), np.max(source_freqs) + new_freq_min, new_freq_max = np.min(freqs), np.max(freqs) + + if new_freq_min < freq_min * ( + 1 - MODE_INTERP_EXTRAPOLATION_TOLERANCE + ) or new_freq_max > freq_max * (1 + MODE_INTERP_EXTRAPOLATION_TOLERANCE): + log.warning( + f"Interpolating to frequencies outside original range " + f"[{freq_min:.3e}, {freq_max:.3e}] Hz. New range: " + f"[{new_freq_min:.3e}, {new_freq_max:.3e}] Hz. " + "Results may be inaccurate due to extrapolation." + ) + + # Build update dictionary + update_dict = self._interp_in_freq_update_dict(freqs, method, assume_sorted) + + # Handle eps_spec if present - use nearest neighbor interpolation + if self.eps_spec is not None: + update_dict["eps_spec"] = list( + self._interp_dataarray_in_freq( + FreqDataArray(self.eps_spec, coords={"f": source_freqs}), + freqs, + "nearest", + ).data + ) + + # Update monitor with new frequencies, remove interp_spece + update_dict["monitor"] = self.monitor.updated_copy( + freqs=list(freqs), + mode_spec=self.monitor.mode_spec.updated_copy(interp_spec=None), + ) + + if recalculate_grid_correction: + update_dict["grid_primal_correction"], update_dict["grid_dual_correction"] = ( + self._grid_correction_factors( + list(self.grid_distances_primal), + list(self.grid_distances_dual), + self.monitor.mode_spec, + update_dict["n_complex"], + self.monitor.direction, + "xyz"[self.monitor._normal_axis], + ) + ) + + updated_data = self.updated_copy(**update_dict) + if renormalize: + updated_data._normalize_modes() + + return updated_data + + @property + def _reduced_data(self) -> bool: + """Whether data will be stored at fewer frequencies than the original number of frequencies.""" + return ( + self.monitor.mode_spec._is_interp_spec_applied(self.monitor.freqs) + and self.monitor.mode_spec.interp_spec.reduce_data + ) + + @property + def interpolated_copy(self) -> ModeSolverData: + """Return a copy of the data with interpolated fields.""" + if not self._reduced_data: + return self + interpolated_data = self.interp_in_freq( + freqs=self.monitor.freqs, + method=self.monitor.mode_spec.interp_spec.method, + renormalize=True, + recalculate_grid_correction=True, + assume_sorted=True, + ) + return interpolated_data + @property def time_reversed_copy(self) -> FieldData: """Make a copy of the data with direction-reversed fields. In lossy or gyrotropic systems, diff --git a/tidy3d/components/microwave/data/dataset.py b/tidy3d/components/microwave/data/dataset.py index 8810df1cc5..669e32b004 100644 --- a/tidy3d/components/microwave/data/dataset.py +++ b/tidy3d/components/microwave/data/dataset.py @@ -9,10 +9,10 @@ ImpedanceFreqModeDataArray, VoltageFreqModeDataArray, ) -from tidy3d.components.data.dataset import Dataset +from tidy3d.components.data.dataset import ModeFreqDataset -class TransmissionLineDataset(Dataset): +class TransmissionLineDataset(ModeFreqDataset): """Holds mode data that is specific to transmission lines in microwave and RF applications, like characteristic impedance. diff --git a/tidy3d/components/microwave/data/monitor_data.py b/tidy3d/components/microwave/data/monitor_data.py index 64051ca5fb..5803571a30 100644 --- a/tidy3d/components/microwave/data/monitor_data.py +++ b/tidy3d/components/microwave/data/monitor_data.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Optional +from typing import Literal, Optional import pydantic.v1 as pd import xarray as xr @@ -14,7 +14,7 @@ from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.microwave.data.dataset import TransmissionLineDataset from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor -from tidy3d.components.types import PolarizationBasis +from tidy3d.components.types import FreqArray, PolarizationBasis class AntennaMetricsData(DirectivityData, MicrowaveBaseModel): @@ -305,6 +305,25 @@ def _group_index_post_process(self, frequency_step: float) -> ModeData: super_data = super_data.updated_copy(**update_dict, path="transmission_line_data") return super_data + def _apply_mode_reorder(self, sort_inds_2d): + """Apply a mode reordering along mode_index for all frequency indices. + + Parameters + ---------- + sort_inds_2d : np.ndarray + Array of shape (num_freqs, num_modes) where each row is the + permutation to apply to the mode_index for that frequency. + """ + main_data_reordered = super()._apply_mode_reorder(sort_inds_2d) + if self.transmission_line_data is not None: + transmission_line_data_reordered = self.transmission_line_data._apply_mode_reorder( + sort_inds_2d + ) + main_data_reordered = main_data_reordered.updated_copy( + transmission_line_data=transmission_line_data_reordered + ) + return main_data_reordered + class MicrowaveModeSolverData(ModeSolverData, MicrowaveModeData): """ @@ -378,3 +397,81 @@ class MicrowaveModeSolverData(ModeSolverData, MicrowaveModeData): monitor: MicrowaveModeSolverMonitor = pd.Field( ..., title="Monitor", description="Mode monitor associated with the data." ) + + def interp_in_freq( + self, + freqs: FreqArray, + method: Literal["linear", "cubic", "poly"] = "linear", + renormalize: bool = True, + recalculate_grid_correction: bool = True, + assume_sorted: bool = False, + ) -> MicrowaveModeData: + """Interpolate mode data to new frequency points. + + Interpolates all stored mode data (effective indices, field components, group indices, + and dispersion) from the current frequency grid to a new set of frequencies. This is + useful for obtaining mode data at many frequencies from computations at fewer frequencies, + when modes vary smoothly with frequency. + + Parameters + ---------- + freqs : FreqArray + New frequency points to interpolate to. Should generally span a similar range + as the original frequencies to avoid extrapolation. + method : Literal["linear", "cubic", "cheb"] + Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source + frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source + frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric + formula (requires 3+ source frequencies at Chebyshev nodes). + For complex-valued data, real and imaginary parts are interpolated independently. + renormalize : Optional[bool] = True + Whether to renormalize the mode profiles to unity power after interpolation. + recalculate_grid_correction : bool = True + Whether to recalculate the grid correction factors after interpolation or use interpolated + grid corrections. + assume_sorted: bool = False, + Whether to assume the frequency points are sorted. + + Returns + ------- + ModeSolverData + New :class:`ModeSolverData` object with data interpolated to the requested frequencies. + + Raises + ------ + DataError + If interpolation parameters are invalid (e.g., too few source frequencies for the + chosen method, or source frequencies not at Chebyshev nodes for 'cheb' method). + + Note + ---- + Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate + near mode crossings or regions of rapid mode variation. Use frequency tracking + (``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency. + + For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the + second kind within the frequency range. + + Example + ------- + >>> # Compute modes at 5 frequencies + >>> import numpy as np + >>> freqs_sparse = np.linspace(1e14, 2e14, 5) + >>> # ... create mode_solver and compute modes ... + >>> # mode_data = mode_solver.solve() + >>> # Interpolate to 50 frequencies + >>> freqs_dense = np.linspace(1e14, 2e14, 50) + >>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear') + """ + main_data_interp = super().interp_in_freq( + freqs, method, renormalize, recalculate_grid_correction, assume_sorted + ) + if self.transmission_line_data is not None: + update_dict = self.transmission_line_data._interp_in_freq_update_dict( + freqs, method, assume_sorted + ) + transmission_line_data_interp = self.transmission_line_data.updated_copy(**update_dict) + main_data_interp = main_data_interp.updated_copy( + transmission_line_data=transmission_line_data_interp + ) + return main_data_interp diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 17a861dcbb..74218e96b8 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -12,7 +12,11 @@ import pydantic.v1 as pydantic import xarray as xr -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import ( + Tidy3dBaseModel, + cached_property, + skip_if_fields_missing, +) from tidy3d.components.boundary import PML, Absorber, Boundary, BoundarySpec, PECBoundary, StablePML from tidy3d.components.data.data_array import ( FreqModeDataArray, @@ -71,6 +75,7 @@ from tidy3d.components.types.mode_spec import ModeSpecType from tidy3d.components.types.monitor_data import ModeSolverDataType from tidy3d.components.validators import ( + _warn_interp_num_points, validate_freqs_min, validate_freqs_not_empty, ) @@ -447,7 +452,7 @@ def _num_cells_freqs_modes(self) -> tuple[int, int, int]: """Get the number of spatial points, number of freqs, and number of modes requested.""" num_cells = np.prod(self._solver_grid.num_cells) num_modes = self.mode_spec.num_modes - num_freqs = len(self.freqs) + num_freqs = len(self._sampling_freqs) return num_cells, num_freqs, num_modes @property @@ -471,39 +476,14 @@ def solve(self) -> ModeSolverData: ) return self.data - def _freqs_for_group_index(self) -> FreqArray: + def _freqs_for_group_index(self, freqs: FreqArray) -> FreqArray: """Get frequencies used to compute group index.""" - f_step = self.mode_spec.group_index_step - fractional_steps = (1 - f_step, 1, 1 + f_step) - return np.outer(self.freqs, fractional_steps).flatten() - - def _remove_freqs_for_group_index(self) -> FreqArray: - """Remove frequencies used to compute group index. - - Returns - ------- - FreqArray - Filtered frequency array with only original values. - """ - return np.array(self.freqs[1 : len(self.freqs) : 3]) - - def _get_data_with_group_index(self) -> ModeSolverData: - """:class:`.ModeSolverData` with fields, effective and group indices on unexpanded grid. + return self.mode_spec._freqs_for_group_index(freqs=self.freqs) - Returns - ------- - ModeSolverData - :class:`.ModeSolverData` object containing the effective and group indices, and mode - fields. - """ - - # create a copy with the required frequencies for numerical differentiation - mode_spec = self.mode_spec.copy(update={"group_index_step": False}) - mode_solver = self.copy( - update={"freqs": self._freqs_for_group_index(), "mode_spec": mode_spec} - ) - - return mode_solver.data_raw._group_index_post_process(self.mode_spec.group_index_step) + @cached_property + def _sampling_freqs(self) -> FreqArray: + """Get frequencies used to compute group index and interpolation.""" + return self.mode_spec._sampling_freqs_mode_solver(freqs=self.freqs) @cached_property def grid_snapped(self) -> Grid: @@ -535,8 +515,8 @@ def data_raw(self) -> ModeSolverDataType: A mode solver data type object containing the effective index and mode fields. """ - if self.mode_spec.group_index_step > 0: - return self._get_data_with_group_index() + if self.mode_spec.interp_spec is not None: + _warn_interp_num_points(self.mode_spec.interp_spec, self.freqs) if self.mode_spec.angle_rotation and np.abs(self.mode_spec.angle_theta) > 0: return self.rotated_mode_solver_data @@ -563,8 +543,26 @@ def data_raw(self) -> ModeSolverDataType: ) self._field_decay_warning(mode_solver_data.symmetry_expanded) - mode_solver_data = self._filter_components(mode_solver_data) + + if self.mode_spec.group_index_step > 0: + mode_solver_data = mode_solver_data._group_index_post_process( + self.mode_spec.group_index_step + ) + + if self.mode_spec._is_interp_spec_applied(self.freqs): + # set interp_spec back + interp_spec = self.mode_spec.interp_spec.updated_copy(reduce_data=True) + mode_solver_data = mode_solver_data.updated_copy( + monitor=mode_solver_data.monitor.updated_copy( + freqs=self.freqs, + mode_spec=self.mode_spec.updated_copy(interp_spec=interp_spec), + ) + ) + + if not self.mode_spec.interp_spec.reduce_data: + mode_solver_data = mode_solver_data.interpolated_copy + # Calculate and add the characteristic impedance if self._has_microwave_mode_spec: mode_solver_data = self._add_microwave_data(mode_solver_data) @@ -621,10 +619,10 @@ def rotated_mode_solver_data(self) -> ModeSolverData: # to compute the backward propagation mode solution using a mode solver # with direction "-". eps_spec = [] - for _ in self.freqs: + for _ in solver.freqs: eps_spec.append("tensorial_complex") # finite grid corrections - grid_factors = solver._grid_correction( + grid_factors, relative_grid_distances = solver._grid_correction( simulation=solver.simulation, plane=solver.plane, mode_spec=solver.mode_spec, @@ -643,6 +641,8 @@ def rotated_mode_solver_data(self) -> ModeSolverData: grid_primal_correction=grid_factors[0], grid_dual_correction=grid_factors[1], eps_spec=eps_spec, + grid_distances_primal=relative_grid_distances[0], + grid_distances_dual=relative_grid_distances[1], **rotated_mode_fields, ) @@ -986,7 +986,7 @@ def _mode_rotation( xyz_coords = solver.grid_snapped[field_name].to_list x, y, z = (coord.copy() for coord in xyz_coords) - f = np.atleast_1d(self.freqs) + f = np.atleast_1d(self._sampling_freqs) mode_index = np.arange(self.mode_spec.num_modes) # Initialize output arrays @@ -1176,6 +1176,12 @@ def _data_on_yee_grid(self) -> ModeSolverData: """Solve for all modes, and construct data with fields on the Yee grid.""" solver = self._reduced_simulation_copy_with_fallback + # set freqs to the sampling frequencies + # temporary remove interp_spec + solver = solver.updated_copy( + freqs=self._sampling_freqs, mode_spec=self.mode_spec.updated_copy(interp_spec=None) + ) + _, _solver_coords = solver.plane.pop_axis( solver._solver_grid.boundaries.to_list, axis=solver.normal_axis ) @@ -1211,7 +1217,7 @@ def _data_on_yee_grid(self) -> ModeSolverData: data_dict[field_name] = scalar_field_data # finite grid corrections - grid_factors = solver._grid_correction( + grid_factors, relative_grid_distances = solver._grid_correction( simulation=solver.simulation, plane=solver.plane, mode_spec=solver.mode_spec, @@ -1229,6 +1235,8 @@ def _data_on_yee_grid(self) -> ModeSolverData: grid_expanded=grid_expanded, grid_primal_correction=grid_factors[0], grid_dual_correction=grid_factors[1], + grid_distances_primal=relative_grid_distances[0], + grid_distances_dual=relative_grid_distances[1], eps_spec=eps_spec, **data_dict, ) @@ -1283,7 +1291,7 @@ def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData: data_dict[field_name] = scalar_field_data # finite grid corrections - grid_factors = self._grid_correction( + grid_factors, relative_grid_distances = self._grid_correction( simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec, @@ -1301,6 +1309,8 @@ def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData: grid_expanded=grid_expanded, grid_primal_correction=grid_factors[0], grid_dual_correction=grid_factors[1], + grid_distances_primal=relative_grid_distances[0], + grid_distances_dual=relative_grid_distances[1], eps_spec=eps_spec, **data_dict, ) @@ -1338,7 +1348,11 @@ def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData: data_dict_colocated[key] = field.interp(**colocate_coords).astype(field.dtype) # Update data - mode_solver_monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME) + mode_solver_monitor = self.to_mode_solver_monitor( + name=MODE_MONITOR_NAME, + mode_spec=mode_solver_data.monitor.mode_spec, + freqs=mode_solver_data.monitor.freqs, + ) grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor) data_dict_colocated.update({"monitor": mode_solver_monitor, "grid_expanded": grid_expanded}) mode_solver_data = mode_solver_data.updated_copy(**data_dict_colocated, deep=False) @@ -1346,9 +1360,7 @@ def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData: def _normalize_modes(self, mode_solver_data: ModeSolverData) -> None: """Normalize modes. Note: this modifies ``mode_solver_data`` in-place.""" - scaling = np.sqrt(np.abs(mode_solver_data.flux)) - for field in mode_solver_data.field_components.values(): - field /= scaling + mode_solver_data._normalize_modes() def _filter_components(self, mode_solver_data: ModeSolverData): skip_components = { @@ -1364,7 +1376,7 @@ def _filter_polarization(self, mode_solver_data: ModeSolverData): if filter_pol is None: return mode_solver_data - num_freqs = len(self.freqs) + num_freqs = len(self._sampling_freqs) num_modes = self.mode_spec.num_modes identity = np.arange(num_modes) sort_inds_2d = np.tile(identity, (num_freqs, 1)) @@ -1848,24 +1860,38 @@ def _grid_correction( mode_spec: ModeSpec, n_complex: ModeIndexDataArray, direction: Direction, - ) -> [FreqModeDataArray, FreqModeDataArray]: - """Correct the fields due to propagation on the grid. + ) -> tuple[ + tuple[FreqModeDataArray, FreqModeDataArray], tuple[tuple[float, ...], tuple[float, ...]] + ]: + """ + Compute grid correction factors for the mode fields. - Return a copy of the :class:`.ModeSolverData` with the fields renormalized to account - for propagation on a finite grid along the propagation direction. The fields are assumed to - have ``E exp(1j k r)`` dependence on the finite grid and are then resampled using linear - interpolation to the exact position of the mode plane. This is needed to correctly compute - overlap with fields that come from a :class:`.FieldMonitor` placed in the same grid. + This method calculates the phase correction factors necessary to account for propagation + on a finite numerical grid along the propagation direction (normal to the mode plane). + The correction is based on the assumed ``E * exp(1j k r)`` field dependence, where the + fields are resampled using linear interpolation to precisely match the mode plane position. + This is needed to correctly compute overlap with fields that come from + a :class:`.FieldMonitor` placed in the same grid. Parameters ---------- - grid : :class:`.Grid` - Numerical grid on which the modes are assumed to propagate. + simulation : MODE_SIMULATION_TYPE + Simulation object, which provides the grid structure. + plane : Box + The mode plane (its normal and center define the propagation direction and position). + mode_spec : ModeSpec + Mode specification with relevant propagation angle and properties. + n_complex : ModeIndexDataArray + Complex effective index array for the modes. + direction : Direction + Direction of propagation; "+" for forward or "-" for backward. Returns ------- - :class:`.ModeSolverData` - Copy of the data with renormalized fields. + tuple of FreqModeDataArray + A tuple of two FreqModeDataArray objects: + (phase_primal, phase_dual), containing the correction phase factors for the primal + (tangential E field) and dual (tangential H field) grid locations, respectively. """ normal_axis = plane.size.index(0.0) normal_pos = float(plane.center[normal_axis]) @@ -1879,27 +1905,43 @@ def _grid_correction( normal_dual = grid.centers.to_list[normal_axis] normal_dual = xr.DataArray(normal_dual, coords={normal_dim: normal_dual}) - # Propagation phase at the primal and dual locations. The k-vector is along the propagation - # direction, so angle_theta has to be taken into account. The distance along the propagation - # direction is the distance along the normal direction over cosine(theta). - cos_theta = np.cos(mode_spec.angle_theta) - k_vec = cos_theta * 2 * np.pi * n_complex * n_complex.f / C_0 - if direction == "-": - k_vec *= -1 - phase_primal = np.exp(1j * k_vec * (normal_primal - normal_pos)) - phase_dual = np.exp(1j * k_vec * (normal_dual - normal_pos)) - - # Fields are modified by a linear interpolation to the exact monitor position - if normal_primal.size > 1: - phase_primal = phase_primal.interp(**{normal_dim: normal_pos}) - else: - phase_primal = phase_primal.squeeze(dim=normal_dim) - if normal_dual.size > 1: - phase_dual = phase_dual.interp(**{normal_dim: normal_pos}) - else: - phase_dual = phase_dual.squeeze(dim=normal_dim) + def find_closest_distances_to_grid_points( + normal_pos: float, grid_coords: ArrayFloat1D + ) -> tuple[float, float]: + """Find the closest points to the normal position in the grid coordinates.""" + + if grid_coords.size == 1: + return [float(grid_coords.data[0] - normal_pos)] + + distances = grid_coords.data - normal_pos + # First, find the signed distance to the closest grid point + closest_distance_ind = np.argmin(np.abs(distances)) + closest_distance = distances[closest_distance_ind] + + # Then, if the closest distance is positive, take the previous point, otherwise take the next point + if closest_distance > 0: + first_dist = distances[closest_distance_ind - 1] + second_dist = distances[closest_distance_ind] + else: + first_dist = distances[closest_distance_ind] + second_dist = distances[closest_distance_ind + 1] + + # Return the two closest points + return [first_dist, second_dist] + + primal_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_primal) + dual_closest_distances = find_closest_distances_to_grid_points(normal_pos, normal_dual) + + grid_correction_factors = ModeSolverData._grid_correction_factors( + primal_closest_distances, + dual_closest_distances, + mode_spec, + n_complex, + direction, + normal_dim, + ) - return FreqModeDataArray(phase_primal), FreqModeDataArray(phase_dual) + return grid_correction_factors, (primal_closest_distances, dual_closest_distances) @property def _is_tensorial(self) -> bool: @@ -1934,7 +1976,13 @@ def _has_complex_eps(self) -> bool: A separate check is done inside the solver, which looks at the actual eps and mu and uses a tolerance to determine whether to use real or complex fields, so the actual behavior may differ from what's predicted by this property.""" - check_freqs = np.unique([np.amin(self.freqs), np.amax(self.freqs), np.mean(self.freqs)]) + check_freqs = np.unique( + [ + np.amin(self._sampling_freqs), + np.amax(self._sampling_freqs), + np.mean(self._sampling_freqs), + ] + ) for int_mat in self._intersecting_media: for freq in check_freqs: max_imag_eps = np.amax(np.abs(np.imag(int_mat.eps_model(freq)))) @@ -2055,7 +2103,11 @@ def to_monitor( ) def to_mode_solver_monitor( - self, name: str, colocate: Optional[bool] = None + self, + name: str, + colocate: Optional[bool] = None, + mode_spec: Optional[ModeSpec] = None, + freqs: Optional[list[float]] = None, ) -> ModeSolverMonitor: """Creates :class:`ModeSolverMonitor` from a :class:`.ModeSolver` instance. @@ -2066,12 +2118,23 @@ def to_mode_solver_monitor( colocate : bool Whether to colocate fields or compute on the Yee grid. If not provided, the value set in the :class:`.ModeSolver` instance is used. + mode_spec : ModeSpec + Mode specification to use for the monitor. + If not specified, uses the mode specification from the mode solver. + freqs : list[float] + Frequencies to include in Monitor (Hz). + If not specified, uses the frequencies from the mode solver. Returns ------- :class:`.ModeSolverMonitor` Mode monitor with specifications taken from the ModeSolver instance and ``name``. """ + if mode_spec is None: + mode_spec = self.mode_spec + + if freqs is None: + freqs = self.freqs if colocate is None: colocate = self.colocate @@ -2083,8 +2146,8 @@ def to_mode_solver_monitor( return mode_solver_monitor_type( size=self.plane.size, center=self.plane.center, - mode_spec=self.mode_spec, - freqs=self.freqs, + mode_spec=mode_spec, + freqs=freqs, direction=self.direction, colocate=colocate, conjugated_dot_product=self.conjugated_dot_product, diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index 2e836c55ec..e964ce0d87 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -26,6 +26,7 @@ from tidy3d.components.source.field import ModeSource from tidy3d.components.types import TYPE_TAG_STR, Ax, Direction, EMField, FreqArray from tidy3d.components.types.mode_spec import ModeSpecType +from tidy3d.components.validators import validate_interp_num_points from tidy3d.constants import C_0 from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log @@ -234,6 +235,8 @@ def plane_in_sim_bounds(cls, val, values): raise SetupError("'ModeSimulation.plane' must intersect 'ModeSimulation.geometry.") return val + _warn_interp_num_points = validate_interp_num_points() + def _post_init_validators(self) -> None: """Call validators taking `self` that get run after init.""" _ = self._mode_solver diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index c0d37a0bb1..0ae0256c80 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -2,7 +2,7 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from math import isclose from typing import Literal, Optional, Union @@ -14,7 +14,7 @@ from tidy3d.log import log from .base import Tidy3dBaseModel, skip_if_fields_missing -from .types import Axis2D, TrackFreq +from .types import Axis2D, FreqArray, TrackFreq GROUP_INDEX_STEP = 0.005 MODE_DATA_KEYS = Literal[ @@ -86,6 +86,358 @@ class ModeSortSpec(Tidy3dBaseModel): ) +class FrequencySamplingSpec(Tidy3dBaseModel, ABC): + """Abstract base class for frequency sampling specifications.""" + + @abstractmethod + def sampling_points(self, freqs: FreqArray) -> FreqArray: + """Compute frequency sampling points. + + Parameters + ---------- + freqs : FreqArray + Target frequency array. + + Returns + ------- + FreqArray + Array of sampling frequency points. + """ + + @property + @abstractmethod + def _num_points(self) -> int: + """Number of sampling points (internal property).""" + + +class UniformSampling(FrequencySamplingSpec): + """Uniform frequency sampling specification.""" + + num_points: int = pd.Field( + ..., + title="Number of Points", + description="Number of uniformly spaced frequency sampling points.", + ge=2, + ) + + @property + def _num_points(self) -> int: + """Number of sampling points (internal property).""" + return self.num_points + + def sampling_points(self, freqs: FreqArray) -> FreqArray: + """Compute uniformly spaced frequency sampling points. + + Parameters + ---------- + freqs : FreqArray + Target frequency array. Sampling points will span from min(freqs) to max(freqs). + + Returns + ------- + FreqArray + Array of uniformly spaced frequency points. + """ + freqs_array = np.asarray(freqs) + f_min, f_max = float(freqs_array.min()), float(freqs_array.max()) + return np.linspace(f_min, f_max, self.num_points) + + +class ChebSampling(FrequencySamplingSpec): + """Chebyshev node frequency sampling specification.""" + + num_points: int = pd.Field( + ..., + title="Number of Points", + description="Number of Chebyshev nodes for frequency sampling.", + ge=3, + ) + + @property + def _num_points(self) -> int: + """Number of sampling points (internal property).""" + return self.num_points + + def sampling_points(self, freqs: FreqArray) -> FreqArray: + """Compute Chebyshev node frequency sampling points. + + Parameters + ---------- + freqs : FreqArray + Target frequency array. Sampling points will span from min(freqs) to max(freqs). + + Returns + ------- + FreqArray + Array of Chebyshev node frequency points (second kind) in ascending order. + """ + freqs_array = np.asarray(freqs) + f_min, f_max = float(freqs_array.min()), float(freqs_array.max()) + + # Chebyshev nodes of the second kind: x_k = cos(k*pi/(n-1)) for k=0,...,n-1 + # This generates nodes from +1 (f_max) to -1 (f_min), descending order + k = np.arange(self.num_points) + nodes_normalized = np.cos(k * np.pi / (self.num_points - 1)) + # Map from [-1, 1] to [f_min, f_max] + cheb_freqs = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * nodes_normalized + # Sort to return ascending order + return cheb_freqs[::-1] + + +class CustomSampling(FrequencySamplingSpec): + """Custom frequency sampling specification.""" + + freqs: FreqArray = pd.Field( + ..., + title="Frequencies", + description="Custom array of frequency sampling points.", + ) + + @pd.validator("freqs", always=True) + def _validate_freqs(cls, val): + """Validate custom frequencies.""" + freqs_array = np.asarray(val) + if freqs_array.size < 2: + raise ValidationError("Custom sampling requires at least 2 frequency points.") + return val + + def sampling_points(self, freqs: FreqArray) -> FreqArray: + """Return the custom frequency sampling points. + + Parameters + ---------- + freqs : FreqArray + Target frequency array (not used, custom frequencies are returned as-is). + + Returns + ------- + FreqArray + Array of custom frequency points. + """ + return np.asarray(self.freqs) + + @property + def _num_points(self) -> int: + """Number of custom sampling points (internal property).""" + return len(np.asarray(self.freqs)) + + +class ModeInterpSpec(Tidy3dBaseModel): + """Specification for mode frequency interpolation. + + Allows computing modes at a reduced set of frequencies and interpolating + to obtain results at all requested frequencies. This can significantly + reduce computational cost for broadband simulations where modes vary + smoothly with frequency. + + Note + ---- + Requires frequency tracking to be enabled (``mode_spec.sort_spec.track_freq`` + must not be ``None``) to ensure mode ordering is consistent across frequencies. + + Example + ------- + >>> # Uniform sampling with linear interpolation + >>> interp_spec = ModeInterpSpec( + ... method='linear', + ... sampling_spec=UniformSampling(num_points=10) + ... ) + >>> # Chebyshev sampling with polynomial interpolation + >>> interp_spec = ModeInterpSpec.cheb(num_points=10) + >>> # Custom sampling with cubic interpolation + >>> custom_freqs = [1e14, 1.5e14, 2e14, 2.5e14] + >>> interp_spec = ModeInterpSpec.custom(method='cubic', freqs=custom_freqs) + + See Also + -------- + + :class:`ModeSolver`: + Mode solver that can use this specification for efficient broadband computation. + + :class:`ModeSolverMonitor`: + Monitor that can use this specification to reduce mode computation cost. + + :class:`ModeMonitor`: + Monitor that can use this specification to reduce mode computation cost. + """ + + sampling_spec: Union[UniformSampling, ChebSampling, CustomSampling] = pd.Field( + ..., + title="Sampling Specification", + description="Specification for frequency sampling points.", + discriminator="type", + ) + + method: Literal["linear", "cubic", "poly"] = pd.Field( + "linear", + title="Interpolation Method", + description="Method for interpolating mode data between computed frequencies. " + "'linear' uses linear interpolation (faster, requires 2+ points). " + "'cubic' uses cubic spline interpolation (smoother, more accurate, requires 4+ points). " + "'poly' uses polynomial interpolation with barycentric formula " + "(optimal for Chebyshev nodes, requires 3+ points). " + "For complex-valued data, real and imaginary parts are interpolated independently.", + ) + + reduce_data: bool = pd.Field( + False, + title="Reduce Data", + description="Applies only to :class:`ModeSolverData`. If ``True``, fields and quantities " + "are only recorded at interpolation source frequency points. " + "The data at requested frequencies can be obtained through interpolation. " + "This can significantly reduce storage and computational costs for broadband simulations. " + "Does not apply if the number of sampling points is greater than the number of monitor frequencies.", + ) + + @pd.validator("method", always=True) + @skip_if_fields_missing(["sampling_spec"]) + def _validate_method_needs_points(cls, val, values): + """Validate that the method has enough points.""" + sampling_spec = values.get("sampling_spec") + if sampling_spec is None: + return val + + num_points = sampling_spec._num_points + if val == "cubic" and num_points < 4: + raise ValidationError( + "Cubic interpolation requires at least 4 frequency points. " + f"Got {num_points} points. " + "Use method='linear' or increase num_points." + ) + if val == "poly" and num_points < 3: + raise ValidationError( + "Polynomial interpolation requires at least 3 frequency points. " + f"Got {num_points} points. " + "Use method='linear' or increase num_points." + ) + return val + + @classmethod + def uniform( + cls, + num_points: int, + method: Literal["linear", "cubic", "poly"] = "linear", + reduce_data: bool = False, + ) -> ModeInterpSpec: + """Create a ModeInterpSpec with uniform frequency sampling. + + Parameters + ---------- + num_points : int + Number of uniformly spaced sampling points. + method : Literal["linear", "cubic", "poly"] + Interpolation method. Default is 'linear'. + reduce_data : bool + Whether to reduce data storage. Default is False. + + Returns + ------- + ModeInterpSpec + Interpolation specification with uniform sampling. + + Example + ------- + >>> interp_spec = ModeInterpSpec.uniform(num_points=10, method='cubic') + """ + return cls( + method=method, + sampling_spec=UniformSampling(num_points=num_points), + reduce_data=reduce_data, + ) + + @classmethod + def cheb(cls, num_points: int, reduce_data: bool = False) -> ModeInterpSpec: + """Create a ModeInterpSpec with Chebyshev node sampling and polynomial interpolation. + + Chebyshev nodes provide optimal sampling for polynomial interpolation, + minimizing interpolation error for smooth functions. + + Parameters + ---------- + num_points : int + Number of Chebyshev nodes (minimum 3). + reduce_data : bool + Whether to reduce data storage. Default is False. + + Returns + ------- + ModeInterpSpec + Interpolation specification with Chebyshev sampling and polynomial interpolation. + + Example + ------- + >>> interp_spec = ModeInterpSpec.cheb(num_points=10) + """ + return cls( + method="poly", + sampling_spec=ChebSampling(num_points=num_points), + reduce_data=reduce_data, + ) + + @classmethod + def custom( + cls, + freqs: FreqArray, + method: Literal["linear", "cubic", "poly"] = "linear", + reduce_data: bool = False, + ) -> ModeInterpSpec: + """Create a ModeInterpSpec with custom frequency sampling. + + Parameters + ---------- + freqs : FreqArray + Custom array of frequency sampling points. + method : Literal["linear", "cubic", "poly"] + Interpolation method. Default is 'linear'. + reduce_data : bool + Whether to reduce data storage. Default is False. + + Returns + ------- + ModeInterpSpec + Interpolation specification with custom sampling. + + Example + ------- + >>> custom_freqs = [1e14, 1.5e14, 1.8e14, 2e14] + >>> interp_spec = ModeInterpSpec.custom(freqs=custom_freqs, method='cubic') + """ + return cls( + method=method, + sampling_spec=CustomSampling(freqs=freqs), + reduce_data=reduce_data, + ) + + @property + def num_points(self) -> int: + """Number of sampling points.""" + return self.sampling_spec._num_points + + def sampling_points(self, freqs: FreqArray) -> FreqArray: + """Compute frequency sampling points. + + Parameters + ---------- + freqs : FreqArray + Target frequency array. + + Returns + ------- + FreqArray + Array of frequency sampling points. + + Example + ------- + >>> import numpy as np + >>> freqs = np.linspace(1e14, 2e14, 100) + >>> interp_spec = ModeInterpSpec.cheb(num_points=10) + >>> sampling_freqs = interp_spec.sampling_points(freqs) + """ + if self.num_points > len(freqs): + return freqs + return self.sampling_spec.sampling_points(freqs) + + class AbstractModeSpec(Tidy3dBaseModel, ABC): """ Abstract base for mode specification data. @@ -199,6 +551,16 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "frequencies it can change depending on the mode tracking.", ) + interp_spec: Optional[ModeInterpSpec] = pd.Field( + None, + title="Mode frequency interpolation specification", + description="Specification for computing modes at a reduced set of frequencies and " + "interpolating to obtain results at all requested frequencies. This can significantly " + "reduce computational cost for broadband simulations where modes vary smoothly with " + "frequency. Requires frequency tracking to be enabled (``sort_spec.track_freq`` must " + "not be ``None``) to ensure consistent mode ordering across frequencies.", + ) + @pd.validator("bend_axis", always=True) @skip_if_fields_missing(["bend_radius"]) def bend_axis_given(cls, val, values): @@ -310,15 +672,71 @@ def _track_freq_deprecated(cls, val): ) return val + @classmethod + def _track_freq_from_specs( + cls, track_freq: Optional[TrackFreq], sort_spec: Optional[ModeSortSpec] + ) -> Optional[TrackFreq]: + """Resolver for tracking frequency: prefers track_freq if set, + otherwise falls back to sort_spec.track_freq.""" + if track_freq is not None: + return track_freq + if sort_spec is not None: + return sort_spec.track_freq + return None + + @pd.validator("interp_spec", always=True) + @skip_if_fields_missing(["sort_spec", "track_freq"]) + def _interp_spec_needs_tracking(cls, val, values): + """Ensure frequency tracking is enabled when using interpolation.""" + if val is None: + return val + + # Check if track_freq is enabled (prefer ModeSpec.track_freq, else sort_spec.track_freq) + track_freq = values.get("track_freq") + sort_spec = values.get("sort_spec") + if cls._track_freq_from_specs(track_freq, sort_spec) is None: + raise ValidationError( + "Mode frequency interpolation requires frequency tracking to be enabled. " + "Please set 'sort_spec.track_freq' to 'central', 'lowest', or 'highest'." + ) + + return val + @property def _track_freq(self) -> Optional[TrackFreq]: """Private resolver for tracking frequency: prefers ModeSpec.track_freq if set, otherwise falls back to ModeSortSpec.track_freq.""" - if self.track_freq is not None: - return self.track_freq - if self.sort_spec is not None: - return self.sort_spec.track_freq - return None + return self._track_freq_from_specs(self.track_freq, self.sort_spec) + + def _freqs_for_group_index(self, freqs: list[float]) -> list[float]: + """Get frequencies used to compute group index.""" + fractional_steps = (1 - self.group_index_step, 1, 1 + self.group_index_step) + return np.outer(freqs, fractional_steps).flatten() + + def _sampling_freqs_mode_solver_data(self, freqs: list[float]) -> list[float]: + """Frequencies that will be stored in ModeSolverData after group index calculation and, possibly, interpolation is applied.""" + if self.interp_spec is not None and self.interp_spec.reduce_data: + # note that if len(freqs) < interp_spec.num_points, the result will be freqs itself + freqs = self.interp_spec.sampling_points(freqs) + return freqs + + def _sampling_freqs_mode_solver( + self, + freqs: list[float], + ) -> list[float]: + """Frequencies that mode solver needs to compute modes at.""" + if self.interp_spec is not None: + # note that if len(freqs) < interp_spec.num_points, the result will be freqs itself + freqs = self.interp_spec.sampling_points(freqs) + + if self.group_index_step > 0: + freqs = self._freqs_for_group_index(freqs=freqs) + + return freqs + + def _is_interp_spec_applied(self, freqs: FreqArray) -> bool: + """Whether interp_spec is used to compute modes at the given frequencies.""" + return self.interp_spec is not None and self.interp_spec.num_points < len(freqs) class ModeSpec(AbstractModeSpec): diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index ba0c37e916..1739dbe40f 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -33,7 +33,12 @@ ObsGridArray, Size, ) -from .validators import assert_plane, validate_freqs_min, validate_freqs_not_empty +from .validators import ( + assert_plane, + validate_freqs_min, + validate_freqs_not_empty, + validate_interp_num_points, +) from .viz import ARROW_ALPHA, ARROW_COLOR_MONITOR BYTES_REAL = 4 @@ -432,6 +437,8 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: return 2 * bytes_single return bytes_single + _warn_interp_num_points = validate_interp_num_points() + class FieldMonitor(AbstractFieldMonitor, FreqMonitor): """:class:`Monitor` that records electromagnetic fields in the frequency domain. @@ -805,12 +812,20 @@ def _to_solver_monitor(self): stepping.""" return self.updated_copy(colocate=False) + @property + def _stored_freqs(self) -> list[float]: + """Return actually stored frequencies of the data.""" + # always stored at original frequencies, no matter whether interp_spec is used + return self.freqs + def storage_size(self, num_cells: int, tmesh: int) -> int: """Size of monitor storage given the number of points after discretization.""" - amps_size = 3 * BYTES_COMPLEX * len(self.freqs) * self.mode_spec.num_modes + amps_size = 3 * BYTES_COMPLEX * len(self._stored_freqs) * self.mode_spec.num_modes fields_size = 0 if self.store_fields_direction is not None: - fields_size = 6 * BYTES_COMPLEX * num_cells * len(self.freqs) * self.mode_spec.num_modes + fields_size = ( + 6 * BYTES_COMPLEX * num_cells * len(self._stored_freqs) * self.mode_spec.num_modes + ) if self.mode_spec.precision == "double": fields_size *= 2 return amps_size + fields_size @@ -846,6 +861,11 @@ class ModeSolverMonitor(AbstractModeMonitor): "like ``mode_area`` require all E-field components.", ) + @property + def _stored_freqs(self) -> list[float]: + """Return actually stored frequencies of the data.""" + return self.mode_spec._sampling_freqs_mode_solver_data(freqs=self.freqs) + @pydantic.root_validator(skip_on_failure=True) def set_store_fields(cls, values): """Ensure 'store_fields_direction' is compatible with 'direction'.""" @@ -862,7 +882,9 @@ def set_store_fields(cls, values): def storage_size(self, num_cells: int, tmesh: int) -> int: """Size of monitor storage given the number of points after discretization.""" - bytes_single = 6 * BYTES_COMPLEX * num_cells * len(self.freqs) * self.mode_spec.num_modes + bytes_single = ( + 6 * BYTES_COMPLEX * num_cells * len(self._stored_freqs) * self.mode_spec.num_modes + ) if self.mode_spec.precision == "double": return 2 * bytes_single return bytes_single diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index c6a9d9c640..d1efc03018 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -494,3 +494,35 @@ def _warn_traced_arg(cls, val, values): return val return _warn_traced_arg + + +def _warn_interp_num_points(interp_spec, freqs) -> None: + """Warn if the number of sampling points for interpolation is greater than or equal to the number of target frequencies.""" + + num_freqs = len(freqs) + + if interp_spec.num_points >= num_freqs: + log.warning( + f"'interp_spec.num_points' ({interp_spec.num_points}) is greater than or equal to " + f"the number of frequencies ({num_freqs}). Interpolation will be skipped and " + f"modes will be computed at all {num_freqs} frequencies.", + custom_loc=["mode_spec", "interp_spec", "num_points"], + ) + + +def validate_interp_num_points(): + @pydantic.root_validator(allow_reuse=True) + @skip_if_fields_missing(["freqs", "mode_spec"], root=True) + def _validate_warn_interp_num_points(cls, values): + """Warn if the number of sampling points for interpolation is greater than or equal to the number of target frequencies.""" + + interp_spec = values.get("mode_spec").interp_spec + if interp_spec is None: + return values + + freqs = values.get("freqs") + _warn_interp_num_points(interp_spec, freqs) + + return values + + return _validate_warn_interp_num_points