Skip to content

Commit efc014c

Browse files
committed
fix: Updates to speed up some mode data operations
1 parent 691741b commit efc014c

File tree

3 files changed

+102
-48
lines changed

3 files changed

+102
-48
lines changed

docs/notebooks

Submodule notebooks updated 85 files

tidy3d/components/data/monitor_data.py

Lines changed: 100 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,6 @@ def amplitude_fn(freq: list[float]) -> complex:
151151

152152
return self.normalize(amplitude_fn)
153153

154-
def _updated(self, update: dict) -> MonitorData:
155-
"""Similar to ``updated_copy``, but does not actually copy components, for speed.
156-
157-
Note
158-
----
159-
This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
160-
are not updated, they will point to the values in the original data. This method should
161-
thus be used carefully.
162-
163-
"""
164-
data_dict = self.dict()
165-
data_dict.update(update)
166-
return type(self).parse_obj(data_dict)
167-
168154
def _make_adjoint_sources(self, dataset_names: list[str], fwidth: float) -> list[Source]:
169155
"""Generate adjoint sources for this ``MonitorData`` instance."""
170156

@@ -263,7 +249,7 @@ def symmetry_expanded(self):
263249
if all(sym == 0 for sym in self.symmetry):
264250
return self
265251

266-
return self._updated(self._symmetry_update_dict)
252+
return self.updated_copy(**self._symmetry_update_dict, deep=False, validate=False)
267253

268254
@property
269255
def symmetry_expanded_copy(self) -> AbstractFieldData:
@@ -782,23 +768,51 @@ def dot(
782768
fields_self = {key: field.conj() for key, field in fields_self.items()}
783769

784770
fields_other = field_data._interpolated_tangential_fields(self._plane_grid_boundaries)
771+
dim1, dim2 = self._tangential_dims
772+
d_area = self._diff_area
785773

786-
# Drop size-1 dimensions in the other data
787-
fields_other = {key: field.squeeze(drop=True) for key, field in fields_other.items()}
774+
# After interpolation, the tangential coordinates should match. However, the two arrays
775+
# may either have the same shape along other dimensions, or be broadcastable.
776+
if (
777+
fields_self[next(iter(fields_self))].shape
778+
== fields_other[next(iter(fields_other))].shape
779+
):
780+
# Arrays are same shape, so we can use numpy
781+
e_self_x_h_other = fields_self["E" + dim1].values * fields_other["H" + dim2].values
782+
e_self_x_h_other -= fields_self["E" + dim2].values * fields_other["H" + dim1].values
783+
h_self_x_e_other = fields_self["H" + dim1].values * fields_other["E" + dim2].values
784+
h_self_x_e_other -= fields_self["H" + dim2].values * fields_other["E" + dim1].values
785+
integrand = xr.DataArray(
786+
e_self_x_h_other - h_self_x_e_other, coords=fields_self["E" + dim1].coords
787+
)
788+
integrand *= d_area
789+
else:
790+
# Broadcasting is needed, which may be complicated depending on the dimensions order.
791+
# Use xarray to handle robustly.
788792

789-
# Cross products of fields
790-
dim1, dim2 = self._tangential_dims
791-
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2]
792-
e_self_x_h_other -= fields_self["E" + dim2] * fields_other["H" + dim1]
793-
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2]
794-
h_self_x_e_other -= fields_self["H" + dim2] * fields_other["E" + dim1]
793+
# Drop size-1 dimensions in the other data
794+
fields_other = {key: field.squeeze(drop=True) for key, field in fields_other.items()}
795795

796-
# Integrate over plane
797-
d_area = self._diff_area
798-
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
796+
# Cross products of fields
797+
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2]
798+
e_self_x_h_other -= fields_self["E" + dim2] * fields_other["H" + dim1]
799+
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2]
800+
h_self_x_e_other -= fields_self["H" + dim2] * fields_other["E" + dim1]
801+
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
799802

803+
# Integrate over plane
800804
return ModeAmpsDataArray(0.25 * integrand.sum(dim=d_area.dims))
801805

