55import struct
66import warnings
77from abc import ABC
8+ from collections .abc import Mapping
89from math import isclose
910from os import PathLike
10- from typing import Any , Callable , Literal , Optional , Union , get_args
11+ from typing import Any , Callable , Literal , Optional , Self , SupportsComplex , Union , get_args
1112
1213import autograd .numpy as np
1314import xarray as xr
15+ from numpy ._typing import NDArray
1416from pandas import DataFrame
1517from pydantic import Field , model_validator
1618
103105# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104106COS_THETA_THRESH = 1e-5
105107
108+ GRID_CORRECTION_TYPE = Union [
109+ float ,
110+ FreqDataArray ,
111+ TimeDataArray ,
112+ FreqModeDataArray ,
113+ EMEFreqModeDataArray ,
114+ ]
115+
106116
107117class MonitorData (AbstractMonitorData , ABC ):
108118 """
@@ -171,7 +181,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
171181 return "-" if direction == "+" else "+"
172182
173183 @staticmethod
174- def get_amplitude (x ) -> complex :
184+ def get_amplitude (x : Union [ DataArray , SupportsComplex ] ) -> complex :
175185 """Get the complex amplitude out of some data."""
176186
177187 if isinstance (x , DataArray ):
@@ -213,7 +223,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
213223 )
214224
215225 @model_validator (mode = "after" )
216- def warn_missing_grid_expanded (self ):
226+ def warn_missing_grid_expanded (self ) -> Self :
217227 """If ``grid_expanded`` not provided and fields data is present, warn that some methods
218228 will break."""
219229 field_comps = ["Ex" , "Ey" , "Ez" , "Hx" , "Hy" , "Hz" ]
@@ -226,15 +236,15 @@ def warn_missing_grid_expanded(self):
226236 )
227237 return self
228238
229- _require_sym_center = required_if_symmetry_present ("symmetry_center" )
230- _require_grid_expanded = required_if_symmetry_present ("grid_expanded" )
239+ _require_sym_center : Callable [[ Any ], Any ] = required_if_symmetry_present ("symmetry_center" )
240+ _require_grid_expanded : Callable [[ Any ], Any ] = required_if_symmetry_present ("grid_expanded" )
231241
232242 def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
233243 """Coordinates in the expanded grid corresponding to a given field component."""
234244 return self .grid_expanded [self .grid_locations [field_name ]]
235245
236246 @property
237- def symmetry_expanded (self ):
247+ def symmetry_expanded (self ) -> Self :
238248 """Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
239249 any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
240250 data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -391,27 +401,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
391401class ElectromagneticFieldData (AbstractFieldData , ElectromagneticFieldDataset , ABC ):
392402 """Collection of electromagnetic fields."""
393403
394- grid_primal_correction : Union [
395- float ,
396- FreqDataArray ,
397- TimeDataArray ,
398- FreqModeDataArray ,
399- EMEFreqModeDataArray ,
400- ] = Field (
404+ grid_primal_correction : GRID_CORRECTION_TYPE = Field (
401405 1.0 ,
402406 title = "Field correction factor" ,
403407 description = "Correction factor that needs to be applied for data corresponding to a 2D "
404408 "monitor to take into account the finite grid in the normal direction in the simulation in "
405409 "which the data was computed. The factor is applied to fields defined on the primal grid "
406410 "locations along the normal direction." ,
407411 )
408- grid_dual_correction : Union [
409- float ,
410- FreqDataArray ,
411- TimeDataArray ,
412- FreqModeDataArray ,
413- EMEFreqModeDataArray ,
414- ] = Field (
412+ grid_dual_correction : GRID_CORRECTION_TYPE = Field (
415413 1.0 ,
416414 title = "Field correction factor" ,
417415 description = "Correction factor that needs to be applied for data corresponding to a 2D "
@@ -420,15 +418,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
420418 "locations along the normal direction." ,
421419 )
422420
423- def _expanded_grid_field_coords (self , field_name : str ):
421+ def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
424422 """Coordinates in the expanded grid corresponding to a given field component."""
425423 if self .monitor .colocate :
426424 bounds_dict = self .grid_expanded .boundaries .to_dict
427425 return Coords (** {key : val [:- 1 ] for key , val in bounds_dict .items ()})
428426 return self .grid_expanded [self .grid_locations [field_name ]]
429427
430428 @property
431- def _grid_correction_dict (self ):
429+ def _grid_correction_dict (self ) -> dict [ str , GRID_CORRECTION_TYPE ] :
432430 """Return the primal and dual finite grid correction factors as a dictionary."""
433431 return {
434432 "grid_primal_correction" : self .grid_primal_correction ,
@@ -937,7 +935,7 @@ def outer_dot(
937935 d_area = self ._diff_area .expand_dims (dim = {"f" : f }, axis = 2 ).to_numpy ()
938936
939937 # function to apply at each pair of mode indices before integrating
940- def fn (fields_1 , fields_2 ) :
938+ def fn (fields_1 : dict [ str , NDArray ], fields_2 : dict [ str , NDArray ]) -> NDArray :
941939 e_self_1 = fields_1 [e_1 ]
942940 e_self_2 = fields_1 [e_2 ]
943941 h_self_1 = fields_1 [h_1 ]
@@ -978,7 +976,7 @@ def _outer_fn_summation(
978976 outer_dim_1 : str ,
979977 outer_dim_2 : str ,
980978 sum_dims : list [str ],
981- fn : Callable ,
979+ fn : Callable [[ dict [ str , NDArray ], NDArray ], NDArray ] ,
982980 ) -> DataArray :
983981 """
984982 Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1676,7 +1674,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16761674 )
16771675
16781676 @model_validator (mode = "after" )
1679- def eps_spec_match_mode_spec (self ):
1677+ def eps_spec_match_mode_spec (self ) -> Self :
16801678 """Raise validation error if frequencies in eps_spec does not match frequency list"""
16811679 if self .eps_spec :
16821680 mode_data_freqs = self .monitor .freqs
@@ -1686,7 +1684,7 @@ def eps_spec_match_mode_spec(self):
16861684 )
16871685 return self
16881686
1689- def normalize (self , source_spectrum_fn ) -> ModeData :
1687+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> Self :
16901688 """Return copy of self after normalization is applied using source spectrum function."""
16911689 source_freq_amps = source_spectrum_fn (self .amps .f )[None , :, None ]
16921690 new_amps = (self .amps / source_freq_amps ).astype (self .amps .dtype )
@@ -1808,7 +1806,7 @@ def overlap_sort(
18081806
18091807 return data_reordered .updated_copy (monitor = monitor_updated , deep = False , validate = False )
18101808
1811- def _isel (self , ** isel_kwargs : Any ):
1809+ def _isel (self , ** isel_kwargs : Any ) -> Self :
18121810 """Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
18131811 mode index. Used in ``overlap_sort`` but not officially supported since for example
18141812 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1822,7 +1820,7 @@ def _isel(self, **isel_kwargs: Any):
18221820 }
18231821 return self .updated_copy (** update_dict , deep = False , validate = False )
18241822
1825- def _assign_coords (self , ** assign_coords_kwargs : Any ) :
1823+ def _assign_coords (self , ** assign_coords_kwargs : Optional [ Mapping ]) -> Self :
18261824 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18271825 mode index. Used in ``overlap_sort`` but not officially supported since for example
18281826 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -2243,7 +2241,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22432241
22442242 return src_adj
22452243
2246- def _apply_mode_reorder (self , sort_inds_2d ) :
2244+ def _apply_mode_reorder (self , sort_inds_2d : NDArray ) -> Self :
22472245 """Apply a mode reordering along mode_index for all frequency indices.
22482246
22492247 Parameters
@@ -2342,7 +2340,7 @@ def sort_modes(
23422340 sort_inds_2d = np .tile (identity , (num_freqs , 1 ))
23432341
23442342 # Helper to compute ordered indices within a subset
2345- def _order_indices (indices , vals_all ) :
2343+ def _order_indices (indices : NDArray , vals_all : DataArray ) -> NDArray :
23462344 if indices .size == 0 :
23472345 return indices
23482346 vals = vals_all .isel (mode_index = indices )
@@ -2461,7 +2459,7 @@ class ModeSolverData(ModeData):
24612459 description = "Unused for ModeSolverData." ,
24622460 )
24632461
2464- def normalize (self , source_spectrum_fn : Callable [[float ], complex ]) -> ModeSolverData :
2462+ def normalize (self , source_spectrum_fn : Callable [[DataArray ], NDArray ]) -> ModeSolverData :
24652463 """Return copy of self after normalization is applied using source spectrum function."""
24662464 return self .copy ()
24672465
@@ -2554,7 +2552,7 @@ def _make_adjoint_sources(
25542552 "computation."
25552553 )
25562554
2557- def normalize (self , source_spectrum_fn ) -> FluxData :
2555+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> FluxData :
25582556 """Return copy of self after normalization is applied using source spectrum function."""
25592557 source_freq_amps = source_spectrum_fn (self .flux .f )
25602558 source_power = abs (source_freq_amps ) ** 2
@@ -3164,7 +3162,7 @@ def z(self) -> np.ndarray:
31643162 return self .Etheta .z .values
31653163
31663164 @property
3167- def tangential_dims (self ):
3165+ def tangential_dims (self ) -> list [ str ] :
31683166 tangential_dims = ["x" , "y" , "z" ]
31693167 tangential_dims .pop (self .monitor .proj_axis )
31703168 return tangential_dims
0 commit comments