Skip to content

Commit d58af9b

Browse files
prototype interpolated_data property
1 parent df5727b commit d58af9b

File tree

8 files changed

+184
-74
lines changed

8 files changed

+184
-74
lines changed

tests/test_components/test_mode_interp.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import tidy3d as td
1010
from tidy3d.plugins.mode import ModeSolver
11-
from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_MODE_SPEC
11+
# from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_MODE_SPEC
1212

1313
from ..test_data.test_data_arrays import MODE_SPEC, SIZE_2D
1414
from ..utils import AssertLogLevel
@@ -323,11 +323,11 @@ def get_mode_solver_data():
323323
from tidy3d.components.data.data_array import GroupIndexDataArray, ModeDispersionDataArray
324324

325325
from ..test_data.test_data_arrays import (
326-
make_scalar_mode_field_data_array,
326+
make_scalar_mode_field_data_array, SIM
327327
)
328328
from ..test_data.test_monitor_data import N_COMPLEX
329329

330-
freqs = np.linspace(1e14, 2e14, 5)
330+
freqs = N_COMPLEX.f.data
331331
num_modes = len(N_COMPLEX.mode_index)
332332
mode_indices = np.arange(num_modes)
333333
mode_spec = td.ModeSpec(num_modes=num_modes, sort_spec=td.ModeSortSpec(track_freq="central"))
@@ -337,6 +337,7 @@ def get_mode_solver_data():
337337
freqs=freqs,
338338
mode_spec=mode_spec,
339339
name="test_monitor",
340+
colocate=False,
340341
)
341342

342343
# Create n_group_raw and dispersion_raw with same shape as n_complex
@@ -353,18 +354,18 @@ def get_mode_solver_data():
353354
# Create mode data with the right frequencies
354355
mode_data = td.ModeSolverData(
355356
monitor=monitor,
356-
Ex=make_scalar_mode_field_data_array("Ex"),
357-
Ey=make_scalar_mode_field_data_array("Ey"),
358-
Ez=make_scalar_mode_field_data_array("Ez"),
359-
Hx=make_scalar_mode_field_data_array("Hx"),
360-
Hy=make_scalar_mode_field_data_array("Hy"),
361-
Hz=make_scalar_mode_field_data_array("Hz"),
362-
n_complex=N_COMPLEX.copy(),
357+
Ex=make_scalar_mode_field_data_array("Ex", symmetry=False),
358+
Ey=make_scalar_mode_field_data_array("Ey", symmetry=False),
359+
Ez=make_scalar_mode_field_data_array("Ez", symmetry=False),
360+
Hx=make_scalar_mode_field_data_array("Hx", symmetry=False),
361+
Hy=make_scalar_mode_field_data_array("Hy", symmetry=False),
362+
Hz=make_scalar_mode_field_data_array("Hz", symmetry=False),
363+
n_complex=N_COMPLEX,
363364
n_group_raw=n_group_raw,
364365
dispersion_raw=dispersion_raw,
365366
symmetry=(0, 0, 0),
366367
symmetry_center=(0, 0, 0),
367-
grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])),
368+
grid_expanded=SIM.discretize_monitor(monitor),
368369
)
369370
return mode_data
370371

@@ -425,23 +426,24 @@ def test_mode_solver_data_interp_cheb():
425426
freqs=freqs_cheb,
426427
mode_spec=mode_spec,
427428
name="test_cheb",
429+
colocate=False,
428430
)
429431

430-
from ..test_data.test_data_arrays import make_scalar_mode_field_data_array
432+
from ..test_data.test_data_arrays import make_scalar_mode_field_data_array, SIM
431433
from ..test_data.test_monitor_data import N_COMPLEX
432434

