Skip to content

Commit 5ba903c

Browse files
committed
Fix storage of L-BFGS data in state file
1 parent fa0e1d0 commit 5ba903c

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

varipeps/optimization/optimizer.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,26 @@ def autosave_function_restartable(
297297

298298
grp_l_bfgs = grp_restart_data.create_group("l_bfgs", track_order=True)
299299
grp_l_bfgs.attrs["len"] = len(l_bfgs_x_cache)
300+
if len(l_bfgs_x_cache) > 0:
301+
grp_l_bfgs.attrs["len_elems"] = len(l_bfgs_x_cache[0])
300302
for i, (x, g) in enumerate(
301303
zip(l_bfgs_x_cache, l_bfgs_grad_cache, strict=True)
302304
):
303-
grp_l_bfgs.create_dataset(
304-
f"x_{i:d}", data=x, compression="gzip", compression_opts=6
305-
)
306-
grp_l_bfgs.create_dataset(
307-
f"grad_{i:d}", data=g, compression="gzip", compression_opts=6
308-
)
305+
if len(x) != len(g) != grp_l_bfgs.attrs["len_elems"]:
306+
raise ValueError("L-BFGS list lengths mismatch.")
307+
for j in range(grp_l_bfgs.attrs["len_elems"]):
308+
grp_l_bfgs.create_dataset(
309+
f"x_{i:d}_{j:d}",
310+
data=x[j],
311+
compression="gzip",
312+
compression_opts=6,
313+
)
314+
grp_l_bfgs.create_dataset(
315+
f"grad_{i:d}_{j:d}",
316+
data=g[j],
317+
compression="gzip",
318+
compression_opts=6,
319+
)
309320

310321

311322
def _autosave_wrapper(
@@ -1330,11 +1341,17 @@ def restart_from_state_file(filename: PathLike):
13301341
restart_state["bfgs_B_inv"] = jnp.asarray(grp_restart_data["bfgs_B_inv"])
13311342
elif config.optimizer_method is Optimizing_Methods.L_BFGS:
13321343
restart_state["l_bfgs_x_cache"] = [
1333-
jnp.asarray(grp_restart_data["l_bfgs"][f"x_{i:d}"])
1344+
[
1345+
jnp.asarray(grp_restart_data["l_bfgs"][f"x_{i:d}_{j:d}"])
1346+
for j in range(grp_restart_data["l_bfgs"].attrs["len_elems"])
1347+
]
13341348
for i in range(grp_restart_data["l_bfgs"].attrs["len"])
13351349
]
13361350
restart_state["l_bfgs_grad_cache"] = [
1337-
jnp.asarray(grp_restart_data["l_bfgs"][f"grad_{i:d}"])
1351+
[
1352+
jnp.asarray(grp_restart_data["l_bfgs"][f"grad_{i:d}_{j:d}"])
1353+
for j in range(grp_restart_data["l_bfgs"].attrs["len_elems"])
1354+
]
13381355
for i in range(grp_restart_data["l_bfgs"].attrs["len"])
13391356
]
13401357

0 commit comments

Comments
 (0)