77from abc import ABC
88from math import isclose
99from os import PathLike
10- from typing import Any , Callable , Literal , Optional , Union , get_args
10+ from typing import Any , Callable , Literal , Optional , Self , SupportsComplex , Union , get_args
1111
1212import autograd .numpy as np
1313import xarray as xr
14+ from numpy .typing import NDArray
1415from pandas import DataFrame
1516from pydantic import Field , model_validator
1617
103104# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104105COS_THETA_THRESH = 1e-5
105106
107+ GRID_CORRECTION_TYPE = Union [
108+ float ,
109+ FreqDataArray ,
110+ TimeDataArray ,
111+ FreqModeDataArray ,
112+ EMEFreqModeDataArray ,
113+ ]
114+
106115
107116class MonitorData (AbstractMonitorData , ABC ):
108117 """
@@ -171,7 +180,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
171180 return "-" if direction == "+" else "+"
172181
173182 @staticmethod
174- def get_amplitude (x ) -> complex :
183+ def get_amplitude (x : Union [ DataArray , SupportsComplex ] ) -> complex :
175184 """Get the complex amplitude out of some data."""
176185
177186 if isinstance (x , DataArray ):
@@ -213,7 +222,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
213222 )
214223
215224 @model_validator (mode = "after" )
216- def warn_missing_grid_expanded (self ):
225+ def warn_missing_grid_expanded (self ) -> Self :
217226 """If ``grid_expanded`` not provided and fields data is present, warn that some methods
218227 will break."""
219228 field_comps = ["Ex" , "Ey" , "Ez" , "Hx" , "Hy" , "Hz" ]
@@ -226,15 +235,15 @@ def warn_missing_grid_expanded(self):
226235 )
227236 return self
228237
229- _require_sym_center = required_if_symmetry_present ("symmetry_center" )
230- _require_grid_expanded = required_if_symmetry_present ("grid_expanded" )
238+ _require_sym_center : Callable [[ Any ], Any ] = required_if_symmetry_present ("symmetry_center" )
239+ _require_grid_expanded : Callable [[ Any ], Any ] = required_if_symmetry_present ("grid_expanded" )
231240
232241 def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
233242 """Coordinates in the expanded grid corresponding to a given field component."""
234243 return self .grid_expanded [self .grid_locations [field_name ]]
235244
236245 @property
237- def symmetry_expanded (self ):
246+ def symmetry_expanded (self ) -> Self :
238247 """Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
239248 any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
240249 data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -391,27 +400,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
391400class ElectromagneticFieldData (AbstractFieldData , ElectromagneticFieldDataset , ABC ):
392401 """Collection of electromagnetic fields."""
393402
394- grid_primal_correction : Union [
395- float ,
396- FreqDataArray ,
397- TimeDataArray ,
398- FreqModeDataArray ,
399- EMEFreqModeDataArray ,
400- ] = Field (
403+ grid_primal_correction : GRID_CORRECTION_TYPE = Field (
401404 1.0 ,
402405 title = "Field correction factor" ,
403406 description = "Correction factor that needs to be applied for data corresponding to a 2D "
404407 "monitor to take into account the finite grid in the normal direction in the simulation in "
405408 "which the data was computed. The factor is applied to fields defined on the primal grid "
406409 "locations along the normal direction." ,
407410 )
408- grid_dual_correction : Union [
409- float ,
410- FreqDataArray ,
411- TimeDataArray ,
412- FreqModeDataArray ,
413- EMEFreqModeDataArray ,
414- ] = Field (
411+ grid_dual_correction : GRID_CORRECTION_TYPE = Field (
415412 1.0 ,
416413 title = "Field correction factor" ,
417414 description = "Correction factor that needs to be applied for data corresponding to a 2D "
@@ -420,15 +417,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
420417 "locations along the normal direction." ,
421418 )
422419
423- def _expanded_grid_field_coords (self , field_name : str ):
420+ def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
424421 """Coordinates in the expanded grid corresponding to a given field component."""
425422 if self .monitor .colocate :
426423 bounds_dict = self .grid_expanded .boundaries .to_dict
427424 return Coords (** {key : val [:- 1 ] for key , val in bounds_dict .items ()})
428425 return self .grid_expanded [self .grid_locations [field_name ]]
429426
430427 @property
431- def _grid_correction_dict (self ):
428+ def _grid_correction_dict (self ) -> dict [ str , GRID_CORRECTION_TYPE ] :
432429 """Return the primal and dual finite grid correction factors as a dictionary."""
433430 return {
434431 "grid_primal_correction" : self .grid_primal_correction ,
@@ -937,7 +934,7 @@ def outer_dot(
937934 d_area = self ._diff_area .expand_dims (dim = {"f" : f }, axis = 2 ).to_numpy ()
938935
939936 # function to apply at each pair of mode indices before integrating
940- def fn (fields_1 , fields_2 ) :
937+ def fn (fields_1 : dict [ str , NDArray ], fields_2 : dict [ str , NDArray ]) -> NDArray :
941938 e_self_1 = fields_1 [e_1 ]
942939 e_self_2 = fields_1 [e_2 ]
943940 h_self_1 = fields_1 [h_1 ]
@@ -978,7 +975,7 @@ def _outer_fn_summation(
978975 outer_dim_1 : str ,
979976 outer_dim_2 : str ,
980977 sum_dims : list [str ],
981- fn : Callable ,
978+ fn : Callable [[ dict [ str , NDArray ], NDArray ], NDArray ] ,
982979 ) -> DataArray :
983980 """
984981 Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1676,7 +1673,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16761673 )
16771674
16781675 @model_validator (mode = "after" )
1679- def eps_spec_match_mode_spec (self ):
1676+ def eps_spec_match_mode_spec (self ) -> Self :
16801677 """Raise validation error if frequencies in eps_spec does not match frequency list"""
16811678 if self .eps_spec :
16821679 mode_data_freqs = self .monitor .freqs
@@ -1686,7 +1683,7 @@ def eps_spec_match_mode_spec(self):
16861683 )
16871684 return self
16881685
1689- def normalize (self , source_spectrum_fn ) -> ModeData :
1686+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> Self :
16901687 """Return copy of self after normalization is applied using source spectrum function."""
16911688 source_freq_amps = source_spectrum_fn (self .amps .f )[None , :, None ]
16921689 new_amps = (self .amps / source_freq_amps ).astype (self .amps .dtype )
@@ -1808,7 +1805,7 @@ def overlap_sort(
18081805
18091806 return data_reordered .updated_copy (monitor = monitor_updated , deep = False , validate = False )
18101807
1811- def _isel (self , ** isel_kwargs : Any ):
1808+ def _isel (self , ** isel_kwargs : Any ) -> Self :
18121809 """Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
18131810 mode index. Used in ``overlap_sort`` but not officially supported since for example
18141811 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1822,12 +1819,11 @@ def _isel(self, **isel_kwargs: Any):
18221819 }
18231820 return self .updated_copy (** update_dict , deep = False , validate = False )
18241821
1825- def _assign_coords (self , ** assign_coords_kwargs : Any ):
1822+ def _assign_coords (self , ** assign_coords_kwargs : Any ) -> Self :
18261823 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18271824 mode index. Used in ``overlap_sort`` but not officially supported since for example
18281825 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
18291826 newly created data."""
1830-
18311827 update_dict = dict (self ._grid_correction_dict , ** self .field_components )
18321828 update_dict = {
18331829 key : field .assign_coords (** assign_coords_kwargs ) for key , field in update_dict .items ()
@@ -2243,7 +2239,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22432239
22442240 return src_adj
22452241
2246- def _apply_mode_reorder (self , sort_inds_2d ) :
2242+ def _apply_mode_reorder (self , sort_inds_2d : NDArray ) -> Self :
22472243 """Apply a mode reordering along mode_index for all frequency indices.
22482244
22492245 Parameters
@@ -2342,7 +2338,7 @@ def sort_modes(
23422338 sort_inds_2d = np .tile (identity , (num_freqs , 1 ))
23432339
23442340 # Helper to compute ordered indices within a subset
2345- def _order_indices (indices , vals_all ) :
2341+ def _order_indices (indices : NDArray , vals_all : DataArray ) -> NDArray :
23462342 if indices .size == 0 :
23472343 return indices
23482344 vals = vals_all .isel (mode_index = indices )
@@ -2461,7 +2457,7 @@ class ModeSolverData(ModeData):
24612457 description = "Unused for ModeSolverData." ,
24622458 )
24632459
2464- def normalize (self , source_spectrum_fn : Callable [[float ], complex ]) -> ModeSolverData :
2460+ def normalize (self , source_spectrum_fn : Callable [[DataArray ], NDArray ]) -> ModeSolverData :
24652461 """Return copy of self after normalization is applied using source spectrum function."""
24662462 return self .copy ()
24672463
@@ -2554,7 +2550,7 @@ def _make_adjoint_sources(
25542550 "computation."
25552551 )
25562552
2557- def normalize (self , source_spectrum_fn ) -> FluxData :
2553+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> FluxData :
25582554 """Return copy of self after normalization is applied using source spectrum function."""
25592555 source_freq_amps = source_spectrum_fn (self .flux .f )
25602556 source_power = abs (source_freq_amps ) ** 2
@@ -3164,7 +3160,7 @@ def z(self) -> np.ndarray:
31643160 return self .Etheta .z .values
31653161
31663162 @property
3167- def tangential_dims (self ):
3163+ def tangential_dims (self ) -> list [ str ] :
31683164 tangential_dims = ["x" , "y" , "z" ]
31693165 tangential_dims .pop (self .monitor .proj_axis )
31703166 return tangential_dims
0 commit comments