Skip to content

Commit 7489852

Browse files
committed
Fix storage and loading of old gradient and descent dir from state file if they are None
1 parent 5ba903c commit 7489852

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

varipeps/optimization/optimizer.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -248,21 +248,28 @@ def autosave_function_restartable(
248248
if convert_to_unitcell_func is not None:
249249
pass
250250

251-
grp_old_grad = grp_restart_data.create_group("old_gradient", track_order=True)
252-
grp_old_grad.attrs["len"] = len(old_gradient)
253-
for i, g in enumerate(old_gradient):
254-
grp_old_grad.create_dataset(
255-
f"old_grad_{i:d}", data=g, compression="gzip", compression_opts=6
251+
if old_gradient is not None:
252+
grp_old_grad = grp_restart_data.create_group(
253+
"old_gradient", track_order=True
256254
)
255+
grp_old_grad.attrs["len"] = len(old_gradient)
256+
for i, g in enumerate(old_gradient):
257+
grp_old_grad.create_dataset(
258+
f"old_grad_{i:d}", data=g, compression="gzip", compression_opts=6
259+
)
257260

258-
grp_old_des_dir = grp_restart_data.create_group(
259-
"old_descent_dir", track_order=True
260-
)
261-
grp_old_des_dir.attrs["len"] = len(old_descent_dir)
262-
for i, d in enumerate(old_descent_dir):
263-
grp_old_des_dir.create_dataset(
264-
f"old_descent_dir_{i:d}", data=d, compression="gzip", compression_opts=6
261+
if old_descent_dir is not None:
262+
grp_old_des_dir = grp_restart_data.create_group(
263+
"old_descent_dir", track_order=True
265264
)
265+
grp_old_des_dir.attrs["len"] = len(old_descent_dir)
266+
for i, d in enumerate(old_descent_dir):
267+
grp_old_des_dir.create_dataset(
268+
f"old_descent_dir_{i:d}",
269+
data=d,
270+
compression="gzip",
271+
compression_opts=6,
272+
)
266273

267274
if best_unitcell is not None:
268275
grp_best_t = grp_restart_data.create_group("best_tensors", track_order=True)
@@ -1281,14 +1288,23 @@ def restart_from_state_file(filename: PathLike):
12811288

12821289
restart_state = {}
12831290

1284-
restart_state["old_gradient"] = [
1285-
jnp.asarray(grp_restart_data["old_gradient"][f"old_grad_{i:d}"])
1286-
for i in range(grp_restart_data["old_gradient"].attrs["len"])
1287-
]
1288-
restart_state["old_descent_dir"] = [
1289-
jnp.asarray(grp_restart_data["old_descent_dir"][f"old_descent_dir_{i:d}"])
1290-
for i in range(grp_restart_data["old_descent_dir"].attrs["len"])
1291-
]
1291+
if grp_restart_data.get("old_gradient") is not None:
1292+
restart_state["old_gradient"] = [
1293+
jnp.asarray(grp_restart_data["old_gradient"][f"old_grad_{i:d}"])
1294+
for i in range(grp_restart_data["old_gradient"].attrs["len"])
1295+
]
1296+
else:
1297+
restart_state["old_gradient"] = None
1298+
1299+
if grp_restart_data.get("old_descent_dir") is not None:
1300+
restart_state["old_descent_dir"] = [
1301+
jnp.asarray(
1302+
grp_restart_data["old_descent_dir"][f"old_descent_dir_{i:d}"]
1303+
)
1304+
for i in range(grp_restart_data["old_descent_dir"].attrs["len"])
1305+
]
1306+
else:
1307+
restart_state["old_descent_dir"] = None
12921308

12931309
restart_state["best_run"] = auxiliary_data["best_run"]
12941310

0 commit comments

Comments
 (0)