Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ files = [
"tidy3d/web",
"tidy3d/config",
"tidy3d/material_library",
"tidy3d/components/geometry"
"tidy3d/components/geometry",
"tidy3d/components/data",
]
ignore_missing_imports = true
follow_imports = "skip"
Expand Down
30 changes: 21 additions & 9 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import xarray as xr
from autograd.tracer import isbox
from numpy.typing import NDArray
from pydantic.annotated_handlers import GetCoreSchemaHandler
from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue
from pydantic_core import core_schema
Expand All @@ -24,7 +25,13 @@
from xarray.core.variable import as_variable

from tidy3d.compat import alignment
from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box
from tidy3d.components.autograd import (
InterpolationType,
TidyArrayBox,
get_static,
interpn,
is_tidy_box,
)
from tidy3d.components.geometry.bound_ops import bounds_contains
from tidy3d.components.types import Axis, Bound
from tidy3d.constants import (
Expand Down Expand Up @@ -80,7 +87,7 @@ class DataArray(xr.DataArray):
# stores a dictionary of attributes corresponding to the data values
_data_attrs: dict[str, str] = {}

def __init__(self, data, *args: Any, **kwargs: Any) -> None:
def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None:
# if data is a vanilla autograd box, convert to our box
if isbox(data) and not is_tidy_box(data):
data = TidyArrayBox.from_arraybox(data)
Expand Down Expand Up @@ -217,7 +224,7 @@ def _interp_validator(self, field_name: Optional[str] = None) -> None:
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
)

def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Whether two data array objects are equal."""

if not isinstance(other, xr.DataArray):
Expand All @@ -231,7 +238,7 @@ def __eq__(self, other) -> bool:
return True

@property
def values(self):
def values(self) -> NDArray:
"""
The array's data converted to a numpy.ndarray.

Expand All @@ -247,18 +254,18 @@ def values(self, value: Any) -> None:
self.variable.values = value

@property
def abs(self):
def abs(self) -> Self:
"""Absolute value of data array."""
return abs(self)

@property
def angle(self):
def angle(self) -> Self:
"""Angle or phase value of data array."""
values = np.angle(self.values)
return type(self)(values, coords=self.coords)

@property
def is_uniform(self):
def is_uniform(self) -> bool:
"""Whether each element is of equal value in the data array"""
raw_data = self.data.ravel()
return np.allclose(raw_data, raw_data[0])
Expand Down Expand Up @@ -471,7 +478,12 @@ def _ag_interp(
return self._from_temp_dataset(ds)

@staticmethod
def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
def _ag_interp_func(
var: xr.Variable,
indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]],
method: InterpolationType,
**kwargs: Any,
) -> xr.Variable:
"""
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.

Expand All @@ -486,7 +498,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
The variable to be interpolated.
indexes_coords : dict
A dictionary mapping dimension names to coordinate values for interpolation.
method : str
method : Literal["nearest", "linear"]
The interpolation method to use.
**kwargs : dict
Additional keyword arguments to pass to the interpolation function.
Expand Down
3 changes: 2 additions & 1 deletion tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import xarray as xr
from numpy.typing import ArrayLike
from pydantic import Field

from tidy3d.components.base import Tidy3dBaseModel
Expand Down Expand Up @@ -83,7 +84,7 @@ def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArr
"""How to package the dictionary of fields computed via self.colocate()."""
return xr.Dataset(centered_fields)

def colocate(self, x=None, y=None, z=None) -> xr.Dataset:
def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset:
"""Colocate all of the data at a set of x, y, z coordinates.

Parameters
Expand Down
64 changes: 30 additions & 34 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from abc import ABC
from math import isclose
from os import PathLike
from typing import Any, Callable, Literal, Optional, Union, get_args
from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args

import autograd.numpy as np
import xarray as xr
from numpy.typing import NDArray
from pandas import DataFrame
from pydantic import Field, model_validator

Expand Down Expand Up @@ -103,6 +104,14 @@
# Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles
COS_THETA_THRESH = 1e-5

GRID_CORRECTION_TYPE = Union[
float,
FreqDataArray,
TimeDataArray,
FreqModeDataArray,
EMEFreqModeDataArray,
]


class MonitorData(AbstractMonitorData, ABC):
"""
Expand Down Expand Up @@ -171,7 +180,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str:
return "-" if direction == "+" else "+"

@staticmethod
def get_amplitude(x) -> complex:
def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
"""Get the complex amplitude out of some data."""

if isinstance(x, DataArray):
Expand Down Expand Up @@ -213,7 +222,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC):
)

@model_validator(mode="after")
def warn_missing_grid_expanded(self):
def warn_missing_grid_expanded(self) -> Self:
"""If ``grid_expanded`` not provided and fields data is present, warn that some methods
will break."""
field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]
Expand All @@ -226,15 +235,15 @@ def warn_missing_grid_expanded(self):
)
return self

_require_sym_center = required_if_symmetry_present("symmetry_center")
_require_grid_expanded = required_if_symmetry_present("grid_expanded")
_require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center")
_require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded")

def _expanded_grid_field_coords(self, field_name: str) -> Coords:
"""Coordinates in the expanded grid corresponding to a given field component."""
return self.grid_expanded[self.grid_locations[field_name]]

@property
def symmetry_expanded(self):
def symmetry_expanded(self) -> Self:
"""Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If
any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the
data array. However, if symmetry is not expanded, the returned array contains a view of
Expand Down Expand Up @@ -391,27 +400,15 @@ def at_coords(self, coords: Coords) -> xr.Dataset:
class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC):
"""Collection of electromagnetic fields."""

