1818from tidy3d .components .base_sim .data .monitor_data import AbstractMonitorData
1919from tidy3d .components .grid .grid import Coords , Grid
2020from tidy3d .components .medium import Medium , MediumType
21- from tidy3d .components .mode_spec import ModeInterpSpec , ModeSortSpec
21+ from tidy3d .components .mode_spec import ModeInterpSpec , ModeSortSpec , ModeSpec
2222from tidy3d .components .monitor import (
2323 AuxFieldTimeMonitor ,
2424 DiffractionMonitor ,
4545 ArrayFloat1D ,
4646 ArrayFloat2D ,
4747 Coordinate ,
48+ Direction ,
4849 EMField ,
4950 EpsSpecType ,
5051 FreqArray ,
7980 MixedModeDataArray ,
8081 ModeAmpsDataArray ,
8182 ModeDispersionDataArray ,
83+ ModeIndexDataArray ,
8284 ScalarFieldDataArray ,
8385 ScalarFieldTimeDataArray ,
8486 TimeDataArray ,
@@ -2458,6 +2460,36 @@ class ModeSolverData(ModeData):
24582460 None , title = "Amplitudes" , description = "Unused for ModeSolverData."
24592461 )
24602462
2463+ grid_distances_primal : Union [tuple [float ], tuple [float , float ]] = pd .Field (
2464+ (0.0 ,),
2465+ title = "Distances to the Primal Grid" ,
2466+ description = "Relative distances to the primal grid locations along the normal direction in "
2467+ "the original simulation grid. Needed to recalculate grid corrections after "
2468+ "interpolating in frequency." ,
2469+ )
2470+
2471+ grid_distances_dual : Union [tuple [float ], tuple [float , float ]] = pd .Field (
2472+ (0.0 ,),
2473+ title = "Distances to the Dual Grid" ,
2474+ description = "Relative distances to the dual grid locations along the normal direction in "
2475+ "the original simulation grid. Needed to recalculate grid corrections after "
2476+ "interpolating in frequency." ,
2477+ )
2478+
2479+ @pd .validator ("eps_spec" , always = True )
2480+ @skip_if_fields_missing (["monitor" ])
2481+ def eps_spec_match_mode_spec (cls , val , values ):
2482+ """Raise validation error if frequencies in eps_spec does not match frequency list"""
2483+ if val :
2484+ mnt = values ["monitor" ]
2485+ if (mnt .reduce_data and len (val ) != mnt .mode_spec .interp_spec .num_points ) or (
2486+ not mnt .reduce_data and len (val ) != len (mnt .freqs )
2487+ ):
2488+ raise ValidationError (
2489+ "eps_spec must be provided at the same frequencies as mode solver data."
2490+ )
2491+ return val
2492+
24612493 def normalize (self , source_spectrum_fn : Callable [[float ], complex ]) -> ModeSolverData :
24622494 """Return copy of self after normalization is applied using source spectrum function."""
24632495 return self .copy ()
@@ -2468,6 +2500,63 @@ def _normalize_modes(self):
24682500 for field in self .field_components .values ():
24692501 field /= scaling
24702502
2503+ @staticmethod
2504+ def _grid_correction_factors (
2505+ primal_distances : tuple [float , ...],
2506+ dual_distances : tuple [float , ...],
2507+ mode_spec : ModeSpec ,
2508+ n_complex : ModeIndexDataArray ,
2509+ direction : Direction ,
2510+ normal_dim : str ,
2511+ ) -> tuple [FreqModeDataArray , FreqModeDataArray ]:
2512+ """Calculate the grid correction factors for the primal and dual grid.
2513+
2514+ Parameters
2515+ ----------
2516+ primal_distances : tuple[float, ...]
2517+ Relative distances to the primal grid locations along the normal direction in the original simulation grid.
2518+ dual_distances : tuple[float, ...]
2519+ Relative distances to the dual grid locations along the normal direction in the original simulation grid.
2520+ mode_spec : ModeSpec
2521+ Mode specification.
2522+ n_complex : ModeIndexDataArray
2523+ Effective indices of the modes.
2524+ direction : Direction
2525+ Direction of the propagation.
2526+ normal_dim : str
2527+ Name of the normal dimension.
2528+
2529+ Returns
2530+ -------
2531+ tuple[FreqModeDataArray, FreqModeDataArray]
2532+ Grid correction factors for the primal and dual grid.
2533+ """
2534+
2535+ distances_primal = xr .DataArray (primal_distances , coords = {normal_dim : primal_distances })
2536+ distances_dual = xr .DataArray (dual_distances , coords = {normal_dim : dual_distances })
2537+
2538+ # Propagation phase at the primal and dual locations. The k-vector is along the propagation
2539+ # direction, so angle_theta has to be taken into account. The distance along the propagation
2540+ # direction is the distance along the normal direction over cosine(theta).
2541+ cos_theta = np .cos (mode_spec .angle_theta )
2542+ k_vec = cos_theta * 2 * np .pi * n_complex * n_complex .f / C_0
2543+ if direction == "-" :
2544+ k_vec *= - 1
2545+ phase_primal = np .exp (1j * k_vec * distances_primal )
2546+ phase_dual = np .exp (1j * k_vec * distances_dual )
2547+
2548+ # Fields are modified by a linear interpolation to the exact monitor position
2549+ if distances_primal .size > 1 :
2550+ phase_primal = phase_primal .interp (** {normal_dim : 0 })
2551+ else :
2552+ phase_primal = phase_primal .squeeze (dim = normal_dim )
2553+ if distances_dual .size > 1 :
2554+ phase_dual = phase_dual .interp (** {normal_dim : 0 })
2555+ else :
2556+ phase_dual = phase_dual .squeeze (dim = normal_dim )
2557+
2558+ return FreqModeDataArray (phase_primal ), FreqModeDataArray (phase_dual )
2559+
24712560 @staticmethod
24722561 def _validate_cheb_nodes (freqs : np .ndarray ) -> None :
24732562 """Validate that frequencies are approximately at Chebyshev nodes.
@@ -2504,7 +2593,8 @@ def interp_in_freq(
25042593 self ,
25052594 freqs : FreqArray ,
25062595 method : Literal ["linear" , "cubic" , "cheb" ] = "linear" ,
2507- renormalize : Optional [bool ] = True ,
2596+ renormalize : bool = True ,
2597+ recalculate_grid_correction : bool = True ,
25082598 ) -> ModeSolverData :
25092599 """Interpolate mode data to new frequency points.
25102600
@@ -2524,10 +2614,10 @@ def interp_in_freq(
25242614 frequencies), ``"cheb"`` for Chebyshev polynomial interpolation using barycentric
25252615 formula (requires 3+ source frequencies at Chebyshev nodes).
25262616 For complex-valued data, real and imaginary parts are interpolated independently.
2527- renormalize : Optional[ bool] = True
2617+ renormalize : bool = True
25282618 Whether to renormalize the mode profiles to unity power after interpolation.
2529- recalculate_grid_correction : Optional[ bool] = True
2530- Whether to recalculate the grid correction after interpolation or use interpolated
2619+ recalculate_grid_correction : bool = True
2620+ Whether to recalculate the grid correction factors after interpolation or use interpolated
25312621 grid corrections.
25322622
25332623 Returns
@@ -2563,7 +2653,11 @@ def interp_in_freq(
25632653 """
25642654 # Validate input
25652655 freqs = np .array (freqs )
2656+
25662657 source_freqs = np .array (self .monitor .freqs )
2658+ if self .monitor .reduce_data :
2659+ # it is validated that if reduce_data is True, then interp_spec is not None
2660+ source_freqs = self .monitor .mode_spec .interp_spec .sampling_points (source_freqs )
25672661
25682662 # Validate method-specific requirements
25692663 if method == "cubic" and len (source_freqs ) < 4 :
@@ -2607,14 +2701,30 @@ def interp_in_freq(
26072701 if self .eps_spec is not None :
26082702 update_dict ["eps_spec" ] = list (
26092703 self ._interp_dataarray_in_freq (
2610- FreqDataArray (self .eps_spec , coords = {"f" : self . monitor . freqs }),
2704+ FreqDataArray (self .eps_spec , coords = {"f" : source_freqs }),
26112705 freqs ,
26122706 "nearest" ,
26132707 ).data
26142708 )
26152709
2616- # Update monitor with new frequencies
2617- update_dict ["monitor" ] = self .monitor .updated_copy (freqs = list (freqs ))
2710+ # Update monitor with new frequencies, remove interp_spec and set reduce_data to False
2711+ update_dict ["monitor" ] = self .monitor .updated_copy (
2712+ freqs = list (freqs ),
2713+ mode_spec = self .monitor .mode_spec .updated_copy (interp_spec = None ),
2714+ reduce_data = False ,
2715+ )
2716+
2717+ if recalculate_grid_correction :
2718+ update_dict ["grid_primal_correction" ], update_dict ["grid_dual_correction" ] = (
2719+ self ._grid_correction_factors (
2720+ list (self .grid_distances_primal ),
2721+ list (self .grid_distances_dual ),
2722+ self .monitor .mode_spec ,
2723+ update_dict ["n_complex" ],
2724+ self .monitor .direction ,
2725+ "xyz" [self .monitor ._normal_axis ],
2726+ )
2727+ )
26182728
26192729 updated_data = self .updated_copy (** update_dict )
26202730 if renormalize :
@@ -2627,12 +2737,13 @@ def interpolated_copy(self) -> ModeSolverData:
26272737 """Return a copy of the data with interpolated fields."""
26282738 if self .monitor .mode_spec .interp_spec is None or not self .monitor .reduce_data :
26292739 return self
2630- return self .interp_in_freq (
2740+ interpolated_data = self .interp_in_freq (
26312741 freqs = self .monitor .freqs ,
26322742 method = self .monitor .mode_spec .interp_spec .method ,
26332743 renormalize = True ,
2634- monitor = self . monitor . updated_copy ( reduce_data = False ) ,
2744+ recalculate_grid_correction = True ,
26352745 )
2746+ return interpolated_data .updated_copy (monitor = self .monitor .updated_copy (reduce_data = False ))
26362747
26372748 @property
26382749 def time_reversed_copy (self ) -> FieldData :
0 commit comments