@@ -248,21 +248,28 @@ def autosave_function_restartable(
248248 if convert_to_unitcell_func is not None :
249249 pass
250250
251- grp_old_grad = grp_restart_data .create_group ("old_gradient" , track_order = True )
252- grp_old_grad .attrs ["len" ] = len (old_gradient )
253- for i , g in enumerate (old_gradient ):
254- grp_old_grad .create_dataset (
255- f"old_grad_{ i :d} " , data = g , compression = "gzip" , compression_opts = 6
251+ if old_gradient is not None :
252+ grp_old_grad = grp_restart_data .create_group (
253+ "old_gradient" , track_order = True
256254 )
255+ grp_old_grad .attrs ["len" ] = len (old_gradient )
256+ for i , g in enumerate (old_gradient ):
257+ grp_old_grad .create_dataset (
258+ f"old_grad_{ i :d} " , data = g , compression = "gzip" , compression_opts = 6
259+ )
257260
258- grp_old_des_dir = grp_restart_data .create_group (
259- "old_descent_dir" , track_order = True
260- )
261- grp_old_des_dir .attrs ["len" ] = len (old_descent_dir )
262- for i , d in enumerate (old_descent_dir ):
263- grp_old_des_dir .create_dataset (
264- f"old_descent_dir_{ i :d} " , data = d , compression = "gzip" , compression_opts = 6
261+ if old_descent_dir is not None :
262+ grp_old_des_dir = grp_restart_data .create_group (
263+ "old_descent_dir" , track_order = True
265264 )
265+ grp_old_des_dir .attrs ["len" ] = len (old_descent_dir )
266+ for i , d in enumerate (old_descent_dir ):
267+ grp_old_des_dir .create_dataset (
268+ f"old_descent_dir_{ i :d} " ,
269+ data = d ,
270+ compression = "gzip" ,
271+ compression_opts = 6 ,
272+ )
266273
267274 if best_unitcell is not None :
268275 grp_best_t = grp_restart_data .create_group ("best_tensors" , track_order = True )
@@ -1281,14 +1288,23 @@ def restart_from_state_file(filename: PathLike):
12811288
12821289 restart_state = {}
12831290
1284- restart_state ["old_gradient" ] = [
1285- jnp .asarray (grp_restart_data ["old_gradient" ][f"old_grad_{ i :d} " ])
1286- for i in range (grp_restart_data ["old_gradient" ].attrs ["len" ])
1287- ]
1288- restart_state ["old_descent_dir" ] = [
1289- jnp .asarray (grp_restart_data ["old_descent_dir" ][f"old_descent_dir_{ i :d} " ])
1290- for i in range (grp_restart_data ["old_descent_dir" ].attrs ["len" ])
1291- ]
1291+ if grp_restart_data .get ("old_gradient" ) is not None :
1292+ restart_state ["old_gradient" ] = [
1293+ jnp .asarray (grp_restart_data ["old_gradient" ][f"old_grad_{ i :d} " ])
1294+ for i in range (grp_restart_data ["old_gradient" ].attrs ["len" ])
1295+ ]
1296+ else :
1297+ restart_state ["old_gradient" ] = None
1298+
1299+ if grp_restart_data .get ("old_descent_dir" ) is not None :
1300+ restart_state ["old_descent_dir" ] = [
1301+ jnp .asarray (
1302+ grp_restart_data ["old_descent_dir" ][f"old_descent_dir_{ i :d} " ]
1303+ )
1304+ for i in range (grp_restart_data ["old_descent_dir" ].attrs ["len" ])
1305+ ]
1306+ else :
1307+ restart_state ["old_descent_dir" ] = None
12921308
12931309 restart_state ["best_run" ] = auxiliary_data ["best_run" ]
12941310
0 commit comments