Skip to content

Commit 2c11697

Browse files
chore(tidy3d): FXC-4061-mypy-implement-type-defs-in-components-data
1 parent 383ef12 commit 2c11697

File tree

15 files changed

+189
-143
lines changed

15 files changed

+189
-143
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ files = [
328328
"tidy3d/web",
329329
"tidy3d/config",
330330
"tidy3d/material_library",
331-
"tidy3d/components/geometry"
331+
"tidy3d/components/geometry",
332+
"tidy3d/components/data",
332333
]
333334
ignore_missing_imports = true
334335
follow_imports = "skip"

tidy3d/components/data/data_array.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import xarray as xr
1515
from autograd.tracer import isbox
16+
from numpy.typing import NDArray
1617
from pydantic.annotated_handlers import GetCoreSchemaHandler
1718
from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue
1819
from pydantic_core import core_schema
@@ -24,7 +25,13 @@
2425
from xarray.core.variable import as_variable
2526

2627
from tidy3d.compat import alignment
27-
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
28+
from tidy3d.components.autograd import (
29+
InterpolationType,
30+
TidyArrayBox,
31+
get_static,
32+
interpn,
33+
is_tidy_box,
34+
)
2835
from tidy3d.components.geometry.bound_ops import bounds_contains
2936
from tidy3d.components.types import Axis, Bound
3037
from tidy3d.constants import (
@@ -80,7 +87,7 @@ class DataArray(xr.DataArray):
8087
# stores a dictionary of attributes corresponding to the data values
8188
_data_attrs: dict[str, str] = {}
8289

83-
def __init__(self, data, *args: Any, **kwargs: Any) -> None:
90+
def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None:
8491
# if data is a vanilla autograd box, convert to our box
8592
if isbox(data) and not is_tidy_box(data):
8693
data = TidyArrayBox.from_arraybox(data)
@@ -217,7 +224,7 @@ def _interp_validator(self, field_name: Optional[str] = None) -> None:
217224
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
218225
)
219226

220-
def __eq__(self, other) -> bool:
227+
def __eq__(self, other: Any) -> bool:
221228
"""Whether two data array objects are equal."""
222229

223230
if not isinstance(other, xr.DataArray):
@@ -231,7 +238,7 @@ def __eq__(self, other) -> bool:
231238
return True
232239

233240
@property
234-
def values(self):
241+
def values(self) -> NDArray:
235242
"""
236243
The array's data converted to a numpy.ndarray.
237244
@@ -247,18 +254,18 @@ def values(self, value: Any) -> None:
247254
self.variable.values = value
248255

249256
@property
250-
def abs(self):
257+
def abs(self) -> Self:
251258
"""Absolute value of data array."""
252259
return abs(self)
253260

254261
@property
255-
def angle(self):
262+
def angle(self) -> Self:
256263
"""Angle or phase value of data array."""
257264
values = np.angle(self.values)
258265
return type(self)(values, coords=self.coords)
259266

260267
@property
261-
def is_uniform(self):
268+
def is_uniform(self) -> bool:
262269
"""Whether each element is of equal value in the data array"""
263270
raw_data = self.data.ravel()
264271
return np.allclose(raw_data, raw_data[0])
@@ -471,7 +478,12 @@ def _ag_interp(
471478
return self._from_temp_dataset(ds)
472479

473480
@staticmethod
474-
def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
481+
def _ag_interp_func(
482+
var: xr.Variable,
483+
indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]],
484+
method: InterpolationType,
485+
**kwargs: Any,
486+
) -> xr.Variable:
475487
"""
476488
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.
477489
@@ -486,7 +498,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
486498
The variable to be interpolated.
487499
indexes_coords : dict
488500
A dictionary mapping dimension names to coordinate values for interpolation.
489-
method : str
501+
method : Literal["nearest", "linear"]
490502
The interpolation method to use.
491503
**kwargs : dict
492504
Additional keyword arguments to pass to the interpolation function.

tidy3d/components/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import xarray as xr
10+
from numpy.typing import ArrayLike
1011
from pydantic import Field
1112

1213
from tidy3d.components.base import Tidy3dBaseModel
@@ -83,7 +84,7 @@ def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArr
8384
"""How to package the dictionary of fields computed via self.colocate()."""
8485
return xr.Dataset(centered_fields)
8586

