Skip to content

Commit 8d7a255

Browse files
WIP
1 parent 9f3c9a1 commit 8d7a255

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

tests/test_components/autograd/numerical/test_autograd_polyslab_trianglemesh_numerical.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
dimension_permutation,
1717
finite_difference,
1818
make_base_simulation,
19+
run_parameter_simulations,
1920
squeeze_dimension,
2021
)
2122
from tidy3d import config
@@ -113,34 +114,34 @@ def make_polyslab_geometry(params, box_center, axis: int) -> td.PolySlab:
113114
return td.PolySlab(vertices=tuple(vertices), slab_bounds=slab_bounds, axis=axis)
114115

115116

116-
def run_parameter_simulations(
117+
def run_parameter_simulations2(
117118
parameter_sets: list[anp.ndarray],
118119
make_geometry,
119120
box_center,
120-
tag: str,
121121
base_sim: td.Simulation,
122122
fom,
123123
tmp_path,
124124
*,
125125
local_gradient: bool,
126126
):
127-
simulation_dict = {}
127+
simulations = []
128128

129-
for idx, param_values in enumerate(parameter_sets):
129+
for param_values in parameter_sets:
130130
geometry = make_geometry(param_values, box_center)
131131
structure = td.Structure(
132132
geometry=geometry,
133133
medium=td.Medium(permittivity=PERMITTIVITY),
134134
)
135135
sim = base_sim.updated_copy(structures=[structure], validate=False)
136-
simulation_dict[f"sim_{idx}"] = sim
136+
simulations.append(sim)
137137

138-
sim_data_map = web.run(
139-
simulation_dict,
138+
sim_data = web.run(
139+
simulations,
140140
local_gradient=local_gradient,
141+
path=tmp_path,
141142
verbose=VERBOSE,
142143
)
143-
sim_fom = [fom(sim_data_map[key]) for key in simulation_dict]
144+
sim_fom = [fom(data) for data in sim_data]
144145
if len(sim_fom) == 1:
145146
sim_fom = sim_fom[0]
146147
return sim_fom

tidy3d/components/data/data_array.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,28 @@ class DataArray(xr.DataArray):
7878
_data_attrs: dict[str, str] = {}
7979

8080
def __init__(self, data, *args: Any, **kwargs: Any) -> None:
81+
# convert numpy object arrays that contain autograd boxes; keep other types as-is
82+
data = self._maybe_convert_object_boxes(data)
83+
8184
# if data is a vanilla autograd box, convert to our box
8285
if isbox(data) and not is_tidy_box(data):
8386
data = TidyArrayBox.from_arraybox(data)
8487
# do the same for xr.Variable or xr.DataArray type
85-
elif (
86-
isinstance(data, (xr.Variable, xr.DataArray))
87-
and isbox(data.data)
88-
and not is_tidy_box(data.data)
89-
):
90-
data.data = TidyArrayBox.from_arraybox(data.data)
88+
elif isinstance(data, (xr.Variable, xr.DataArray)):
89+
if isbox(data.data) and not is_tidy_box(data.data):
90+
data.data = TidyArrayBox.from_arraybox(data.data)
9191
super().__init__(data, *args, **kwargs)
9292

93+
@staticmethod
94+
def _maybe_convert_object_boxes(data):
95+
"""Convert object arrays of autograd boxes into ArrayBox instances."""
96+
97+
if isinstance(data, np.ndarray) and data.dtype == np.object_ and data.size:
98+
# only convert if at least one element is an autograd tracer
99+
if any(isbox(item) for item in data.flat):
100+
return anp.array(data.tolist())
101+
return data
102+
93103
@classmethod
94104
def __get_validators__(cls):
95105
"""Validators that get run when :class:`.DataArray` objects are added to pydantic models."""

0 commit comments

Comments
 (0)