diff --git a/pyproject.toml b/pyproject.toml index 62720ea126..d2e188c1d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -328,7 +328,8 @@ files = [ "tidy3d/web", "tidy3d/config", "tidy3d/material_library", - "tidy3d/components/geometry" + "tidy3d/components/geometry", + "tidy3d/components/data", ] ignore_missing_imports = true follow_imports = "skip" diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index b8d1204212..858ee33e28 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -13,6 +13,7 @@ import numpy as np import xarray as xr from autograd.tracer import isbox +from numpy.typing import NDArray from pydantic.annotated_handlers import GetCoreSchemaHandler from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue from pydantic_core import core_schema @@ -24,7 +25,13 @@ from xarray.core.variable import as_variable from tidy3d.compat import alignment -from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box +from tidy3d.components.autograd import ( + InterpolationType, + TidyArrayBox, + get_static, + interpn, + is_tidy_box, +) from tidy3d.components.geometry.bound_ops import bounds_contains from tidy3d.components.types import Axis, Bound from tidy3d.constants import ( @@ -80,7 +87,7 @@ class DataArray(xr.DataArray): # stores a dictionary of attributes corresponding to the data values _data_attrs: dict[str, str] = {} - def __init__(self, data, *args: Any, **kwargs: Any) -> None: + def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: # if data is a vanilla autograd box, convert to our box if isbox(data) and not is_tidy_box(data): data = TidyArrayBox.from_arraybox(data) @@ -217,7 +224,7 @@ def _interp_validator(self, field_name: Optional[str] = None) -> None: f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." ) - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: """Whether two data array objects are equal.""" if not isinstance(other, xr.DataArray): @@ -231,7 +238,7 @@ def __eq__(self, other) -> bool: return True @property - def values(self): + def values(self) -> NDArray: """ The array's data converted to a numpy.ndarray. @@ -247,18 +254,18 @@ def values(self, value: Any) -> None: self.variable.values = value @property - def abs(self): + def abs(self) -> Self: """Absolute value of data array.""" return abs(self) @property - def angle(self): + def angle(self) -> Self: """Angle or phase value of data array.""" values = np.angle(self.values) return type(self)(values, coords=self.coords) @property - def is_uniform(self): + def is_uniform(self) -> bool: """Whether each element is of equal value in the data array""" raw_data = self.data.ravel() return np.allclose(raw_data, raw_data[0]) @@ -471,7 +478,12 @@ def _ag_interp( return self._from_temp_dataset(ds) @staticmethod - def _ag_interp_func(var, indexes_coords, method, **kwargs: Any): + def _ag_interp_func( + var: xr.Variable, + indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]], + method: InterpolationType, + **kwargs: Any, + ) -> xr.Variable: """ Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. @@ -486,7 +498,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs: Any): The variable to be interpolated. indexes_coords : dict A dictionary mapping dimension names to coordinate values for interpolation. - method : str + method : Literal["nearest", "linear"] The interpolation method to use. **kwargs : dict Additional keyword arguments to pass to the interpolation function. diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index a5e65c998e..716497d7b0 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -7,6 +7,7 @@ import numpy as np import xarray as xr +from numpy.typing import ArrayLike from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel @@ -83,7 +84,7 @@ def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArr """How to package the dictionary of fields computed via self.colocate().""" return xr.Dataset(centered_fields) - def colocate(self, x=None, y=None, z=None) -> xr.Dataset: + def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset: """Colocate all of the data at a set of x, y, z coordinates. Parameters diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index cf22f9faf4..833bb9eb20 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -7,10 +7,11 @@ from abc import ABC from math import isclose from os import PathLike -from typing import Any, Callable, Literal, Optional, Union, get_args +from typing import Any, Callable, Literal, Optional, Self, SupportsComplex, Union, get_args import autograd.numpy as np import xarray as xr +from numpy.typing import NDArray from pandas import DataFrame from pydantic import Field, model_validator @@ -103,6 +104,14 @@ # Threshold for cos(theta) to avoid unphysically large amplitudes near grazing angles COS_THETA_THRESH = 1e-5 +GRID_CORRECTION_TYPE = Union[ + float, + FreqDataArray, + TimeDataArray, + FreqModeDataArray, + EMEFreqModeDataArray, +] + class MonitorData(AbstractMonitorData, ABC): """ @@ -171,7 +180,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str: return "-" if direction == "+" else "+" @staticmethod - def get_amplitude(x) -> complex: + def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex: """Get the complex amplitude out of some data.""" if isinstance(x, DataArray): @@ -213,7 +222,7 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): ) @model_validator(mode="after") - def warn_missing_grid_expanded(self): + def warn_missing_grid_expanded(self) -> Self: """If ``grid_expanded`` not provided and fields data is present, warn that some methods will break.""" field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] @@ -226,15 +235,15 @@ def warn_missing_grid_expanded(self): ) return self - _require_sym_center = required_if_symmetry_present("symmetry_center") - _require_grid_expanded = required_if_symmetry_present("grid_expanded") + _require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center") + _require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded") def _expanded_grid_field_coords(self, field_name: str) -> Coords: """Coordinates in the expanded grid corresponding to a given field component.""" return self.grid_expanded[self.grid_locations[field_name]] @property - def symmetry_expanded(self): + def symmetry_expanded(self) -> Self: """Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the data array. However, if symmetry is not expanded, the returned array contains a view of @@ -391,13 +400,7 @@ def at_coords(self, coords: Coords) -> xr.Dataset: class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC): """Collection of electromagnetic fields.""" - grid_primal_correction: Union[ - float, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - EMEFreqModeDataArray, - ] = Field( + grid_primal_correction: GRID_CORRECTION_TYPE = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -405,13 +408,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A "which the data was computed. The factor is applied to fields defined on the primal grid " "locations along the normal direction.", ) - grid_dual_correction: Union[ - float, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - EMEFreqModeDataArray, - ] = Field( + grid_dual_correction: GRID_CORRECTION_TYPE = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -420,7 +417,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A "locations along the normal direction.", ) - def _expanded_grid_field_coords(self, field_name: str): + def _expanded_grid_field_coords(self, field_name: str) -> Coords: """Coordinates in the expanded grid corresponding to a given field component.""" if self.monitor.colocate: bounds_dict = self.grid_expanded.boundaries.to_dict @@ -428,7 +425,7 @@ def _expanded_grid_field_coords(self, field_name: str): return self.grid_expanded[self.grid_locations[field_name]] @property - def _grid_correction_dict(self): + def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]: """Return the primal and dual finite grid correction factors as a dictionary.""" return { "grid_primal_correction": self.grid_primal_correction, @@ -937,7 +934,7 @@ def outer_dot( d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy() # function to apply at each pair of mode indices before integrating - def fn(fields_1, fields_2): + def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray: e_self_1 = fields_1[e_1] e_self_2 = fields_1[e_2] h_self_1 = fields_1[h_1] @@ -978,7 +975,7 @@ def _outer_fn_summation( outer_dim_1: str, outer_dim_2: str, sum_dims: list[str], - fn: Callable, + fn: Callable[[dict[str, NDArray], NDArray], NDArray], ) -> DataArray: """ Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``. @@ -1676,7 +1673,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData): ) @model_validator(mode="after") - def eps_spec_match_mode_spec(self): + def eps_spec_match_mode_spec(self) -> Self: """Raise validation error if frequencies in eps_spec does not match frequency list""" if self.eps_spec: mode_data_freqs = self.monitor.freqs @@ -1686,7 +1683,7 @@ def eps_spec_match_mode_spec(self): ) return self - def normalize(self, source_spectrum_fn) -> ModeData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self: """Return copy of self after normalization is applied using source spectrum function.""" source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None] new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype) @@ -1808,7 +1805,7 @@ def overlap_sort( return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False) - def _isel(self, **isel_kwargs: Any): + def _isel(self, **isel_kwargs: Any) -> Self: """Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and mode index. Used in ``overlap_sort`` but not officially supported since for example ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the @@ -1822,12 +1819,11 @@ def _isel(self, **isel_kwargs: Any): } return self.updated_copy(**update_dict, deep=False, validate=False) - def _assign_coords(self, **assign_coords_kwargs: Any): + def _assign_coords(self, **assign_coords_kwargs: Any) -> Self: """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and mode index. Used in ``overlap_sort`` but not officially supported since for example ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the newly created data.""" - update_dict = dict(self._grid_correction_dict, **self.field_components) update_dict = { key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items() @@ -2243,7 +2239,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource: return src_adj - def _apply_mode_reorder(self, sort_inds_2d): + def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self: """Apply a mode reordering along mode_index for all frequency indices. Parameters @@ -2342,7 +2338,7 @@ def sort_modes( sort_inds_2d = np.tile(identity, (num_freqs, 1)) # Helper to compute ordered indices within a subset - def _order_indices(indices, vals_all): + def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray: if indices.size == 0: return indices vals = vals_all.isel(mode_index=indices) @@ -2461,7 +2457,7 @@ class ModeSolverData(ModeData): description="Unused for ModeSolverData.", ) - def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData: """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() @@ -2554,7 +2550,7 @@ def _make_adjoint_sources( "computation." ) - def normalize(self, source_spectrum_fn) -> FluxData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData: """Return copy of self after normalization is applied using source spectrum function.""" source_freq_amps = source_spectrum_fn(self.flux.f) source_power = abs(source_freq_amps) ** 2 @@ -3164,7 +3160,7 @@ def z(self) -> np.ndarray: return self.Etheta.z.values @property - def tangential_dims(self): + def tangential_dims(self) -> list[str]: tangential_dims = ["x", "y", "z"] tangential_dims.pop(self.monitor.proj_axis) return tangential_dims diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index a89df2c5d5..e25db96484 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -13,6 +13,7 @@ import h5py import numpy as np import xarray as xr +from numpy.typing import NDArray from pydantic import Field from tidy3d.components.autograd.utils import split_list @@ -25,7 +26,6 @@ from tidy3d.components.source.time import GaussianPulse from tidy3d.components.source.utils import SourceType from tidy3d.components.structure import Structure -from tidy3d.components.types.base import discriminated_union from tidy3d.components.types import ( Ax, Axis, @@ -33,12 +33,13 @@ FieldVal, PlotScale, ) +from tidy3d.components.types.base import discriminated_union from tidy3d.components.types.monitor_data import MonitorDataType, MonitorDataTypes from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.exceptions import DataError, FileError, SetupError, Tidy3dKeyError from tidy3d.log import log -from .data_array import FreqDataArray, TimeDataArray +from .data_array import DataArray, FreqDataArray, TimeDataArray from .monitor_data import AbstractFieldData, FieldTimeData DATA_TYPE_MAP = {data.model_fields["monitor"].annotation: data for data in MonitorDataTypes} @@ -249,7 +250,7 @@ def _get_scalar_field( field_name: str, val: FieldVal, phase: float = 0.0, - ): + ) -> xr.DataArray: """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers. Parameters @@ -280,7 +281,7 @@ def _get_scalar_field_from_data( field_name: str, val: FieldVal, phase: float = 0.0, - ): + ) -> xr.DataArray: """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers. Parameters @@ -351,7 +352,6 @@ def _get_scalar_field_from_data( f"'val' of {val} not supported. " "Must be one of 'real', 'imag', 'abs', 'abs^2', or 'phase'." ) - return derived_data raise Tidy3dKeyError( @@ -984,7 +984,7 @@ def source_spectrum(self, source_index: int) -> Callable: dt = self.simulation.dt # plug in mornitor_data frequency domain information - def source_spectrum_fn(freqs): + def source_spectrum_fn(freqs: DataArray) -> NDArray: """Source amplitude as function of frequency.""" spectrum = source_time.spectrum(times, freqs, dt) @@ -1012,7 +1012,7 @@ def renormalize(self, normalize_index: int) -> SimulationData: f"of length {num_sources}" ) - def source_spectrum_fn(freqs): + def source_spectrum_fn(freqs: DataArray) -> NDArray: """Normalization function that also removes previous normalization if needed.""" new_spectrum_fn = self.source_spectrum(normalize_index) old_spectrum_fn = self.source_spectrum(self.simulation.normalize_index) diff --git a/tidy3d/components/data/unstructured/base.py b/tidy3d/components/data/unstructured/base.py index 3bd68e958c..1c46449f9e 100644 --- a/tidy3d/components/data/unstructured/base.py +++ b/tidy3d/components/data/unstructured/base.py @@ -5,17 +5,20 @@ import numbers from abc import ABC, abstractmethod from os import PathLike -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Self, Union import numpy as np +from numpy.typing import DTypeLike, NDArray from pandas import RangeIndex from pydantic import Field, PositiveInt, field_validator, model_validator +from vtkmodules.vtkCommonCore import vtkPoints from xarray import DataArray as XrDataArray from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( DATA_ARRAY_MAP, CellDataArray, + DataArray, IndexedDataArray, IndexedDataArrayTypes, PointDataArray, @@ -28,6 +31,15 @@ from tidy3d.log import log from tidy3d.packaging import requires_vtk, vtk +if TYPE_CHECKING: + from vtkmodules.vtkCommonDataModel import ( + vtkCellArray, + vtkDataSet, + vtkPointData, + vtkPolyData, + vtkUnstructuredGrid, + ) + DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 DEFAULT_MAX_CELLS_PER_STEP = 10_000 DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 @@ -68,7 +80,7 @@ def _cell_num_vertices(cls) -> PositiveInt: @field_validator("points") @classmethod - def points_right_dims(cls, val): + def points_right_dims(cls, val: PointDataArray) -> PointDataArray: """Check that point coordinates have the right dimensionality.""" # currently support only the standard axis ordering, that is 01(2) axis_coords_expected = np.arange(cls._point_dims()) @@ -81,7 +93,7 @@ def points_right_dims(cls, val): return val @field_validator("points") - def points_right_indexing(val): + def points_right_indexing(val: PointDataArray) -> PointDataArray: """Check that points are indexed corrrectly.""" indices_expected = np.arange(len(val.data)) indices_given = val.index.data @@ -94,14 +106,14 @@ def points_right_indexing(val): return val @field_validator("values") - def first_values_dim_is_index(val): + def first_values_dim_is_index(val: IndexedDataArrayTypes) -> IndexedDataArrayTypes: """Check that the number of data values matches the number of grid points.""" if val.dims[0] != "index": raise ValidationError("First dimension of array 'values' must be 'index'.") return val @field_validator("values") - def values_right_indexing(val): + def values_right_indexing(val: IndexedDataArrayTypes) -> IndexedDataArrayTypes: """Check that data values are indexed correctly.""" # currently support only simple ordered indexing of points, that is, 0, 1, 2, ... indices_expected = np.arange(len(val.index.data)) @@ -115,7 +127,7 @@ def values_right_indexing(val): return val @model_validator(mode="after") - def number_of_values_matches_points(self): + def number_of_values_matches_points(self) -> Self: """Check that the number of data values matches the number of grid points.""" num_values = len(self.values.index) num_points = len(self.points) @@ -128,7 +140,7 @@ def number_of_values_matches_points(self): return self @field_validator("cells") - def match_cells_to_vtk_type(val): + def match_cells_to_vtk_type(val: CellDataArray) -> CellDataArray: """Check that cell connections does not have duplicate points.""" if vtk is None: return val @@ -138,7 +150,7 @@ def match_cells_to_vtk_type(val): @field_validator("cells") @classmethod - def cells_right_type(cls, val): + def cells_right_type(cls, val: CellDataArray) -> CellDataArray: """Check that cell are of the right type.""" # only supporting the standard ordering of cell vertices 012(3) vertex_coords_expected = np.arange(cls._cell_num_vertices()) @@ -151,7 +163,7 @@ def cells_right_type(cls, val): return val @model_validator(mode="after") - def check_cell_vertex_range(self): + def check_cell_vertex_range(self) -> Self: """Check that cell connections use only defined points.""" val = getattr(self, "cells", None) if val is None: @@ -173,7 +185,7 @@ def check_cell_vertex_range(self): return self @field_validator("cells") - def warn_degenerate_cells(cls, val): + def warn_degenerate_cells(cls, val: CellDataArray) -> CellDataArray: """Check that cell connections does not have duplicate points.""" degenerate_cells = cls._find_degenerate_cells(val) num_degenerate_cells = len(degenerate_cells) @@ -217,7 +229,7 @@ def _warn_if_none(cls, data: Any) -> Any: @model_validator(mode="before") def _add_default_coords(cls, data: dict) -> dict: - def _add_default_coords(da): + def _add_default_coords(da: DataArray) -> DataArray: """Add 0..N-1 coordinates to any dimension that does not already have one. Note: We use a pandas `RangeIndex` here for constant memory. """ @@ -233,7 +245,7 @@ def _add_default_coords(da): return data @model_validator(mode="after") - def _warn_unused_points(self): + def _warn_unused_points(self) -> Self: """Warn if some points are unused.""" point_indices = set(np.arange(len(self.points.data))) used_indices = set(self.cells.values.ravel()) @@ -264,34 +276,34 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.values) @property - def _double_type(self): + def _double_type(self) -> DTypeLike: """Corresponding double data type.""" return np.complex128 if self.is_complex else np.float64 @property - def is_uniform(self): + def is_uniform(self) -> bool: """Whether each element is of equal value in ``values``.""" return self.values.is_uniform @cached_property - def _values_coords_dict(self): + def _values_coords_dict(self) -> dict[str, Any]: """Non-spatial dimensions are corresponding coordinate values of stored data.""" coord_dict = {dim: self.values.coords[dim].data for dim in self.values.dims} _ = coord_dict.pop("index") return coord_dict @cached_property - def _fields_shape(self): + def _fields_shape(self) -> list[int]: """Shape in which fields are stored.""" return [len(coord) for coord in self._values_coords_dict.values()] @cached_property - def _num_fields(self): + def _num_fields(self) -> int: """Total number of stored fields.""" return 1 if len(self._fields_shape) == 0 else np.prod(self._fields_shape) @cached_property - def _values_type(self): + def _values_type(self) -> type: """Type of array storing values.""" return type(self.values) @@ -308,7 +320,7 @@ def _points_3d_array(self) -> None: """ Grid cleaning """ @classmethod - def _find_degenerate_cells(cls, cells: CellDataArray): + def _find_degenerate_cells(cls, cells: CellDataArray) -> set[int]: """Find explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ @@ -318,14 +330,13 @@ def _find_degenerate_cells(cls, cells: CellDataArray): if len(indices) > 0: for i in range(cls._cell_num_vertices() - 1): for j in range(i + 1, cls._cell_num_vertices()): - degenerate_cell_inds = degenerate_cell_inds.union( - np.where(indices[:, i] == indices[:, j])[0] - ) + new_inds = np.where(indices[:, i] == indices[:, j])[0] + degenerate_cell_inds |= {int(k) for k in new_inds} return degenerate_cell_inds @classmethod - def _remove_degenerate_cells(cls, cells: CellDataArray): + def _remove_degenerate_cells(cls, cells: CellDataArray) -> CellDataArray: """Remove explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ @@ -341,7 +352,7 @@ def _remove_degenerate_cells(cls, cells: CellDataArray): @classmethod def _remove_unused_points( cls, points: PointDataArray, values: IndexedDataArrayTypes, cells: CellDataArray - ): + ) -> tuple[PointDataArray, IndexedDataArrayTypes, CellDataArray]: """Remove unused points if any. That is, points that are not used in any grid cell. """ @@ -364,7 +375,9 @@ def _remove_unused_points( return points, values, cells - def clean(self, remove_degenerate_cells=True, remove_unused_points=True): + def clean( + self, remove_degenerate_cells: bool = True, remove_unused_points: bool = True + ) -> Self: """Remove degenerate cells and/or unused points.""" if remove_degenerate_cells: cells = self._remove_degenerate_cells(cells=self.cells) @@ -381,7 +394,9 @@ def clean(self, remove_degenerate_cells=True, remove_unused_points=True): """ Arithmetic operations """ - def __array_ufunc__(self, ufunc, method, *inputs: Any, **kwargs: Any): + def __array_ufunc__( + self, ufunc: np.ufunc, method: str, *inputs: Union[Self, numbers.Number], **kwargs: Any + ) -> Optional[Union[Self, tuple[Self, ...]]]: """Override of numpy functions.""" out = kwargs.get("out", ()) @@ -416,7 +431,7 @@ def __array_ufunc__(self, ufunc, method, *inputs: Any, **kwargs: Any): return self.updated_copy(values=result) @property - def real(self) -> UnstructuredGridDataset: + def real(self) -> Self: """Real part of dataset.""" return self.updated_copy(values=self.values.real) @@ -449,7 +464,7 @@ def _vtk_offsets(self) -> ArrayLike: @property @requires_vtk - def _vtk_cells(self): + def _vtk_cells(self) -> vtkCellArray: """VTK cell array to use in the VTK representation.""" cells = vtk["mod"].vtkCellArray() cells.SetData( @@ -460,7 +475,7 @@ def _vtk_cells(self): @property @requires_vtk - def _vtk_points(self): + def _vtk_points(self) -> vtkPoints: """VTK point array to use in the VTK representation.""" pts = vtk["mod"].vtkPoints() pts.SetData(vtk["numpy_to_vtk"](self._points_3d_array)) @@ -468,7 +483,7 @@ def _vtk_points(self): @property @requires_vtk - def _vtk_obj(self): + def _vtk_obj(self) -> vtkUnstructuredGrid: """A VTK representation (vtkUnstructuredGrid) of the grid.""" grid = vtk["mod"].vtkUnstructuredGrid() @@ -496,7 +511,7 @@ def _vtk_obj(self): @staticmethod @requires_vtk - def _read_vtkUnstructuredGrid(fname: PathLike): + def _read_vtkUnstructuredGrid(fname: PathLike) -> vtkUnstructuredGrid: """Load a :class:`vtkUnstructuredGrid` from a file.""" fname = str(fname) reader = vtk["mod"].vtkXMLUnstructuredGridReader() @@ -508,7 +523,7 @@ def _read_vtkUnstructuredGrid(fname: PathLike): @staticmethod @requires_vtk - def _read_vtkLegacyFile(fname: PathLike): + def _read_vtkLegacyFile(fname: PathLike) -> vtkUnstructuredGrid: """Load a grid from a legacy `.vtk` file.""" fname = str(fname) reader = vtk["mod"].vtkGenericDataObjectReader() @@ -523,20 +538,20 @@ def _read_vtkLegacyFile(fname: PathLike): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, + vtk_obj: vtkUnstructuredGrid, field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, - expect_complex=None, - ignore_invalid_cells=False, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, + ignore_invalid_cells: bool = False, ) -> UnstructuredGridDataset: """Initialize from a vtk object.""" @requires_vtk def _from_vtk_obj_internal( self, - vtk_obj, + vtk_obj: vtkUnstructuredGrid, remove_degenerate_cells: bool = True, remove_unused_points: bool = True, ) -> UnstructuredGridDataset: @@ -649,8 +664,8 @@ def to_vtu(self, fname: PathLike) -> None: @requires_vtk def _cell_to_point_data( cls, - vtk_obj, - ): + vtk_obj: vtkCellArray, + ) -> vtkPointData: """Get point data values from a VTK object.""" cellDataToPointData = vtk["mod"].vtkCellDataToPointData() @@ -663,11 +678,11 @@ def _cell_to_point_data( @requires_vtk def _get_values_from_vtk( cls, - vtk_obj, + vtk_obj: vtkDataSet, num_points: PositiveInt, field: Optional[str] = None, - values_type=IndexedDataArray, - expect_complex=None, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, ) -> IndexedDataArray: """Get point data values from a VTK object.""" @@ -734,7 +749,7 @@ def _get_values_from_vtk( return values - def get_cell_values(self, **kwargs: Any): + def get_cell_values(self, **kwargs: Any) -> NDArray: """This function returns the cell values for the fields stored in the UnstructuredGridDataset. If multiple fields are stored per point, like in an IndexedVoltageDataArray, cell values will be provided for each of the fields unless a selection argument is provided, e.g., voltage=0.2 @@ -759,7 +774,7 @@ def get_cell_volumes(self) -> None: """ Grid operations """ @requires_vtk - def _plane_slice_raw(self, axis: Axis, pos: float): + def _plane_slice_raw(self, axis: Axis, pos: float) -> vtkPolyData: """Slice data with a plane and return the resulting VTK object.""" if pos > self.bounds[1][axis] or pos < self.bounds[0][axis]: @@ -983,7 +998,12 @@ def interp( return result - def _non_spatial_interp(self, method="linear", fill_value=np.nan, **coords_kwargs: Any): + def _non_spatial_interp( + self, + method: Literal["linear", "nearest"] = "linear", + fill_value: Union[float, Literal["extrapolate"]] = np.nan, + **coords_kwargs: Any, + ) -> Self: """Interpolate data at non-spatial dimensions using xarray's interp() function. Parameters @@ -1009,7 +1029,7 @@ def _non_spatial_interp(self, method="linear", fill_value=np.nan, **coords_kwarg return self.updated_copy( values=self.values.interp( **coords_kwargs_only_lists, - method="linear", + method=method, kwargs={"fill_value": fill_value}, ) ) @@ -1753,13 +1773,15 @@ def sel( def _non_spatial_sel( self, - method=None, + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs: Any, ) -> XrDataArray: """Select/interpolate data along one or more non-Cartesian directions. Parameters ---------- + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None + Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. diff --git a/tidy3d/components/data/unstructured/tetrahedral.py b/tidy3d/components/data/unstructured/tetrahedral.py index cc8c822b3a..f7a3ea9b44 100644 --- a/tidy3d/components/data/unstructured/tetrahedral.py +++ b/tidy3d/components/data/unstructured/tetrahedral.py @@ -2,10 +2,11 @@ from __future__ import annotations -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np from pydantic import PositiveInt +from xarray import DataArray from xarray import DataArray as XrDataArray from tidy3d.components.base import cached_property @@ -17,6 +18,9 @@ from .base import UnstructuredGridDataset from .triangular import TriangularGridDataset +if TYPE_CHECKING: + from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid + class TetrahedralGridDataset(UnstructuredGridDataset): """Dataset for storing tetrahedral grid data. Data values are associated with the nodes of @@ -79,7 +83,7 @@ def _points_3d_array(self) -> Bound: @classmethod @requires_vtk - def _vtk_cell_type(cls): + def _vtk_cell_type(cls) -> int: """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TETRA @@ -87,11 +91,11 @@ def _vtk_cell_type(cls): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, - field=None, + vtk_obj: vtkUnstructuredGrid, + field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, + values_type: type = IndexedDataArray, expect_complex: bool = False, ignore_invalid_cells: bool = False, ) -> TetrahedralGridDataset: @@ -297,7 +301,7 @@ def sel( x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, - method=None, + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs: Any, ) -> Union[TriangularGridDataset, XrDataArray]: """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument @@ -313,7 +317,7 @@ def sel( y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. - method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. @@ -357,7 +361,7 @@ def sel( return self_after_non_spatial_sel - def get_cell_volumes(self): + def get_cell_volumes(self) -> DataArray: """Get the volumes associated to each cell in the grid""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 diff --git a/tidy3d/components/data/unstructured/triangular.py b/tidy3d/components/data/unstructured/triangular.py index 50d836f6f6..672e14fe2b 100644 --- a/tidy3d/components/data/unstructured/triangular.py +++ b/tidy3d/components/data/unstructured/triangular.py @@ -2,10 +2,11 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Self, Union import numpy as np from pydantic import Field, PositiveInt +from xarray import DataArray from xarray import DataArray as XrDataArray try: @@ -35,6 +36,9 @@ UnstructuredGridDataset, ) +if TYPE_CHECKING: + from vtkmodules.vtkCommonDataModel import vtkPointSet + class TriangularGridDataset(UnstructuredGridDataset): """Dataset for storing triangular grid data. Data values are associated with the nodes of @@ -115,7 +119,7 @@ def _points_3d_array(self) -> ArrayLike: @classmethod @requires_vtk - def _vtk_cell_type(cls): + def _vtk_cell_type(cls) -> int: """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TRIANGLE @@ -123,14 +127,14 @@ def _vtk_cell_type(cls): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, - field=None, + vtk_obj: vtkPointSet, + field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, - expect_complex=None, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, ignore_invalid_cells: bool = False, - ): + ) -> Self: """Initialize from a vtkUnstructuredGrid instance.""" # get points cells data from vtk object @@ -667,7 +671,7 @@ def plot( ax.set_title(f"{normal_axis_name} = {self.normal_pos}") return ax - def get_cell_volumes(self): + def get_cell_volumes(self) -> DataArray: """Get areas associated to each cell of the grid.""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 diff --git a/tidy3d/components/data/validators.py b/tidy3d/components/data/validators.py index e5c3f93789..09f6211dc9 100644 --- a/tidy3d/components/data/validators.py +++ b/tidy3d/components/data/validators.py @@ -1,10 +1,11 @@ # special validators for Datasets from __future__ import annotations -from typing import Optional +from typing import Any, Callable, Optional import numpy as np from pydantic import field_validator +from pydantic_core.core_schema import ValidationInfo from tidy3d.exceptions import ValidationError @@ -13,20 +14,20 @@ # this can't go in validators.py because that file imports dataset.py -def validate_no_nans(*field_names: str): +def validate_no_nans(*field_names: str) -> Callable[[Any, ValidationInfo], Any]: """Raise validation error if nans found in Dataset, or other data-containing item.""" @field_validator(*field_names) - def no_nans(val, info): + def no_nans(val: Any, info: ValidationInfo) -> Any: """Raise validation error if nans found in Dataset, or other data-containing item.""" if val is None: return val - def error_if_has_nans(value, identifier: Optional[str] = None) -> None: + def error_if_has_nans(value: Any, identifier: Optional[str] = None) -> None: """Recursively check if value (or iterable) has nans and error if so.""" - def has_nans(values) -> bool: + def has_nans(values: Any) -> bool: """Base case: do these values contain NaN?""" try: return np.any(np.isnan(values)) @@ -65,7 +66,9 @@ def has_nans(values) -> bool: return no_nans -def validate_can_interpolate(*field_names: str): +def validate_can_interpolate( + *field_names: str, +) -> Callable[[AbstractFieldDataset], AbstractFieldDataset]: """Make sure the data in ``field_name`` can be interpolated.""" @field_validator(*field_names) diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 58a7098584..4491aabbac 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union import autograd.numpy as np +import pydantic import shapely import xarray as xr -from numpy._typing import ArrayLike, NDArray -from typing_extensions import Self +from numpy.typing import ArrayLike, NDArray from pydantic import ( Field, NonNegativeFloat, @@ -22,6 +22,7 @@ field_validator, model_validator, ) +from typing_extensions import Self from tidy3d.compat import _package_is_older_than from tidy3d.components.autograd import ( @@ -1172,7 +1173,10 @@ def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: @staticmethod @verify_packages_import(["gdstk"]) def load_gds_vertices_gdstk( - gds_cell: Cell, gds_layer: int, gds_dtype: Optional[int] = None, gds_scale: PositiveFloat = 1.0 + gds_cell: Cell, + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, ) -> list[ArrayFloat2D]: """Load polygon vertices from a ``gdstk.Cell``. @@ -1560,7 +1564,7 @@ class Centered(Geometry, ABC): @field_validator("center", mode="before") @classmethod - def _center_default(cls, val): + def _center_default(cls, val: Any) -> Any: """Make sure center is not infinitiy.""" if val is None: val = (0.0, 0.0, 0.0) @@ -3488,9 +3492,7 @@ class GeometryGroup(Geometry): ) @field_validator("geometries") - def _geometries_not_empty( - val: tuple[annotate_type(GeometryType), ...] - ) -> tuple[annotate_type(GeometryType), ...]: + def _geometries_not_empty(val: tuple[GeometryType, ...]) -> tuple[GeometryType, ...]: """make sure geometries are not empty.""" if not len(val) > 0: raise ValidationError("GeometryGroup.geometries must not be empty.") diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 9658394a84..c0152368eb 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -5,7 +5,7 @@ import math from copy import copy from functools import lru_cache -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union import autograd.numpy as np import shapely @@ -140,7 +140,7 @@ def correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: return val @model_validator(mode="after") - def no_complex_self_intersecting_polygon_at_reference_plane(self): + def no_complex_self_intersecting_polygon_at_reference_plane(self: Self) -> Self: """At the reference plane, check if the polygon is self-intersecting. There are two types of self-intersection that can occur during dilation: @@ -190,7 +190,7 @@ def no_complex_self_intersecting_polygon_at_reference_plane(self): return self @model_validator(mode="after") - def no_self_intersecting_polygon_during_extrusion(self): + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that any normal cross section of the PolySlab cannot be self-intersecting. This part checks if any self-interction will occur during extrusion with non-zero sidewall angle. @@ -2355,7 +2355,7 @@ class ComplexPolySlabBase(PolySlab): :class:`plugins.polyslab.ComplexPolySlab`.""" @model_validator(mode="after") - def no_self_intersecting_polygon_during_extrusion(self): + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: """Turn off the validation for this class.""" return self diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index ebb34471bb..8884f9fd41 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -3,13 +3,13 @@ from __future__ import annotations from math import isclose -from typing import Any, Optional +from typing import Any, Optional, Self import autograd.numpy as anp import numpy as np import shapely -from shapely.geometry.base import BaseGeometry from pydantic import Field, model_validator +from shapely.geometry.base import BaseGeometry from tidy3d.components.autograd import AutogradFieldMap, TracedSize1D from tidy3d.components.autograd.derivative_utils import DerivativeInfo @@ -211,7 +211,7 @@ class Cylinder(base.Centered, base.Circular, base.Planar): ) @model_validator(mode="after") - def _only_middle_for_infinite_length_slanted_cylinder(self): + def _only_middle_for_infinite_length_slanted_cylinder(self: Self) -> Self: """For a slanted cylinder of infinite length, ``reference_plane`` can only be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. """ diff --git a/tidy3d/material_library/material_library.py b/tidy3d/material_library/material_library.py index de9280d9b0..ee252cded7 100644 --- a/tidy3d/material_library/material_library.py +++ b/tidy3d/material_library/material_library.py @@ -4,7 +4,7 @@ import json from os import PathLike -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Optional, Self, Union from pydantic import Field, model_validator from rich.panel import Panel @@ -127,7 +127,7 @@ class MaterialItem(Tidy3dBaseModel): default: str = Field(title="default variant", description="The default type of variant.") @model_validator(mode="after") - def _default_in_variants(self): + def _default_in_variants(self: Self) -> Self: """Make sure the default variant is already included in the ``variants``.""" if self.default not in self.variants: raise SetupError( diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 455f6d2546..b05fedf11c 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -600,13 +600,13 @@ def _check_path_dir(path: PathLike) -> None: parent_dir.mkdir(parents=True, exist_ok=True) @model_validator(mode="before") - def set_task_name_if_none(data): + def set_task_name_if_none(data: dict[str, Any]) -> dict[str, Any]: """ Auto-assign a task_name if user did not provide one. """ if not isinstance(data, dict): return data - + if data.get("task_name") is None: sim = data.get("simulation") stub = Tidy3dStub(simulation=sim) @@ -754,7 +754,8 @@ class Batch(WebContainer): """ simulations: Union[ - dict[TaskName, discriminated_union(WorkflowType)], tuple[discriminated_union(WorkflowType), ...] + dict[TaskName, discriminated_union(WorkflowType)], + tuple[discriminated_union(WorkflowType), ...], ] = Field( title="Simulations", description="Mapping of task names to Simulations to run as a batch.", diff --git a/tidy3d/web/api/material_library.py b/tidy3d/web/api/material_library.py index f6e6aa164a..c6037aeabb 100644 --- a/tidy3d/web/api/material_library.py +++ b/tidy3d/web/api/material_library.py @@ -3,12 +3,12 @@ from __future__ import annotations import json -from typing import Optional +from typing import Any, Optional from pydantic import Field, TypeAdapter, field_validator from tidy3d.components.medium import MediumType -from tidy3d.web.core.http_util import http +from tidy3d.web.core.http_util import JSONType, http from tidy3d.web.core.types import Queryable @@ -44,7 +44,7 @@ class MaterialLibrary(Queryable): @field_validator("medium", "json_input", mode="before") @classmethod - def parse_result(cls, values): + def parse_result(cls, values: Any) -> JSONType: """Automatically parsing medium and json_input from string to object.""" return json.loads(values)