Skip to content

Commit f95897a

Browse files
chore(tidy3d): FXC-4061-mypy-implement-type-defs-in-components-data
1 parent f6aaf9b commit f95897a

File tree

12 files changed

+175
-130
lines changed

12 files changed

+175
-130
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ python_files = "*.py"
326326
python_version = "3.10"
327327
files = [
328328
"tidy3d/web",
329+
"tidy3d/components/data",
329330
]
330331
ignore_missing_imports = true
331332
follow_imports = "skip"

tidy3d/components/data/data_array.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import xarray as xr
1515
from autograd.tracer import isbox
16+
from numpy._typing import NDArray
1617
from pydantic.annotated_handlers import GetCoreSchemaHandler
1718
from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue
1819
from pydantic_core import core_schema
@@ -24,7 +25,13 @@
2425
from xarray.core.variable import as_variable
2526

2627
from tidy3d.compat import alignment
27-
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
28+
from tidy3d.components.autograd import (
29+
InterpolationType,
30+
TidyArrayBox,
31+
get_static,
32+
interpn,
33+
is_tidy_box,
34+
)
2835
from tidy3d.components.geometry.bound_ops import bounds_contains
2936
from tidy3d.components.types import Axis, Bound
3037
from tidy3d.constants import (
@@ -80,7 +87,7 @@ class DataArray(xr.DataArray):
8087
# stores a dictionary of attributes corresponding to the data values
8188
_data_attrs: dict[str, str] = {}
8289

83-
def __init__(self, data, *args: Any, **kwargs: Any) -> None:
90+
def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None:
8491
# if data is a vanilla autograd box, convert to our box
8592
if isbox(data) and not is_tidy_box(data):
8693
data = TidyArrayBox.from_arraybox(data)
@@ -217,7 +224,7 @@ def _interp_validator(self, field_name: Optional[str] = None) -> None:
217224
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
218225
)
219226

220-
def __eq__(self, other) -> bool:
227+
def __eq__(self, other: Any) -> bool:
221228
"""Whether two data array objects are equal."""
222229

223230
if not isinstance(other, xr.DataArray):
@@ -231,7 +238,7 @@ def __eq__(self, other) -> bool:
231238
return True
232239

233240
@property
234-
def values(self):
241+
def values(self) -> NDArray:
235242
"""
236243
The array's data converted to a numpy.ndarray.
237244
@@ -247,18 +254,18 @@ def values(self, value: Any) -> None:
247254
self.variable.values = value
248255

249256
@property
250-
def abs(self):
257+
def abs(self) -> Self:
251258
"""Absolute value of data array."""
252259
return abs(self)
253260

254261
@property
255-
def angle(self):
262+
def angle(self) -> Self:
256263
"""Angle or phase value of data array."""
257264
values = np.angle(self.values)
258265
return type(self)(values, coords=self.coords)
259266

260267
@property
261-
def is_uniform(self):
268+
def is_uniform(self) -> bool:
262269
"""Whether each element is of equal value in the data array"""
263270
raw_data = self.data.ravel()
264271
return np.allclose(raw_data, raw_data[0])
@@ -471,7 +478,12 @@ def _ag_interp(
471478
return self._from_temp_dataset(ds)
472479

473480
@staticmethod
474-
def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
481+
def _ag_interp_func(
482+
var: xr.Variable,
483+
indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]],
484+
method: InterpolationType,
485+
**kwargs: Any,
486+
) -> xr.Variable:
475487
"""
476488
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.
477489
@@ -486,7 +498,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
486498
The variable to be interpolated.
487499
indexes_coords : dict
488500
A dictionary mapping dimension names to coordinate values for interpolation.
489-
method : str
501+
method : Literal["nearest", "linear"]
490502
The interpolation method to use.
491503
**kwargs : dict
492504
Additional keyword arguments to pass to the interpolation function.

tidy3d/components/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import xarray as xr
10+
from numpy._typing import ArrayLike
1011
from pydantic import Field
1112

1213
from tidy3d.components.base import Tidy3dBaseModel
@@ -83,7 +84,7 @@ def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArr
8384
"""How to package the dictionary of fields computed via self.colocate()."""
8485
return xr.Dataset(centered_fields)
8586

