@@ -297,15 +297,26 @@ def autosave_function_restartable(
297297
298298 grp_l_bfgs = grp_restart_data .create_group ("l_bfgs" , track_order = True )
299299 grp_l_bfgs .attrs ["len" ] = len (l_bfgs_x_cache )
300+ if len (l_bfgs_x_cache ) > 0 :
301+ grp_l_bfgs .attrs ["len_elems" ] = len (l_bfgs_x_cache [0 ])
300302 for i , (x , g ) in enumerate (
301303 zip (l_bfgs_x_cache , l_bfgs_grad_cache , strict = True )
302304 ):
303- grp_l_bfgs .create_dataset (
304- f"x_{ i :d} " , data = x , compression = "gzip" , compression_opts = 6
305- )
306- grp_l_bfgs .create_dataset (
307- f"grad_{ i :d} " , data = g , compression = "gzip" , compression_opts = 6
308- )
305+ if len (x ) != len (g ) != grp_l_bfgs .attrs ["len_elems" ]:
306+ raise ValueError ("L-BFGS list lengths mismatch." )
307+ for j in range (grp_l_bfgs .attrs ["len_elems" ]):
308+ grp_l_bfgs .create_dataset (
309+ f"x_{ i :d} _{ j :d} " ,
310+ data = x [j ],
311+ compression = "gzip" ,
312+ compression_opts = 6 ,
313+ )
314+ grp_l_bfgs .create_dataset (
315+ f"grad_{ i :d} _{ j :d} " ,
316+ data = g [j ],
317+ compression = "gzip" ,
318+ compression_opts = 6 ,
319+ )
309320
310321
311322def _autosave_wrapper (
@@ -1330,11 +1341,17 @@ def restart_from_state_file(filename: PathLike):
13301341 restart_state ["bfgs_B_inv" ] = jnp .asarray (grp_restart_data ["bfgs_B_inv" ])
13311342 elif config .optimizer_method is Optimizing_Methods .L_BFGS :
13321343 restart_state ["l_bfgs_x_cache" ] = [
1333- jnp .asarray (grp_restart_data ["l_bfgs" ][f"x_{ i :d} " ])
1344+ [
1345+ jnp .asarray (grp_restart_data ["l_bfgs" ][f"x_{ i :d} _{ j :d} " ])
1346+ for j in range (grp_restart_data ["l_bfgs" ].attrs ["len_elems" ])
1347+ ]
13341348 for i in range (grp_restart_data ["l_bfgs" ].attrs ["len" ])
13351349 ]
13361350 restart_state ["l_bfgs_grad_cache" ] = [
1337- jnp .asarray (grp_restart_data ["l_bfgs" ][f"grad_{ i :d} " ])
1351+ [
1352+ jnp .asarray (grp_restart_data ["l_bfgs" ][f"grad_{ i :d} _{ j :d} " ])
1353+ for j in range (grp_restart_data ["l_bfgs" ].attrs ["len_elems" ])
1354+ ]
13381355 for i in range (grp_restart_data ["l_bfgs" ].attrs ["len" ])
13391356 ]
13401357
0 commit comments