86-
def colocate(self, x=None, y=None, z=None) -> xr.Dataset:
87+
def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset:
8788
"""Colocate all of the data at a set of x, y, z coordinates.
8889
8990
Parameters

tidy3d/components/data/monitor_data.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from abc import ABC
88
from math import isclose
99
from os import PathLike
10-
from typing import Any, Callable, Literal, Optional, Union, get_args
10+
from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args
1111

1212
import autograd.numpy as np
1313
import xarray as xr
14+
from numpy.typing import NDArray
1415
from pandas import DataFrame
1516
from pydantic import Field, model_validator
1617

@@ -103,6 +104,14 @@
103104
# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104105
COS_THETA_THRESH = 1e-5
105106

107+
GRID_CORRECTION_TYPE = Union[
108+
float,
109+
FreqDataArray,
110+
TimeDataArray,
111+
FreqModeDataArray,
112+
EMEFreqModeDataArray,
113+
]
114+
106115

107116
class MonitorData(AbstractMonitorData, ABC):
108117
"""
@@ -171,7 +180,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
171180
return "-" if direction == "+" else "+"
172181

173182
@staticmethod
174-
def get_amplitude(x) -> complex:
183+
def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
175184
"""Get the complex amplitude out of some data."""
176185

177186
if isinstance(x, DataArray):
@@ -213,7 +222,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
213222
)
214223

215224
@model_validator(mode="after")
216-
def warn_missing_grid_expanded(self):
225+
def warn_missing_grid_expanded(self) -> Self:
217226
"""If ``grid_expanded`` not provided and fields data is present, warn that some methods
218227
will break."""
219228
field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]
@@ -226,15 +235,15 @@ def warn_missing_grid_expanded(self):
226235
)
227236
return self
228237

229-
_require_sym_center = required_if_symmetry_present("symmetry_center")
230-
_require_grid_expanded = required_if_symmetry_present("grid_expanded")
238+
_require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center")
239+
_require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded")
231240

232241
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
233242
"""Coordinates in the expanded grid corresponding to a given field component."""
234243
return self.grid_expanded[self.grid_locations[field_name]]
235244

236245
@property
237-
def symmetry_expanded(self):
246+
def symmetry_expanded(self) -> Self:
238247
"""Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
239248
any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
240249
data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -391,27 +400,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
391400
class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC):
392401
"""Collection of electromagnetic fields."""
393402

394-
grid_primal_correction: Union[
395-
float,
396-
FreqDataArray,
397-
TimeDataArray,
398-
FreqModeDataArray,
399-
EMEFreqModeDataArray,
400-
] = Field(
403+
grid_primal_correction: GRID_CORRECTION_TYPE = Field(
401404
1.0,
402405
title="Field correction factor",
403406
description="Correction factor that needs to be applied for data corresponding to a 2D "
404407
"monitor to take into account the finite grid in the normal direction in the simulation in "
405408
"which the data was computed. The factor is applied to fields defined on the primal grid "
406409
"locations along the normal direction.",
407410
)
408-
grid_dual_correction: Union[
409-
float,
410-
FreqDataArray,
411-
TimeDataArray,
412-
FreqModeDataArray,
413-
EMEFreqModeDataArray,
414-
] = Field(
411+
grid_dual_correction: GRID_CORRECTION_TYPE = Field(
415412
1.0,
416413
title="Field correction factor",
417414
description="Correction factor that needs to be applied for data corresponding to a 2D "
@@ -420,15 +417,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
420417
"locations along the normal direction.",
421418
)
422419

423-
def _expanded_grid_field_coords(self, field_name: str):
420+
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
424421
"""Coordinates in the expanded grid corresponding to a given field component."""
425422
if self.monitor.colocate:
426423
bounds_dict = self.grid_expanded.boundaries.to_dict
427424
return Coords(**{key: val[:-1] for key, val in bounds_dict.items()})
428425
return self.grid_expanded[self.grid_locations[field_name]]
429426

430427
@property
431-
def _grid_correction_dict(self):
428+
def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]:
432429
"""Return the primal and dual finite grid correction factors as a dictionary."""
433430
return {
434431
"grid_primal_correction": self.grid_primal_correction,
@@ -937,7 +934,7 @@ def outer_dot(
937934
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()
938935

939936
# function to apply at each pair of mode indices before integrating
940-
def fn(fields_1, fields_2):
937+
def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray:
941938
e_self_1 = fields_1[e_1]
942939
e_self_2 = fields_1[e_2]
943940
h_self_1 = fields_1[h_1]
@@ -978,7 +975,7 @@ def _outer_fn_summation(
978975
outer_dim_1: str,
979976
outer_dim_2: str,
980977
sum_dims: list[str],
981-
fn: Callable,
978+
fn: Callable[[dict[str, NDArray], NDArray], NDArray],
982979
) -> DataArray:
983980
"""
984981
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1676,7 +1673,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16761673
)
16771674

