Skip to content

Commit a8e97dc

Browse files
chore(tidy3d): FXC-4061-mypy-implement-type-defs-in-components-data
1 parent 2c6e3c8 commit a8e97dc

File tree

12 files changed

+174
-129
lines changed

12 files changed

+174
-129
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ python_files = "*.py"
326326
python_version = "3.10"
327327
files = [
328328
"tidy3d/web",
329+
"tidy3d/components/data",
329330
]
330331
ignore_missing_imports = true
331332
follow_imports = "skip"

tidy3d/components/data/data_array.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import xarray as xr
1515
from autograd.tracer import isbox
16+
from numpy._typing import NDArray
1617
from pydantic.annotated_handlers import GetCoreSchemaHandler
1718
from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue
1819
from pydantic_core import core_schema
@@ -24,7 +25,13 @@
2425
from xarray.core.variable import as_variable
2526

2627
from tidy3d.compat import alignment
27-
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
28+
from tidy3d.components.autograd import (
29+
InterpolationType,
30+
TidyArrayBox,
31+
get_static,
32+
interpn,
33+
is_tidy_box,
34+
)
2835
from tidy3d.components.geometry.bound_ops import bounds_contains
2936
from tidy3d.components.types import Axis, Bound
3037
from tidy3d.constants import (
@@ -80,7 +87,7 @@ class DataArray(xr.DataArray):
8087
# stores a dictionary of attributes corresponding to the data values
8188
_data_attrs: dict[str, str] = {}
8289

83-
def __init__(self, data, *args: Any, **kwargs: Any) -> None:
90+
def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None:
8491
# if data is a vanilla autograd box, convert to our box
8592
if isbox(data) and not is_tidy_box(data):
8693
data = TidyArrayBox.from_arraybox(data)
@@ -217,7 +224,7 @@ def _interp_validator(self, field_name: Optional[str] = None) -> None:
217224
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
218225
)
219226

220-
def __eq__(self, other) -> bool:
227+
def __eq__(self, other: Any) -> bool:
221228
"""Whether two data array objects are equal."""
222229

223230
if not isinstance(other, xr.DataArray):
@@ -231,7 +238,7 @@ def __eq__(self, other) -> bool:
231238
return True
232239

233240
@property
234-
def values(self):
241+
def values(self) -> NDArray:
235242
"""
236243
The array's data converted to a numpy.ndarray.
237244
@@ -247,18 +254,18 @@ def values(self, value: Any) -> None:
247254
self.variable.values = value
248255

249256
@property
250-
def abs(self):
257+
def abs(self) -> Self:
251258
"""Absolute value of data array."""
252259
return abs(self)
253260

254261
@property
255-
def angle(self):
262+
def angle(self) -> Self:
256263
"""Angle or phase value of data array."""
257264
values = np.angle(self.values)
258265
return type(self)(values, coords=self.coords)
259266

260267
@property
261-
def is_uniform(self):
268+
def is_uniform(self) -> bool:
262269
"""Whether each element is of equal value in the data array"""
263270
raw_data = self.data.ravel()
264271
return np.allclose(raw_data, raw_data[0])
@@ -471,7 +478,12 @@ def _ag_interp(
471478
return self._from_temp_dataset(ds)
472479

473480
@staticmethod
474-
def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
481+
def _ag_interp_func(
482+
var: xr.Variable,
483+
indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]],
484+
method: InterpolationType,
485+
**kwargs: Any,
486+
) -> xr.Variable:
475487
"""
476488
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.
477489
@@ -486,7 +498,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
486498
The variable to be interpolated.
487499
indexes_coords : dict
488500
A dictionary mapping dimension names to coordinate values for interpolation.
489-
method : str
501+
method : Literal["nearest", "linear"]
490502
The interpolation method to use.
491503
**kwargs : dict
492504
Additional keyword arguments to pass to the interpolation function.

tidy3d/components/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import xarray as xr
10+
from numpy._typing import ArrayLike
1011
from pydantic import Field
1112

1213
from tidy3d.components.base import Tidy3dBaseModel
@@ -83,7 +84,7 @@ def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArr
8384
"""How to package the dictionary of fields computed via self.colocate()."""
8485
return xr.Dataset(centered_fields)
8586

