55import struct
66import warnings
77from abc import ABC
8+ from collections .abc import Mapping
89from math import isclose
910from os import PathLike
10- from typing import Any , Callable , Literal , Optional , Union , get_args
11+ from typing import Any , Callable , Literal , Optional , Self , SupportsComplex , Union , get_args
1112
1213import autograd .numpy as np
1314import xarray as xr
15+ from numpy ._typing import NDArray
1416from pandas import DataFrame
1517from pydantic import Field , model_validator
1618
103105# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104106COS_THETA_THRESH = 1e-5
105107
108+ GRID_CORRECTION_TYPE = Union [
109+ float ,
110+ FreqDataArray ,
111+ TimeDataArray ,
112+ FreqModeDataArray ,
113+ EMEFreqModeDataArray ,
114+ ]
115+
106116
107117class MonitorData (AbstractMonitorData , ABC ):
108118 """
@@ -149,7 +159,7 @@ def amplitude_fn(freq: list[float]) -> complex:
149159
150160 return self .normalize (amplitude_fn )
151161
152- def _updated (self , update : dict ) -> MonitorData :
162+ def _updated (self , update : dict ) -> Self :
153163 """Similar to ``updated_copy``, but does not actually copy components, for speed.
154164
155165 Note
@@ -185,7 +195,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
185195 return "-" if direction == "+" else "+"
186196
187197 @staticmethod
188- def get_amplitude (x ) -> complex :
198+ def get_amplitude (x : Union [ DataArray , SupportsComplex ] ) -> complex :
189199 """Get the complex amplitude out of some data."""
190200
191201 if isinstance (x , DataArray ):
@@ -227,7 +237,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
227237 )
228238
229239 @model_validator (mode = "after" )
230- def warn_missing_grid_expanded (self ):
240+ def warn_missing_grid_expanded (self ) -> Self :
231241 """If ``grid_expanded`` not provided and fields data is present, warn that some methods
232242 will break."""
233243 field_comps = ["Ex" , "Ey" , "Ez" , "Hx" , "Hy" , "Hz" ]
@@ -240,15 +250,15 @@ def warn_missing_grid_expanded(self):
240250 )
241251 return self
242252
243- _require_sym_center = required_if_symmetry_present ("symmetry_center" )
244- _require_grid_expanded = required_if_symmetry_present ("grid_expanded" )
253+ _require_sym_center : Callable [[ Any ], Any ] = required_if_symmetry_present ("symmetry_center" )
254+ _require_grid_expanded : Callable [[ Any ], Any ] = required_if_symmetry_present ("grid_expanded" )
245255
246256 def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
247257 """Coordinates in the expanded grid corresponding to a given field component."""
248258 return self .grid_expanded [self .grid_locations [field_name ]]
249259
250260 @property
251- def symmetry_expanded (self ):
261+ def symmetry_expanded (self ) -> Self :
252262 """Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
253263 any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
254264 data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -405,27 +415,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
405415class ElectromagneticFieldData (AbstractFieldData , ElectromagneticFieldDataset , ABC ):
406416 """Collection of electromagnetic fields."""
407417
408- grid_primal_correction : Union [
409- float ,
410- FreqDataArray ,
411- TimeDataArray ,
412- FreqModeDataArray ,
413- EMEFreqModeDataArray ,
414- ] = Field (
418+ grid_primal_correction : GRID_CORRECTION_TYPE = Field (
415419 1.0 ,
416420 title = "Field correction factor" ,
417421 description = "Correction factor that needs to be applied for data corresponding to a 2D "
418422 "monitor to take into account the finite grid in the normal direction in the simulation in "
419423 "which the data was computed. The factor is applied to fields defined on the primal grid "
420424 "locations along the normal direction." ,
421425 )
422- grid_dual_correction : Union [
423- float ,
424- FreqDataArray ,
425- TimeDataArray ,
426- FreqModeDataArray ,
427- EMEFreqModeDataArray ,
428- ] = Field (
426+ grid_dual_correction : GRID_CORRECTION_TYPE = Field (
429427 1.0 ,
430428 title = "Field correction factor" ,
431429 description = "Correction factor that needs to be applied for data corresponding to a 2D "
@@ -434,15 +432,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
434432 "locations along the normal direction." ,
435433 )
436434
437- def _expanded_grid_field_coords (self , field_name : str ):
435+ def _expanded_grid_field_coords (self , field_name : str ) -> Coords :
438436 """Coordinates in the expanded grid corresponding to a given field component."""
439437 if self .monitor .colocate :
440438 bounds_dict = self .grid_expanded .boundaries .to_dict
441439 return Coords (** {key : val [:- 1 ] for key , val in bounds_dict .items ()})
442440 return self .grid_expanded [self .grid_locations [field_name ]]
443441
444442 @property
445- def _grid_correction_dict (self ):
443+ def _grid_correction_dict (self ) -> dict [ str , GRID_CORRECTION_TYPE ] :
446444 """Return the primal and dual finite grid correction factors as a dictionary."""
447445 return {
448446 "grid_primal_correction" : self .grid_primal_correction ,
@@ -919,7 +917,7 @@ def outer_dot(
919917 d_area = self ._diff_area .expand_dims (dim = {"f" : f }, axis = 2 ).to_numpy ()
920918
921919 # function to apply at each pair of mode indices before integrating
922- def fn (fields_1 , fields_2 ) :
920+ def fn (fields_1 : dict [ str , NDArray ], fields_2 : dict [ str , NDArray ]) -> NDArray :
923921 e_self_1 = fields_1 [e_1 ]
924922 e_self_2 = fields_1 [e_2 ]
925923 h_self_1 = fields_1 [h_1 ]
@@ -960,7 +958,7 @@ def _outer_fn_summation(
960958 outer_dim_1 : str ,
961959 outer_dim_2 : str ,
962960 sum_dims : list [str ],
963- fn : Callable ,
961+ fn : Callable [[ dict [ str , NDArray ], NDArray ], NDArray ] ,
964962 ) -> DataArray :
965963 """
966964 Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1658,7 +1656,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16581656 )
16591657
16601658 @model_validator (mode = "after" )
1661- def eps_spec_match_mode_spec (self ):
1659+ def eps_spec_match_mode_spec (self ) -> Self :
16621660 """Raise validation error if frequencies in eps_spec does not match frequency list"""
16631661 if self .eps_spec :
16641662 mode_data_freqs = self .monitor .freqs
@@ -1668,7 +1666,7 @@ def eps_spec_match_mode_spec(self):
16681666 )
16691667 return self
16701668
1671- def normalize (self , source_spectrum_fn ) -> ModeData :
1669+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> Self :
16721670 """Return copy of self after normalization is applied using source spectrum function."""
16731671 source_freq_amps = source_spectrum_fn (self .amps .f )[None , :, None ]
16741672 new_amps = (self .amps / source_freq_amps ).astype (self .amps .dtype )
@@ -1789,7 +1787,7 @@ def overlap_sort(
17891787
17901788 return data_reordered .updated_copy (monitor = monitor_updated , deep = False , validate = False )
17911789
1792- def _isel (self , ** isel_kwargs : Any ):
1790+ def _isel (self , ** isel_kwargs : Any ) -> Self :
17931791 """Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
17941792 mode index. Used in ``overlap_sort`` but not officially supported since for example
17951793 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1803,7 +1801,7 @@ def _isel(self, **isel_kwargs: Any):
18031801 }
18041802 return self ._updated (update = update_dict )
18051803
1806- def _assign_coords (self , ** assign_coords_kwargs : Any ) :
1804+ def _assign_coords (self , ** assign_coords_kwargs : Optional [ Mapping ]) -> Self :
18071805 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18081806 mode index. Used in ``overlap_sort`` but not officially supported since for example
18091807 ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -2210,7 +2208,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22102208
22112209 return src_adj
22122210
2213- def _apply_mode_reorder (self , sort_inds_2d ) :
2211+ def _apply_mode_reorder (self , sort_inds_2d : NDArray ) -> Self :
22142212 """Apply a mode reordering along mode_index for all frequency indices.
22152213
22162214 Parameters
@@ -2273,7 +2271,7 @@ def sort_modes(
22732271 sort_inds_2d = np .tile (identity , (num_freqs , 1 ))
22742272
22752273 # Helper to compute ordered indices within a subset
2276- def _order_indices (indices , vals_all ) :
2274+ def _order_indices (indices : NDArray , vals_all : DataArray ) -> NDArray :
22772275 if indices .size == 0 :
22782276 return indices
22792277 vals = vals_all .isel (mode_index = indices )
@@ -2392,7 +2390,7 @@ class ModeSolverData(ModeData):
23922390 description = "Unused for ModeSolverData." ,
23932391 )
23942392
2395- def normalize (self , source_spectrum_fn : Callable [[float ], complex ]) -> ModeSolverData :
2393+ def normalize (self , source_spectrum_fn : Callable [[DataArray ], NDArray ]) -> ModeSolverData :
23962394 """Return copy of self after normalization is applied using source spectrum function."""
23972395 return self .copy ()
23982396
@@ -2485,7 +2483,7 @@ def _make_adjoint_sources(
24852483 "computation."
24862484 )
24872485
2488- def normalize (self , source_spectrum_fn ) -> FluxData :
2486+ def normalize (self , source_spectrum_fn : Callable [[ DataArray ], NDArray ] ) -> FluxData :
24892487 """Return copy of self after normalization is applied using source spectrum function."""
24902488 source_freq_amps = source_spectrum_fn (self .flux .f )
24912489 source_power = abs (source_freq_amps ) ** 2
@@ -3095,7 +3093,7 @@ def z(self) -> np.ndarray:
30953093 return self .Etheta .z .values
30963094
30973095 @property
3098- def tangential_dims (self ):
3096+ def tangential_dims (self ) -> list [ str ] :
30993097 tangential_dims = ["x" , "y" , "z" ]
31003098 tangential_dims .pop (self .monitor .proj_axis )
31013099 return tangential_dims
0 commit comments