86-
def colocate(self, x=None, y=None, z=None) -> xr.Dataset:
87+
def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset:
8788
"""Colocate all of the data at a set of x, y, z coordinates.
8889
8990
Parameters

tidy3d/components/data/monitor_data.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import struct
66
import warnings
77
from abc import ABC
8+
from collections.abc import Mapping
89
from math import isclose
910
from os import PathLike
10-
from typing import Any, Callable, Literal, Optional, Union, get_args
11+
from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args
1112

1213
import autograd.numpy as np
1314
import xarray as xr
15+
from numpy._typing import NDArray
1416
from pandas import DataFrame
1517
from pydantic import Field, model_validator
1618

@@ -103,6 +105,14 @@
103105
# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
104106
COS_THETA_THRESH = 1e-5
105107

108+
GRID_CORRECTION_TYPE = Union[
109+
float,
110+
FreqDataArray,
111+
TimeDataArray,
112+
FreqModeDataArray,
113+
EMEFreqModeDataArray,
114+
]
115+
106116

107117
class MonitorData(AbstractMonitorData, ABC):
108118
"""
@@ -149,7 +159,7 @@ def amplitude_fn(freq: list[float]) -> complex:
149159

150160
return self.normalize(amplitude_fn)
151161

152-
def _updated(self, update: dict) -> MonitorData:
162+
def _updated(self, update: dict) -> Self:
153163
"""Similar to ``updated_copy``, but does not actually copy components, for speed.
154164
155165
Note
@@ -185,7 +195,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
185195
return "-" if direction == "+" else "+"
186196

187197
@staticmethod
188-
def get_amplitude(x) -> complex:
198+
def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
189199
"""Get the complex amplitude out of some data."""
190200

191201
if isinstance(x, DataArray):
@@ -227,7 +237,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
227237
)
228238

229239
@model_validator(mode="after")
230-
def warn_missing_grid_expanded(self):
240+
def warn_missing_grid_expanded(self) -> Self:
231241
"""If ``grid_expanded`` not provided and fields data is present, warn that some methods
232242
will break."""
233243
field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]
@@ -240,15 +250,15 @@ def warn_missing_grid_expanded(self):
240250
)
241251
return self
242252

243-
_require_sym_center = required_if_symmetry_present("symmetry_center")
244-
_require_grid_expanded = required_if_symmetry_present("grid_expanded")
253+
_require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center")
254+
_require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded")
245255

246256
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
247257
"""Coordinates in the expanded grid corresponding to a given field component."""
248258
return self.grid_expanded[self.grid_locations[field_name]]
249259

250260
@property
251-
def symmetry_expanded(self):
261+
def symmetry_expanded(self) -> Self:
252262
"""Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
253263
any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
254264
data array. However, if symmetry is not expanded, the returned array contains a view of
@@ -405,27 +415,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
405415
class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC):
406416
"""Collection of electromagnetic fields."""
407417