806+
def _tangential_fields_match_coords(self, coords: ArrayFloat2D) -> bool:
807+
"""Check if the tangential fields already match given coords in the tangential plane."""
808+
for field in self._tangential_fields.values():
809+
for idim, dim in enumerate(self._tangential_dims):
810+
if field.coords[dim].values.size != coords[idim].size or not np.all(
811+
field.coords[dim].values == coords[idim]
812+
):
813+
return False
814+
return True
815+
802816
def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, DataArray]:
803817
"""For 2D monitors, interpolate this fields to given coords in the tangential plane.
804818
@@ -813,6 +827,10 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
813827
"""
814828
fields = self._tangential_fields
815829

830+
# If coords already match, just return the tangential fields directly.
831+
if self._tangential_fields_match_coords(coords):
832+
return fields
833+
816834
# Interpolate if data has more than one coordinate along a dimension
817835
interp_dict = {"assume_sorted": True}
818836
# If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1713,10 +1731,12 @@ def overlap_sort(
17131731

17141732
# Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17151733
# are also normalized to 1 when the non-conjugated product is used.
1716-
if self.monitor.conjugated_dot_product:
1734+
data_expanded = self.symmetry_expanded
1735+
if data_expanded.monitor.conjugated_dot_product:
17171736
self_overlap = np.ones((num_freqs, num_modes))
17181737
else:
1719-
self_overlap = np.abs(self.dot(self, self.monitor.conjugated_dot_product).values)
1738+
self_overlap = data_expanded.dot(data_expanded, self.monitor.conjugated_dot_product)
1739+
self_overlap = np.abs(self_overlap.values)
17201740
threshold_array = overlap_thresh * self_overlap
17211741