86-
def colocate(self, x=None, y=None, z=None) -> xr.Dataset:
87+
def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset:
8788
"""Colocate all of the data at a set of x, y, z coordinates.
8889
8990
Parameters

tidy3d/components/data/monitor_data.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import struct
66
import warnings
77
from abc import ABC
8+
from collections.abc import Mapping
89
from math import isclose
910
from os import PathLike
10-
from typing import Any, Callable, Literal, Optional, Union, get_args
11+
from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args
1112

1213
import autograd.numpy as np
1314
import xarray as xr
15+
from numpy._typing import NDArray
1416
from pandas import DataFrame
1517
from pydantic import Field, model_validator
1618

@@ -103,6 +105,14 @@
103105
# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104106
COS_THETA_THRESH = 1e-5
105107

108+
GRID_CORRECTION_TYPE = Union[
109+
float,
110+
FreqDataArray,
111+
TimeDataArray,
112+
FreqModeDataArray,
113+
EMEFreqModeDataArray,
114+
]
115+
106116

107117
class MonitorData(AbstractMonitorData, ABC):
108118
"""
@@ -171,7 +181,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
171181
return "-" if direction == "+" else "+"
172182

173183
@staticmethod
174-
def get_amplitude(x) -> complex:
184+
def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
175185
"""Get the complex amplitude out of some data."""
176186

177187
if isinstance(x, DataArray):
@@ -213,7 +223,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
213223
)
214224

215225
@model_validator(mode="after")
216-
def warn_missing_grid_expanded(self):
226+
def warn_missing_grid_expanded(self) -> Self:
217227
"""If ``grid_expanded`` not provided and fields data is present, warn that some methods
218228
will break."""
219229
field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]
@@ -226,15 +236,15 @@ def warn_missing_grid_expanded(self):
226236
)
227237
return self
228238

229-
_require_sym_center = required_if_symmetry_present("symmetry_center")
230-
_require_grid_expanded = required_if_symmetry_present("grid_expanded")
239+
_require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center")
240+
_require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded")
231241

232242
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
233243
"""Coordinates in the expanded grid corresponding to a given field component."""
234244
return self.grid_expanded[self.grid_locations[field_name]]
235245

236246
@property
237-
def symmetry_expanded(self):
247+
def symmetry_expanded(self) -> Self:
238248
"""Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
239249
any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
240250
data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -391,27 +401,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
391401
class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC):
392402
"""Collection of electromagnetic fields."""
393403

394-
grid_primal_correction: Union[
395-
float,
396-
FreqDataArray,
397-
TimeDataArray,
398-
FreqModeDataArray,
399-
EMEFreqModeDataArray,
400-
] = Field(
404+
grid_primal_correction: GRID_CORRECTION_TYPE = Field(
401405
1.0,
402406
title="Field correction factor",
403407
description="Correction factor that needs to be applied for data corresponding to a 2D "
404408
"monitor to take into account the finite grid in the normal direction in the simulation in "
405409
"which the data was computed. The factor is applied to fields defined on the primal grid "
406410
"locations along the normal direction.",
407411
)
408-
grid_dual_correction: Union[
409-
float,
410-
FreqDataArray,
411-
TimeDataArray,
412-
FreqModeDataArray,
413-
EMEFreqModeDataArray,
414-
] = Field(
412+
grid_dual_correction: GRID_CORRECTION_TYPE = Field(
415413
1.0,
416414
title="Field correction factor",
417415
description="Correction factor that needs to be applied for data corresponding to a 2D "
@@ -420,15 +418,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
420418
"locations along the normal direction.",
421419
)
422420

423-
def _expanded_grid_field_coords(self, field_name: str):
421+
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
424422
"""Coordinates in the expanded grid corresponding to a given field component."""
425423
if self.monitor.colocate:
426424
bounds_dict = self.grid_expanded.boundaries.to_dict
427425
return Coords(**{key: val[:-1] for key, val in bounds_dict.items()})
428426
return self.grid_expanded[self.grid_locations[field_name]]
429427

430428
@property
431-
def _grid_correction_dict(self):
429+
def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]:
432430
"""Return the primal and dual finite grid correction factors as a dictionary."""
433431
return {
434432
"grid_primal_correction": self.grid_primal_correction,
@@ -937,7 +935,7 @@ def outer_dot(
937935
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()
938936

939937
# function to apply at each pair of mode indices before integrating
940-
def fn(fields_1, fields_2):
938+
def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray:
941939
e_self_1 = fields_1[e_1]
942940
e_self_2 = fields_1[e_2]
943941
h_self_1 = fields_1[h_1]
@@ -978,7 +976,7 @@ def _outer_fn_summation(
978976
outer_dim_1: str,
979977
outer_dim_2: str,
980978
sum_dims: list[str],
981-
fn: Callable,
979+
fn: Callable[[dict[str, NDArray], NDArray], NDArray],
982980
) -> DataArray:
983981
"""
984982
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1676,7 +1674,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16761674
)
16771675

