@@ -151,20 +151,6 @@ def amplitude_fn(freq: list[float]) -> complex:
151151
152152 return self .normalize (amplitude_fn )
153153
154- def _updated (self , update : dict ) -> MonitorData :
155- """Similar to ``updated_copy``, but does not actually copy components, for speed.
156-
157- Note
158- ----
159- This does **not** produce a copy of mutable objects, so e.g. if some of the data arrays
160- are not updated, they will point to the values in the original data. This method should
161- thus be used carefully.
162-
163- """
164- data_dict = self .dict ()
165- data_dict .update (update )
166- return type (self ).parse_obj (data_dict )
167-
168154 def _make_adjoint_sources (self , dataset_names : list [str ], fwidth : float ) -> list [Source ]:
169155 """Generate adjoint sources for this ``MonitorData`` instance."""
170156
@@ -263,7 +249,7 @@ def symmetry_expanded(self):
263249 if all (sym == 0 for sym in self .symmetry ):
264250 return self
265251
266- return self ._updated ( self ._symmetry_update_dict )
252+ return self .updated_copy ( ** self ._symmetry_update_dict , deep = False , validate = False )
267253
268254 @property
269255 def symmetry_expanded_copy (self ) -> AbstractFieldData :
@@ -782,23 +768,51 @@ def dot(
782768 fields_self = {key : field .conj () for key , field in fields_self .items ()}
783769
784770 fields_other = field_data ._interpolated_tangential_fields (self ._plane_grid_boundaries )
771+ dim1 , dim2 = self ._tangential_dims
772+ d_area = self ._diff_area
785773
786- # Drop size-1 dimensions in the other data
787- fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
774+ # After interpolation, the tangential coordinates should match. However, the two arrays
775+ # may either have the same shape along other dimensions, or be broadcastable.
776+ if (
777+ fields_self [next (iter (fields_self ))].shape
778+ == fields_other [next (iter (fields_other ))].shape
779+ ):
780+ # Arrays are same shape, so we can use numpy
781+ e_self_x_h_other = fields_self ["E" + dim1 ].values * fields_other ["H" + dim2 ].values
782+ e_self_x_h_other -= fields_self ["E" + dim2 ].values * fields_other ["H" + dim1 ].values
783+ h_self_x_e_other = fields_self ["H" + dim1 ].values * fields_other ["E" + dim2 ].values
784+ h_self_x_e_other -= fields_self ["H" + dim2 ].values * fields_other ["E" + dim1 ].values
785+ integrand = xr .DataArray (
786+ e_self_x_h_other - h_self_x_e_other , coords = fields_self ["E" + dim1 ].coords
787+ )
788+ integrand *= d_area
789+ else :
790+ # Broadcasting is needed, which may be complicated depending on the dimensions order.
791+ # Use xarray to handle robustly.
788792
789- # Cross products of fields
790- dim1 , dim2 = self ._tangential_dims
791- e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
792- e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
793- h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
794- h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
793+ # Drop size-1 dimensions in the other data
794+ fields_other = {key : field .squeeze (drop = True ) for key , field in fields_other .items ()}
795795
796- # Integrate over plane
797- d_area = self ._diff_area
798- integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
796+ # Cross products of fields
797+ e_self_x_h_other = fields_self ["E" + dim1 ] * fields_other ["H" + dim2 ]
798+ e_self_x_h_other -= fields_self ["E" + dim2 ] * fields_other ["H" + dim1 ]
799+ h_self_x_e_other = fields_self ["H" + dim1 ] * fields_other ["E" + dim2 ]
800+ h_self_x_e_other -= fields_self ["H" + dim2 ] * fields_other ["E" + dim1 ]
801+ integrand = (e_self_x_h_other - h_self_x_e_other ) * d_area
799802
803+ # Integrate over plane
800804 return ModeAmpsDataArray (0.25 * integrand .sum (dim = d_area .dims ))
801805
806+ def _tangential_fields_match_coords (self , coords : ArrayFloat2D ) -> bool :
807+ """Check if the tangential fields already match given coords in the tangential plane."""
808+ for field in self ._tangential_fields .values ():
809+ for idim , dim in enumerate (self ._tangential_dims ):
810+ if field .coords [dim ].values .size != coords [idim ].size or not np .all (
811+ field .coords [dim ].values == coords [idim ]
812+ ):
813+ return False
814+ return True
815+
802816 def _interpolated_tangential_fields (self , coords : ArrayFloat2D ) -> dict [str , DataArray ]:
803817 """For 2D monitors, interpolate this fields to given coords in the tangential plane.
804818
@@ -813,6 +827,10 @@ def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, Dat
813827 """
814828 fields = self ._tangential_fields
815829
830+ # If coords already match, just return the tangential fields directly.
831+ if self ._tangential_fields_match_coords (coords ):
832+ return fields
833+
816834 # Interpolate if data has more than one coordinate along a dimension
817835 interp_dict = {"assume_sorted" : True }
818836 # If single coordinate, just sel "nearest", i.e. just propagate the same data everywhere
@@ -1713,10 +1731,12 @@ def overlap_sort(
17131731
17141732 # Normalizing the flux to 1, does not guarantee self terms of overlap integrals
17151733 # are also normalized to 1 when the non-conjugated product is used.
1716- if self .monitor .conjugated_dot_product :
1734+ data_expanded = self .symmetry_expanded
1735+ if data_expanded .monitor .conjugated_dot_product :
17171736 self_overlap = np .ones ((num_freqs , num_modes ))
17181737 else :
1719- self_overlap = np .abs (self .dot (self , self .monitor .conjugated_dot_product ).values )
1738+ self_overlap = data_expanded .dot (data_expanded , self .monitor .conjugated_dot_product )
1739+ self_overlap = np .abs (self_overlap .values )
17201740 threshold_array = overlap_thresh * self_overlap
17211741
17221742 # Compute sorting order and overlaps with neighboring frequencies
@@ -1729,20 +1749,19 @@ def overlap_sort(
17291749 # Sort in two directions from the base frequency
17301750 for step , last_ind in zip ([- 1 , 1 ], [- 1 , num_freqs ]):
17311751 # Start with the base frequency
1732- data_template = self ._isel (f = [f0_ind ])
1752+ data_template = data_expanded ._isel (f = [f0_ind ])
17331753
17341754 # March to lower/higher frequencies
17351755 for freq_id in range (f0_ind + step , last_ind , step ):
17361756 # Calculate threshold array for this frequency
1737- if not self .monitor .conjugated_dot_product :
1757+ if not data_expanded .monitor .conjugated_dot_product :
17381758 overlap_thresh = threshold_array [freq_id , :]
17391759 # Get next frequency to sort
1740- data_to_sort = self ._isel (f = [freq_id ])
1760+ data_to_sort = data_expanded ._isel (f = [freq_id ])
17411761 # Assign to the base frequency so that outer_dot will compare them
17421762 data_to_sort = data_to_sort ._assign_coords (f = [self .monitor .freqs [f0_ind ]])
17431763
17441764 # Compute "sorting w.r.t. to neighbor" and overlap values
1745-
17461765 sorting_one_mode , amps_one_mode = data_template ._find_ordering_one_freq (
17471766 data_to_sort , overlap_thresh
17481767 )
@@ -1758,8 +1777,8 @@ def overlap_sort(
17581777 for mode_ind in list (np .nonzero (overlap [freq_id , :] < overlap_thresh )[0 ]):
17591778 log .warning (
17601779 f"Mode '{ mode_ind } ' appears to undergo a discontinuous change "
1761- f"between frequencies '{ self .monitor .freqs [freq_id ]} ' "
1762- f"and '{ self .monitor .freqs [freq_id - step ]} ' "
1780+ f"between frequencies '{ data_expanded .monitor .freqs [freq_id ]} ' "
1781+ f"and '{ data_expanded .monitor .freqs [freq_id - step ]} ' "
17631782 f"(overlap: '{ overlap [freq_id , mode_ind ]:.2f} ')."
17641783 )
17651784
@@ -1798,7 +1817,7 @@ def _isel(self, **isel_kwargs: Any):
17981817 for key , field in update_dict .items ()
17991818 if isinstance (field , DataArray )
18001819 }
1801- return self ._updated ( update = update_dict )
1820+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18021821
18031822 def _assign_coords (self , ** assign_coords_kwargs : Any ):
18041823 """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
@@ -1810,7 +1829,7 @@ def _assign_coords(self, **assign_coords_kwargs: Any):
18101829 update_dict = {
18111830 key : field .assign_coords (** assign_coords_kwargs ) for key , field in update_dict .items ()
18121831 }
1813- return self ._updated ( update = update_dict )
1832+ return self .updated_copy ( ** update_dict , deep = False , validate = False )
18141833
18151834 def _find_ordering_one_freq (
18161835 self ,
@@ -2216,20 +2235,56 @@ def _apply_mode_reorder(self, sort_inds_2d):
22162235 Array of shape (num_freqs, num_modes) where each row is the
22172236 permutation to apply to the mode_index for that frequency.
22182237 """
2238+ sort_inds_2d = np .asarray (sort_inds_2d , dtype = int )
22192239 num_freqs , num_modes = sort_inds_2d .shape
2240+
2241+ # Fast no-op
2242+ identity = np .arange (num_modes )
2243+ if np .all (sort_inds_2d == identity [None , :]):
2244+ return self
2245+
22202246 modify_data = {}
2247+ new_mode_index_coord = identity
2248+
22212249 for key , data in self .data_arrs .items ():
22222250 if "mode_index" not in data .dims or "f" not in data .dims :
22232251 continue
2224- dims_orig = data .dims
2225- f_coord = data .coords ["f" ]
2226- slices = []
2227- for ifreq in range (num_freqs ):
2228- sl = data .isel (f = ifreq , mode_index = sort_inds_2d [ifreq ])
2229- slices .append (sl .assign_coords (mode_index = np .arange (num_modes )))
2230- # Concatenate along the 'f' dimension name and then restore original frequency coordinates
2231- data = xr .concat (slices , dim = "f" ).assign_coords (f = f_coord ).transpose (* dims_orig )
2232- modify_data [key ] = data
2252+
2253+ dims_orig = tuple (data .dims )
2254+ # Preserve coords (as numpy)
2255+ coords_out = {
2256+ k : (v .values if hasattr (v , "values" ) else np .asarray (v ))
2257+ for k , v in data .coords .items ()
2258+ }
2259+ f_axis = data .get_axis_num ("f" )
2260+ m_axis = data .get_axis_num ("mode_index" )
2261+
2262+ # Move axes directly to (f, ..., mode)
2263+ src_order = (
2264+ [f_axis ] + [ax for ax in range (data .ndim ) if ax not in (f_axis , m_axis )] + [m_axis ]
2265+ )
2266+ arr = np .moveaxis (data .data , src_order , range (data .ndim ))
2267+ nf , nm = arr .shape [0 ], arr .shape [- 1 ]
2268+ if nf != num_freqs or nm != num_modes :
2269+ raise DataError (
2270+ "sort_inds_2d shape does not match array shape in _apply_mode_reorder."
2271+ )
2272+
2273+ # Apply sorting
2274+ arr2 = arr .reshape (nf , - 1 , nm ) # (nf, Nlead, nm)
2275+ inds = sort_inds_2d [:, None , :] # (nf, 1, nm)
2276+ arr2_sorted = np .take_along_axis (arr2 , inds , axis = 2 )
2277+ arr_sorted = arr2_sorted .reshape (arr .shape )
2278+
2279+ # Move axes back to original order
2280+ arr_sorted = np .moveaxis (arr_sorted , range (data .ndim ), src_order )
2281+
2282+ # Update coords: keep f, reset mode_index to 0..num_modes-1
2283+ coords_out ["mode_index" ] = new_mode_index_coord
2284+ coords_out ["f" ] = data .coords ["f" ].values
2285+
2286+ modify_data [key ] = DataArray (arr_sorted , coords = coords_out , dims = dims_orig )
2287+
22332288 return self .updated_copy (** modify_data )
22342289
22352290 def sort_modes (
0 commit comments