408-
grid_primal_correction: Union[
409-
float,
410-
FreqDataArray,
411-
TimeDataArray,
412-
FreqModeDataArray,
413-
EMEFreqModeDataArray,
414-
] = Field(
418+
grid_primal_correction: GRID_CORRECTION_TYPE = Field(
415419
1.0,
416420
title="Field correction factor",
417421
description="Correction factor that needs to be applied for data corresponding to a 2D "
418422
"monitor to take into account the finite grid in the normal direction in the simulation in "
419423
"which the data was computed. The factor is applied to fields defined on the primal grid "
420424
"locations along the normal direction.",
421425
)
422-
grid_dual_correction: Union[
423-
float,
424-
FreqDataArray,
425-
TimeDataArray,
426-
FreqModeDataArray,
427-
EMEFreqModeDataArray,
428-
] = Field(
426+
grid_dual_correction: GRID_CORRECTION_TYPE = Field(
429427
1.0,
430428
title="Field correction factor",
431429
description="Correction factor that needs to be applied for data corresponding to a 2D "
@@ -434,15 +432,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
434432
"locations along the normal direction.",
435433
)
436434

437-
def _expanded_grid_field_coords(self, field_name: str):
435+
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
438436
"""Coordinates in the expanded grid corresponding to a given field component."""
439437
if self.monitor.colocate:
440438
bounds_dict = self.grid_expanded.boundaries.to_dict
441439
return Coords(**{key: val[:-1] for key, val in bounds_dict.items()})
442440
return self.grid_expanded[self.grid_locations[field_name]]
443441

444442
@property
445-
def _grid_correction_dict(self):
443+
def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]:
446444
"""Return the primal and dual finite grid correction factors as a dictionary."""
447445
return {
448446
"grid_primal_correction": self.grid_primal_correction,
@@ -919,7 +917,7 @@ def outer_dot(
919917
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()
920918

921919
# function to apply at each pair of mode indices before integrating
922-
def fn(fields_1, fields_2):
920+
def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray:
923921
e_self_1 = fields_1[e_1]
924922
e_self_2 = fields_1[e_2]
925923
h_self_1 = fields_1[h_1]
@@ -960,7 +958,7 @@ def _outer_fn_summation(
960958
outer_dim_1: str,
961959
outer_dim_2: str,
962960
sum_dims: list[str],
963-
fn: Callable,
961+
fn: Callable[[dict[str, NDArray], NDArray], NDArray],
964962
) -> DataArray:
965963
"""
966964
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
@@ -1658,7 +1656,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
16581656
)
16591657

16601658
@model_validator(mode="after")
1661-
def eps_spec_match_mode_spec(self):
1659+
def eps_spec_match_mode_spec(self) -> Self:
16621660
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
16631661
if self.eps_spec:
16641662
mode_data_freqs = self.monitor.freqs
@@ -1668,7 +1666,7 @@ def eps_spec_match_mode_spec(self):
16681666
)
16691667
return self
16701668

1671-
def normalize(self, source_spectrum_fn) -> ModeData:
1669+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self:
16721670
"""Return copy of self after normalization is applied using source spectrum function."""
16731671
source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None]
16741672
new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype)
@@ -1789,7 +1787,7 @@ def overlap_sort(
17891787

17901788
return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False)
17911789

1792-
def _isel(self, **isel_kwargs: Any):
1790+
def _isel(self, **isel_kwargs: Any) -> Self:
17931791
"""Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
17941792
mode index. Used in ``overlap_sort`` but not officially supported since for example
17951793
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -1803,7 +1801,7 @@ def _isel(self, **isel_kwargs: Any):
18031801
}
18041802
return self._updated(update=update_dict)
18051803

1806-
def _assign_coords(self, **assign_coords_kwargs: Any):
1804+
def _assign_coords(self, **assign_coords_kwargs: Optional[Mapping]) -> Self:
18071805
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
18081806
mode index. Used in ``overlap_sort`` but not officially supported since for example
18091807
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
@@ -2210,7 +2208,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
22102208

22112209
return src_adj
22122210

2213-
def _apply_mode_reorder(self, sort_inds_2d):
2211+
def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self:
22142212
"""Apply a mode reordering along mode_index for all frequency indices.
22152213
22162214
Parameters
@@ -2273,7 +2271,7 @@ def sort_modes(
22732271
sort_inds_2d = np.tile(identity, (num_freqs, 1))
22742272

22752273
# Helper to compute ordered indices within a subset
2276-
def _order_indices(indices, vals_all):
2274+
def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray:
22772275
if indices.size == 0:
22782276
return indices
22792277
vals = vals_all.isel(mode_index=indices)
@@ -2392,7 +2390,7 @@ class ModeSolverData(ModeData):
23922390
description="Unused for ModeSolverData.",
23932391
)
23942392

2395-
def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData:
2393+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData:
23962394
"""Return copy of self after normalization is applied using source spectrum function."""
23972395
return self.copy()
23982396

@@ -2485,7 +2483,7 @@ def _make_adjoint_sources(
24852483
"computation."
24862484
)
24872485

2488-
def normalize(self, source_spectrum_fn) -> FluxData:
2486+
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData:
24892487
"""Return copy of self after normalization is applied using source spectrum function."""
24902488
source_freq_amps = source_spectrum_fn(self.flux.f)
24912489
source_power = abs(source_freq_amps) ** 2
@@ -3095,7 +3093,7 @@ def z(self) -> np.ndarray:
30953093
return self.Etheta.z.values
30963094

30973095
@property
3098-
def tangential_dims(self):
3096+
def tangential_dims(self) -> list[str]:
30993097
tangential_dims = ["x", "y", "z"]
31003098
tangential_dims.pop(self.monitor.proj_axis)
31013099
return tangential_dims

0 commit comments

Comments
 (0)