16781676
@model_validator(mode="after")
1679-
def eps_spec_match_mode_spec(self):
1677+
def eps_spec_match_mode_spec(self) -> Self:
16801678
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
16811679
if self.eps_spec:
16821680
mode_data_freqs = self.monitor.freqs
@@ -1686,7 +1684,7 @@ def eps_spec_match_mode_spec(self):
16861684
)
16871685
return self
16881686

1689-
def normalize(self, source_spectrum_fn) -> ModeData:
1687+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self:
16901688
"""Return copy of self after normalization is applied using source spectrum function."""
16911689
source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None]
16921690
new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype)
@@ -1808,7 +1806,7 @@ def overlap_sort(
18081806

18091807
return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False)
18101808

1811-
def _isel(self, **isel_kwargs: Any):
1809+
def _isel(self, **isel_kwargs: Any) -> Self:
18121810
"""Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
18131811
mode index. Used in ``overlap_sort`` but not officially supported since for example
18141812
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1822,7 +1820,7 @@ def _isel(self, **isel_kwargs: Any):
18221820
}
18231821
return self.updated_copy(**update_dict, deep=False, validate=False)
18241822

1825-
def _assign_coords(self, **assign_coords_kwargs: Any):
1823+
def _assign_coords(self, **assign_coords_kwargs: Optional[Mapping]) -> Self:
18261824
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18271825
mode index. Used in ``overlap_sort`` but not officially supported since for example
18281826
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -2243,7 +2241,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22432241

22442242
return src_adj
22452243

2246-
def _apply_mode_reorder(self, sort_inds_2d):
2244+
def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self:
22472245
"""Apply a mode reordering along mode_index for all frequency indices.
22482246
22492247
Parameters
@@ -2342,7 +2340,7 @@ def sort_modes(
23422340
sort_inds_2d = np.tile(identity, (num_freqs, 1))
23432341

23442342
# Helper to compute ordered indices within a subset
2345-
def _order_indices(indices, vals_all):
2343+
def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray:
23462344
if indices.size == 0:
23472345
return indices
23482346
vals = vals_all.isel(mode_index=indices)
@@ -2461,7 +2459,7 @@ class ModeSolverData(ModeData):
24612459
description="Unused for ModeSolverData.",
24622460
)
24632461

2464-
def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData:
2462+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData:
24652463
"""Return copy of self after normalization is applied using source spectrum function."""
24662464
return self.copy()
24672465

@@ -2554,7 +2552,7 @@ def _make_adjoint_sources(
25542552
"computation."
25552553
)
25562554

2557-
def normalize(self, source_spectrum_fn) -> FluxData:
2555+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData:
25582556
"""Return copy of self after normalization is applied using source spectrum function."""
25592557
source_freq_amps = source_spectrum_fn(self.flux.f)
25602558
source_power = abs(source_freq_amps) ** 2
@@ -3164,7 +3162,7 @@ def z(self) -> np.ndarray:
31643162
return self.Etheta.z.values
31653163

31663164
@property
3167-
def tangential_dims(self):
3165+
def tangential_dims(self) -> list[str]:
31683166
tangential_dims = ["x", "y", "z"]
31693167
tangential_dims.pop(self.monitor.proj_axis)
31703168
return tangential_dims

0 commit comments

Comments
 (0)