17221742
# Compute sorting order and overlaps with neighboring frequencies
@@ -1729,20 +1749,19 @@ def overlap_sort(
17291749
# Sort in two directions from the base frequency
17301750
for step, last_ind in zip([-1, 1], [-1, num_freqs]):
17311751
# Start with the base frequency
1732-
data_template = self._isel(f=[f0_ind])
1752+
data_template = data_expanded._isel(f=[f0_ind])
17331753

17341754
# March to lower/higher frequencies
17351755
for freq_id in range(f0_ind + step, last_ind, step):
17361756
# Calculate threshold array for this frequency
1737-
if not self.monitor.conjugated_dot_product:
1757+
if not data_expanded.monitor.conjugated_dot_product:
17381758
overlap_thresh = threshold_array[freq_id, :]
17391759
# Get next frequency to sort
1740-
data_to_sort = self._isel(f=[freq_id])
1760+
data_to_sort = data_expanded._isel(f=[freq_id])
17411761
# Assign to the base frequency so that outer_dot will compare them
17421762
data_to_sort = data_to_sort._assign_coords(f=[self.monitor.freqs[f0_ind]])
17431763

17441764
# Compute "sorting w.r.t. to neighbor" and overlap values
1745-
17461765
sorting_one_mode, amps_one_mode = data_template._find_ordering_one_freq(
17471766
data_to_sort, overlap_thresh
17481767
)
@@ -1758,8 +1777,8 @@ def overlap_sort(
17581777
for mode_ind in list(np.nonzero(overlap[freq_id, :] < overlap_thresh)[0]):
17591778
log.warning(
17601779
f"Mode '{mode_ind}' appears to undergo a discontinuous change "
1761-
f"between frequencies '{self.monitor.freqs[freq_id]}' "
1762-
f"and '{self.monitor.freqs[freq_id - step]}' "
1780+
f"between frequencies '{data_expanded.monitor.freqs[freq_id]}' "
1781+
f"and '{data_expanded.monitor.freqs[freq_id - step]}' "
17631782
f"(overlap: '{overlap[freq_id, mode_ind]:.2f}')."
17641783
)
17651784

@@ -1798,7 +1817,7 @@ def _isel(self, **isel_kwargs: Any):
17981817
for key, field in update_dict.items()
17991818
if isinstance(field, DataArray)
18001819
}
1801-
return self._updated(update=update_dict)
1820+
return self.updated_copy(**update_dict, deep=False, validate=False)
18021821

18031822
def _assign_coords(self, **assign_coords_kwargs: Any):
18041823
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1810,7 +1829,7 @@ def _assign_coords(self, **assign_coords_kwargs: Any):
18101829
update_dict = {
18111830
key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items()
18121831
}
1813-
return self._updated(update=update_dict)
1832+
return self.updated_copy(**update_dict, deep=False, validate=False)
18141833

18151834
def _find_ordering_one_freq(
18161835
self,
@@ -2216,20 +2235,56 @@ def _apply_mode_reorder(self, sort_inds_2d):
22162235
Array of shape (num_freqs, num_modes) where each row is the
22172236
permutation to apply to the mode_index for that frequency.
22182237
"""
2238+
sort_inds_2d = np.asarray(sort_inds_2d, dtype=int)
22192239
num_freqs, num_modes = sort_inds_2d.shape
2240+
2241+
# Fast no-op
2242+
identity = np.arange(num_modes)
2243+
if np.all(sort_inds_2d == identity[None, :]):
2244+
return self
2245+
22202246
modify_data = {}
2247+
new_mode_index_coord = identity
2248+
22212249
for key, data in self.data_arrs.items():
22222250
if "mode_index" not in data.dims or "f" not in data.dims:
22232251
continue
2224-
dims_orig = data.dims
2225-
f_coord = data.coords["f"]
2226-
slices = []
2227-
for ifreq in range(num_freqs):
2228-
sl = data.isel(f=ifreq, mode_index=sort_inds_2d[ifreq])
2229-
slices.append(sl.assign_coords(mode_index=np.arange(num_modes)))
2230-
# Concatenate along the 'f' dimension name and then restore original frequency coordinates
2231-
data = xr.concat(slices, dim="f").assign_coords(f=f_coord).transpose(*dims_orig)
2232-
modify_data[key] = data
2252+
2253+
dims_orig = tuple(data.dims)
2254+
# Preserve coords (as numpy)
2255+
coords_out = {
2256+
k: (v.values if hasattr(v, "values") else np.asarray(v))
2257+
for k, v in data.coords.items()
2258+
}
2259+
f_axis = data.get_axis_num("f")
2260+
m_axis = data.get_axis_num("mode_index")
2261+
2262+
# Move axes directly to (f, ..., mode)
2263+
src_order = (
2264+
[f_axis] + [ax for ax in range(data.ndim) if ax not in (f_axis, m_axis)] + [m_axis]
2265+
)
2266+
arr = np.moveaxis(data.data, src_order, range(data.ndim))
2267+
nf, nm = arr.shape[0], arr.shape[-1]
2268+
if nf != num_freqs or nm != num_modes:
2269+
raise DataError(
2270+
"sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2271+
)
2272+
2273+
# Apply sorting
2274+
arr2 = arr.reshape(nf, -1, nm) # (nf, Nlead, nm)
2275+
inds = sort_inds_2d[:, None, :] # (nf, 1, nm)
2276+
arr2_sorted = np.take_along_axis(arr2, inds, axis=2)
2277+
arr_sorted = arr2_sorted.reshape(arr.shape)
2278+
2279+
# Move axes back to original order
2280+
arr_sorted = np.moveaxis(arr_sorted, range(data.ndim), src_order)
2281+
2282+
# Update coords: keep f, reset mode_index to 0..num_modes-1
2283+
coords_out["mode_index"] = new_mode_index_coord
2284+
coords_out["f"] = data.coords["f"].values
2285+
2286+
modify_data[key] = DataArray(arr_sorted, coords=coords_out, dims=dims_orig)
2287+
22332288
return self.updated_copy(**modify_data)
22342289

22352290
def sort_modes(

tidy3d/components/mode/mode_solver.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,8 +1341,7 @@ def _colocate_data(self, mode_solver_data: ModeSolverData) -> ModeSolverData:
13411341
mode_solver_monitor = self.to_mode_solver_monitor(name=MODE_MONITOR_NAME)
13421342
grid_expanded = self.simulation.discretize_monitor(mode_solver_monitor)
13431343
data_dict_colocated.update({"monitor": mode_solver_monitor, "grid_expanded": grid_expanded})
1344-
mode_solver_data = mode_solver_data._updated(update=data_dict_colocated)
1345-
1344+
mode_solver_data = mode_solver_data.updated_copy(**data_dict_colocated, deep=False)
13461345
return mode_solver_data
13471346

13481347
def _normalize_modes(self, mode_solver_data: ModeSolverData) -> None:

0 commit comments

Comments
 (0)