16781675
@model_validator(mode="after")
1679-
def eps_spec_match_mode_spec(self):
1676+
def eps_spec_match_mode_spec(self) -> Self:
16801677
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
16811678
if self.eps_spec:
16821679
mode_data_freqs = self.monitor.freqs
@@ -1686,7 +1683,7 @@ def eps_spec_match_mode_spec(self):
16861683
)
16871684
return self
16881685

1689-
def normalize(self, source_spectrum_fn) -> ModeData:
1686+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self:
16901687
"""Return copy of self after normalization is applied using source spectrum function."""
16911688
source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None]
16921689
new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype)
@@ -1808,7 +1805,7 @@ def overlap_sort(
18081805

18091806
return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False)
18101807

1811-
def _isel(self, **isel_kwargs: Any):
1808+
def _isel(self, **isel_kwargs: Any) -> Self:
18121809
"""Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
18131810
mode index. Used in ``overlap_sort`` but not officially supported since for example
18141811
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1822,12 +1819,11 @@ def _isel(self, **isel_kwargs: Any):
18221819
}
18231820
return self.updated_copy(**update_dict, deep=False, validate=False)
18241821

1825-
def _assign_coords(self, **assign_coords_kwargs: Any):
1822+
def _assign_coords(self, **assign_coords_kwargs: Any) -> Self:
18261823
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18271824
mode index. Used in ``overlap_sort`` but not officially supported since for example
18281825
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
18291826
newly created data."""
1830-
18311827
update_dict = dict(self._grid_correction_dict, **self.field_components)
18321828
update_dict = {
18331829
key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items()
@@ -2243,7 +2239,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22432239

22442240
return src_adj
22452241

2246-
def _apply_mode_reorder(self, sort_inds_2d):
2242+
def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self:
22472243
"""Apply a mode reordering along mode_index for all frequency indices.
22482244
22492245
Parameters
@@ -2342,7 +2338,7 @@ def sort_modes(
23422338
sort_inds_2d = np.tile(identity, (num_freqs, 1))
23432339

23442340
# Helper to compute ordered indices within a subset
2345-
def _order_indices(indices, vals_all):
2341+
def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray:
23462342
if indices.size == 0:
23472343
return indices
23482344
vals = vals_all.isel(mode_index=indices)
@@ -2461,7 +2457,7 @@ class ModeSolverData(ModeData):
24612457
description="Unused for ModeSolverData.",
24622458
)
24632459

2464-
def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData:
2460+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData:
24652461
"""Return copy of self after normalization is applied using source spectrum function."""
24662462
return self.copy()
24672463

@@ -2554,7 +2550,7 @@ def _make_adjoint_sources(
25542550
"computation."
25552551
)
25562552

2557-
def normalize(self, source_spectrum_fn) -> FluxData:
2553+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData:
25582554
"""Return copy of self after normalization is applied using source spectrum function."""
25592555
source_freq_amps = source_spectrum_fn(self.flux.f)
25602556
source_power = abs(source_freq_amps) ** 2
@@ -3164,7 +3160,7 @@ def z(self) -> np.ndarray:
31643160
return self.Etheta.z.values
31653161

31663162
@property
3167-
def tangential_dims(self):
3163+
def tangential_dims(self) -> list[str]:
31683164
tangential_dims = ["x", "y", "z"]
31693165
tangential_dims.pop(self.monitor.proj_axis)
31703166
return tangential_dims

0 commit comments

Comments
 (0)