433435
mode_data = td.ModeSolverData(
434436
monitor=monitor,
435-
Ex=make_scalar_mode_field_data_array("Ex"),
436-
Ey=make_scalar_mode_field_data_array("Ey"),
437-
Ez=make_scalar_mode_field_data_array("Ez"),
438-
Hx=make_scalar_mode_field_data_array("Hx"),
439-
Hy=make_scalar_mode_field_data_array("Hy"),
440-
Hz=make_scalar_mode_field_data_array("Hz"),
437+
Ex=make_scalar_mode_field_data_array("Ex", symmetry=False),
438+
Ey=make_scalar_mode_field_data_array("Ey", symmetry=False),
439+
Ez=make_scalar_mode_field_data_array("Ez", symmetry=False),
440+
Hx=make_scalar_mode_field_data_array("Hx", symmetry=False),
441+
Hy=make_scalar_mode_field_data_array("Hy", symmetry=False),
442+
Hz=make_scalar_mode_field_data_array("Hz", symmetry=False),
441443
n_complex=N_COMPLEX.copy(),
442444
symmetry=(0, 0, 0),
443445
symmetry_center=(0, 0, 0),
444-
grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])),
446+
grid_expanded=SIM.discretize_monitor(monitor),
445447
)
446448

447449
# Interpolate to 50 frequencies
@@ -895,7 +897,7 @@ def test_mode_monitor_interp_spec_none():
895897
# ============================================================================
896898

897899