grid_primal_correction: Union[
float,
FreqDataArray,
TimeDataArray,
FreqModeDataArray,
EMEFreqModeDataArray,
] = Field(
grid_primal_correction: GRID_CORRECTION_TYPE = Field(
1.0,
title="Field correction factor",
description="Correction factor that needs to be applied for data corresponding to a 2D "
"monitor to take into account the finite grid in the normal direction in the simulation in "
"which the data was computed. The factor is applied to fields defined on the primal grid "
"locations along the normal direction.",
)
grid_dual_correction: Union[
float,
FreqDataArray,
TimeDataArray,
FreqModeDataArray,
EMEFreqModeDataArray,
] = Field(
grid_dual_correction: GRID_CORRECTION_TYPE = Field(
1.0,
title="Field correction factor",
description="Correction factor that needs to be applied for data corresponding to a 2D "
Expand All @@ -420,15 +417,15 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A
"locations along the normal direction.",
)

def _expanded_grid_field_coords(self, field_name: str):
def _expanded_grid_field_coords(self, field_name: str) -> Coords:
"""Coordinates in the expanded grid corresponding to a given field component."""
if self.monitor.colocate:
bounds_dict = self.grid_expanded.boundaries.to_dict
return Coords(**{key: val[:-1] for key, val in bounds_dict.items()})
return self.grid_expanded[self.grid_locations[field_name]]

@property
def _grid_correction_dict(self):
def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]:
"""Return the primal and dual finite grid correction factors as a dictionary."""
return {
"grid_primal_correction": self.grid_primal_correction,
Expand Down Expand Up @@ -937,7 +934,7 @@ def outer_dot(
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()

# function to apply at each pair of mode indices before integrating
def fn(fields_1, fields_2):
def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray:
e_self_1 = fields_1[e_1]
e_self_2 = fields_1[e_2]
h_self_1 = fields_1[h_1]
Expand Down Expand Up @@ -978,7 +975,7 @@ def _outer_fn_summation(
outer_dim_1: str,
outer_dim_2: str,
sum_dims: list[str],
fn: Callable,
fn: Callable[[dict[str, NDArray], NDArray], NDArray],
) -> DataArray:
"""
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
Expand Down Expand Up @@ -1676,7 +1673,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData):
)

@model_validator(mode="after")
def eps_spec_match_mode_spec(self):
def eps_spec_match_mode_spec(self) -> Self:
"""Raise validation error if frequencies in eps_spec does not match frequency list"""
if self.eps_spec:
mode_data_freqs = self.monitor.freqs
Expand All @@ -1686,7 +1683,7 @@ def eps_spec_match_mode_spec(self):
)
return self

def normalize(self, source_spectrum_fn) -> ModeData:
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self:
"""Return copy of self after normalization is applied using source spectrum function."""
source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None]
new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype)
Expand Down Expand Up @@ -1808,7 +1805,7 @@ def overlap_sort(

return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False)

def _isel(self, **isel_kwargs: Any):
def _isel(self, **isel_kwargs: Any) -> Self:
"""Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and
mode index. Used in ``overlap_sort`` but not officially supported since for example
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
Expand All @@ -1822,12 +1819,11 @@ def _isel(self, **isel_kwargs: Any):
}
return self.updated_copy(**update_dict, deep=False, validate=False)

def _assign_coords(self, **assign_coords_kwargs: Any):
def _assign_coords(self, **assign_coords_kwargs: Any) -> Self:
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
mode index. Used in ``overlap_sort`` but not officially supported since for example
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
newly created data."""

update_dict = dict(self._grid_correction_dict, **self.field_components)
update_dict = {
key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items()
Expand Down Expand Up @@ -2243,7 +2239,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:

return src_adj

def _apply_mode_reorder(self, sort_inds_2d):
def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self:
"""Apply a mode reordering along mode_index for all frequency indices.

Parameters
Expand Down Expand Up @@ -2342,7 +2338,7 @@ def sort_modes(
sort_inds_2d = np.tile(identity, (num_freqs, 1))

# Helper to compute ordered indices within a subset
def _order_indices(indices, vals_all):
def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray:
if indices.size == 0:
return indices
vals = vals_all.isel(mode_index=indices)
Expand Down Expand Up @@ -2461,7 +2457,7 @@ class ModeSolverData(ModeData):
description="Unused for ModeSolverData.",
)

def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData:
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData:
"""Return copy of self after normalization is applied using source spectrum function."""
return self.copy()

Expand Down Expand Up @@ -2554,7 +2550,7 @@ def _make_adjoint_sources(
"computation."
)

def normalize(self, source_spectrum_fn) -> FluxData:
def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData:
"""Return copy of self after normalization is applied using source spectrum function."""
source_freq_amps = source_spectrum_fn(self.flux.f)
source_power = abs(source_freq_amps) ** 2
Expand Down Expand Up @@ -3164,7 +3160,7 @@ def z(self) -> np.ndarray:
return self.Etheta.z.values

@property
def tangential_dims(self):
def tangential_dims(self) -> list[str]:
tangential_dims = ["x", "y", "z"]
tangential_dims.pop(self.monitor.proj_axis)
return tangential_dims
Expand Down
Loading