898-
def make_wave_port():
900+
def make_wave_port(num_interp_points=3, method="linear"):
899901
"""Make a WavePort."""
900902
from tidy3d.components.microwave.path_integrals.integrals.current import (
901903
AxisAlignedCurrentIntegral,
@@ -907,20 +909,18 @@ def make_wave_port():
907909
size=(1, 1, 0),
908910
direction="+",
909911
name="port1",
910-
current_integral=AxisAlignedCurrentIntegral(
911-
center=(0, 0, 0),
912-
size=(1, 1, 0),
913-
sign="+",
914-
extrapolate_to_endpoints=True,
915-
snap_contour_to_grid=True,
912+
mode_spec=td.MicrowaveModeSpec(
913+
num_modes=1,
914+
sort_spec=td.ModeSortSpec(track_freq="central"),
915+
interp_spec=td.ModeInterpSpec(num_points=num_interp_points, method=method),
916916
),
917917
)
918918

919919

920920
def test_wave_port_to_monitors_propagates_default_interp_spec():
921921
"""Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor."""
922922

923-
port = make_wave_port()
923+
port = make_wave_port(num_interp_points=3, method="linear")
924924

925925
freqs = np.linspace(1e14, 2e14, 20)
926926
monitors = port.to_monitors(freqs=freqs)
@@ -931,14 +931,14 @@ def test_wave_port_to_monitors_propagates_default_interp_spec():
931931
assert monitor.mode_spec.interp_spec is not None
932932
assert (
933933
monitor.mode_spec.interp_spec.num_points
934-
== DEFAULT_WAVE_PORT_MODE_SPEC.interp_spec.num_points
934+
== 3
935935
)
936-
assert monitor.mode_spec.interp_spec.method == DEFAULT_WAVE_PORT_MODE_SPEC.interp_spec.method
936+
assert monitor.mode_spec.interp_spec.method == "linear"
937937

938938

939939
def test_wave_port_to_monitors_propagates_custom_interp_spec():
940940
"""Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor."""
941-
custom_mode_spec = td.ModeSpec(
941+
custom_mode_spec = td.MicrowaveModeSpec(
942942
num_modes=1,
943943
sort_spec=td.ModeSortSpec(track_freq="central"),
944944
interp_spec=td.ModeInterpSpec(num_points=8, method="cheb"),
@@ -958,7 +958,7 @@ def test_wave_port_to_monitors_propagates_custom_interp_spec():
958958

959959
def test_wave_port_to_monitors_propagates_none_interp_spec():
960960
"""Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor."""
961-
mode_spec_no_interp = td.ModeSpec(
961+
mode_spec_no_interp = td.MicrowaveModeSpec(
962962
num_modes=1, sort_spec=td.ModeSortSpec(track_freq="central"), interp_spec=None
963963
)
964964
port = make_wave_port().updated_copy(mode_spec=mode_spec_no_interp)

tidy3d/components/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,3 +1397,16 @@ def __getattribute__(self, name: str):
13971397

13981398
_LazyProxy.__name__ = proxy_name
13991399
return _LazyProxy
1400+
1401+
1402+
class InterpolatableMixin(Tidy3dBaseModel):
1403+
"""Mixin to add a `reduce_data` field to a model."""
1404+
1405+
reduce_data: bool = pydantic.Field(
1406+
False,
1407+
title="Reduce Data",
1408+
description="If `mode_spec.interp_spec` is defined, one can use this flag "
1409+
"to record fields and quatities only at interpolation source frequency points. "
1410+
"The :class:`.ModeMonitorData` at requested frequencies can be obtain through "
1411+
"the :attr:`.ModeMonitorData.interpolated_copy` property.",
1412+
)

tidy3d/components/data/monitor_data.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,7 +2504,7 @@ def interp_in_freq(
25042504
self,
25052505
freqs: FreqArray,
25062506
method: Literal["linear", "cubic", "cheb"] = "linear",
2507-
renormalize: Optional[bool] = False,
2507+
renormalize: Optional[bool] = True,
25082508
) -> ModeSolverData:
25092509
"""Interpolate mode data to new frequency points.
25102510
@@ -2524,8 +2524,11 @@ def interp_in_freq(
25242524
frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
25252525
formula (requires 3+ source frequencies at Chebyshev nodes).
25262526
For complex-valued data, real and imaginary parts are interpolated independently.
2527-
renormalize : Optional[bool] = False
2527+
renormalize : Optional[bool] = True
25282528
Whether to renormalize the mode profiles to unity power after interpolation.
2529+
recalculate_grid_correction : Optional[bool] = True
2530+
Whether to recalculate the grid correction after interpolation or use interpolated
2531+
grid corrections.
25292532
25302533
Returns
25312534
-------
@@ -2614,13 +2617,23 @@ def interp_in_freq(
26142617
update_dict["monitor"] = self.monitor.updated_copy(freqs=list(freqs))
26152618

26162619
updated_data = self.updated_copy(**update_dict)
2617-
# print(updated_data.poynting)
2618-
# print(updated_data._diff_area)
26192620
if renormalize:
26202621
updated_data._normalize_modes()
26212622

26222623
return updated_data
26232624

2625+
@property
2626+
def interpolated_copy(self) -> ModeSolverData:
2627+
"""Return a copy of the data with interpolated fields."""
2628+
if self.monitor.mode_spec.interp_spec is None or not self.monitor.reduce_data:
2629+
return self
2630+
return self.interp_in_freq(
2631+
freqs=self.monitor.freqs,
2632+
method=self.monitor.mode_spec.interp_spec.method,
2633+
renormalize=True,
2634+
monitor=self.monitor.updated_copy(reduce_data=False),
2635+
)
2636+
26242637
@property
26252638
def time_reversed_copy(self) -> FieldData:
26262639
"""Make a copy of the data with direction-reversed fields. In lossy or gyrotropic systems,

tidy3d/components/microwave/data/monitor_data.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -324,22 +324,6 @@ def _apply_mode_reorder(self, sort_inds_2d):
324324
)
325325
return main_data_reordered
326326

327-
def interp_in_freq(
328-
self,
329-
freqs: FreqArray,
330-
method: Literal["linear", "cubic", "cheb", "nearest"] = "linear",
331-
renormalize: Optional[bool] = True,
332-
) -> MicrowaveModeData:
333-
"""Interpolate mode data to new frequency points."""
334-
main_data_interp = super().interp_in_freq(freqs, method, renormalize)
335-
if self.transmission_line_data is not None:
336-
update_dict = self.transmission_line_data._interp_in_freq_update_dict(freqs, method)
337-
transmission_line_data_interp = self.transmission_line_data.updated_copy(**update_dict)
338-
main_data_interp = main_data_interp.updated_copy(
339-
transmission_line_data=transmission_line_data_interp
340-
)
341-
return main_data_interp
342-
343327

344328
class MicrowaveModeSolverData(ModeSolverData, MicrowaveModeData):
345329
"""
@@ -413,3 +397,70 @@ class MicrowaveModeSolverData(ModeSolverData, MicrowaveModeData):
413397
monitor: MicrowaveModeSolverMonitor = pd.Field(
414398
..., title="Monitor", description="Mode monitor associated with the data."
415399
)
400+
401+
def interp_in_freq(
402+
self,
403+
freqs: FreqArray,
404+
method: Literal["linear", "cubic", "cheb", "nearest"] = "linear",
405+
renormalize: Optional[bool] = True,
406+
) -> MicrowaveModeData:
407+
"""Interpolate mode data to new frequency points.
408+
409+
Interpolates all stored mode data (effective indices, field components, group indices,
410+
and dispersion) from the current frequency grid to a new set of frequencies. This is
411+
useful for obtaining mode data at many frequencies from computations at fewer frequencies,
412+
when modes vary smoothly with frequency.
413+
414+
Parameters
415+
----------
416+
freqs : FreqArray
417+
New frequency points to interpolate to. Should generally span a similar range
418+
as the original frequencies to avoid extrapolation.
419+
method : Literal["linear", "cubic", "cheb"]
420+
Interpolation method. ``"linear"`` for linear interpolation (requires 2+ source
421+
frequencies), ``"cubic"`` for cubic spline interpolation (requires 4+ source
422+
frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
423+
formula (requires 3+ source frequencies at Chebyshev nodes).
424+
For complex-valued data, real and imaginary parts are interpolated independently.
425+
renormalize : Optional[bool] = True
426+
Whether to renormalize the mode profiles to unity power after interpolation.
427+
428+
Returns
429+
-------
430+
ModeSolverData
431+
New :class:`ModeSolverData` object with data interpolated to the requested frequencies.
432+
433+
Raises
434+
------
435+
DataError
436+
If interpolation parameters are invalid (e.g., too few source frequencies for the
437+
chosen method, or source frequencies not at Chebyshev nodes for 'cheb' method).
438+
439+
Note
440+
----
441+
Interpolation assumes modes vary smoothly with frequency. Results may be inaccurate
442+
near mode crossings or regions of rapid mode variation. Use frequency tracking
443+
(``mode_spec.sort_spec.track_freq``) to help maintain mode ordering consistency.
444+
445+
For Chebyshev interpolation, source frequencies must be at Chebyshev nodes of the
446+
second kind within the frequency range.
447+
448+
Example
449+
-------
450+
>>> # Compute modes at 5 frequencies
451+
>>> import numpy as np
452+
>>> freqs_sparse = np.linspace(1e14, 2e14, 5)
453+
>>> # ... create mode_solver and compute modes ...
454+
>>> # mode_data = mode_solver.solve()
455+
>>> # Interpolate to 50 frequencies
456+
>>> freqs_dense = np.linspace(1e14, 2e14, 50)
457+
>>> # mode_data_interp = mode_data.interp(freqs=freqs_dense, method='linear')
458+
"""
459+
main_data_interp = super().interp_in_freq(freqs, method, renormalize)
460+
if self.transmission_line_data is not None:
461+
update_dict = self.transmission_line_data._interp_in_freq_update_dict(freqs, method)
462+
transmission_line_data_interp = self.transmission_line_data.updated_copy(**update_dict)
463+
main_data_interp = main_data_interp.updated_copy(
464+
transmission_line_data=transmission_line_data_interp
465+
)
466+
return main_data_interp

0 commit comments

Comments
 (0)