Skip to content

Commit cb50ab5

Browse files
committed
Automatically save restartable state of optimizer and add feature to restart it from that state
1 parent fc659d0 commit cb50ab5

File tree

13 files changed

+1158
-72
lines changed

13 files changed

+1158
-72
lines changed

varipeps/config.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from enum import Enum, IntEnum, auto, unique
33

4+
import numpy as np
5+
46
from jax.tree_util import register_pytree_node_class
57

68
from typing import TypeVar, Tuple, Any, Type, NoReturn
@@ -209,13 +211,13 @@ class VariPEPS_Config:
209211
ctmrg_fail_if_not_converged: bool = True
210212
ctmrg_full_projector_method: Projector_Method = Projector_Method.FISHMAN
211213
ctmrg_increase_truncation_eps: bool = True
212-
ctmrg_increase_truncation_eps_factor: float = 100
214+
ctmrg_increase_truncation_eps_factor: float = 100.0
213215
ctmrg_increase_truncation_eps_max_value: float = 1e-6
214216
ctmrg_heuristic_increase_chi: bool = True
215217
ctmrg_heuristic_increase_chi_threshold: float = 1e-6
216-
ctmrg_heuristic_increase_chi_step_size: float = 2
218+
ctmrg_heuristic_increase_chi_step_size: int = 2
217219
ctmrg_heuristic_decrease_chi: bool = True
218-
ctmrg_heuristic_decrease_chi_step_size: float = 1
220+
ctmrg_heuristic_decrease_chi_step_size: int = 1
219221

220222
# SVD
221223
svd_sign_fix_eps: float = 1e-1
@@ -274,7 +276,37 @@ def __setattr__(self, name: str, value: Any) -> NoReturn:
274276

275277
if not type(value) is field.type:
276278
if field.type is float and type(value) is int:
277-
pass
279+
value = float(value)
280+
elif (
281+
field.type is float
282+
and hasattr(value, "dtype")
283+
and (
284+
np.issubdtype(value.dtype, np.floating)
285+
or np.issubdtype(value.dtype, np.integer)
286+
)
287+
and value.size == 1
288+
):
289+
if value.ndim > 0:
290+
value = value.reshape(-1)[0]
291+
value = float(value)
292+
elif (
293+
field.type is int
294+
and hasattr(value, "dtype")
295+
and np.issubdtype(value.dtype, np.integer)
296+
and value.size == 1
297+
):
298+
if value.ndim > 0:
299+
value = value.reshape(-1)[0]
300+
value = int(value)
301+
elif (
302+
field.type is bool
303+
and hasattr(value, "dtype")
304+
and np.isdtype(value.dtype, np.bool)
305+
and value.size == 1
306+
):
307+
if value.ndim > 0:
308+
value = value.reshape(-1)[0]
309+
value = bool(value)
278310
else:
279311
raise TypeError(
280312
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."

varipeps/expectation/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
33

4+
import h5py
5+
46
import jax.numpy as jnp
57

68
from varipeps.peps import PEPS_Unit_Cell
@@ -54,3 +56,12 @@ def __call__(
5456
is applied.
5557
"""
5658
pass
59+
60+
@abstractmethod
61+
def save_to_group(self, grp: h5py.Group):
62+
pass
63+
64+
@classmethod
65+
@abstractmethod
66+
def load_from_group(cls, grp: h5py.Group):
67+
pass

varipeps/expectation/one_site.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass
22

3+
import h5py
4+
35
import jax.numpy as jnp
46
from jax import jit
57

@@ -171,3 +173,28 @@ def __call__(
171173
return result[0]
172174
else:
173175
return result
176+
177+
def save_to_group(self, grp: h5py.Group):
178+
cls = type(self)
179+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
180+
181+
grp_gates = grp.create_group("gates", track_order=True)
182+
grp_gates.attrs["len"] = len(self.gates)
183+
for i, g in enumerate(self.gates):
184+
grp_gates.create_dataset(
185+
f"gate_{i:d}", data=g, compression="gzip", compression_opts=6
186+
)
187+
188+
@classmethod
189+
def load_from_group(cls, grp: h5py.Group):
190+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
191+
raise ValueError(
192+
"The HDF5 group suggests that this is not the right class to load data from it."
193+
)
194+
195+
gates = tuple(
196+
jnp.asarray(grp["gates"][f"gate_{i:d}"])
197+
for i in range(grp["gates"].attrs["len"])
198+
)
199+
200+
return cls(gates=gates)

varipeps/expectation/two_sites.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from functools import partial
33

4+
import h5py
5+
46
import numpy as np
57

68
import jax.numpy as jnp
@@ -699,3 +701,62 @@ def __call__(
699701
return result[0]
700702
else:
701703
return result
704+
705+
def save_to_group(self, grp: h5py.Group):
706+
cls = type(self)
707+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
708+
709+
grp_gates = grp.create_group("gates", track_order=True)
710+
grp_gates.attrs["len"] = len(self.gates)
711+
for i, (h_g, v_g) in enumerate(
712+
zip(self.horizontal_gates, self.vertical_gates, strict=True)
713+
):
714+
grp_gates.create_dataset(
715+
f"horizontal_gate_{i:d}",
716+
data=h_g,
717+
compression="gzip",
718+
compression_opts=6,
719+
)
720+
grp_gates.create_dataset(
721+
f"vertical_gate_{i:d}", data=v_g, compression="gzip", compression_opts=6
722+
)
723+
724+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
725+
726+
if self.is_spiral_peps:
727+
grp.create_dataset(
728+
"spiral_unitary_operator",
729+
data=self.spiral_unitary_operator,
730+
compression="gzip",
731+
compression_opts=6,
732+
)
733+
734+
@classmethod
735+
def load_from_group(cls, grp: h5py.Group):
736+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
737+
raise ValueError(
738+
"The HDF5 group suggests that this is not the right class to load data from it."
739+
)
740+
741+
horizontal_gates = tuple(
742+
jnp.asarray(grp["gates"][f"horizontal_gate_{i:d}"])
743+
for i in range(grp["gates"].attrs["len"])
744+
)
745+
vertical_gates = tuple(
746+
jnp.asarray(grp["gates"][f"vertical_gate_{i:d}"])
747+
for i in range(grp["gates"].attrs["len"])
748+
)
749+
750+
is_spiral_peps = grp.attrs["is_spiral_peps"]
751+
752+
if is_spiral_peps:
753+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
754+
else:
755+
spiral_unitary_operator = None
756+
757+
return cls(
758+
horizontal_gates=horizontal_gates,
759+
vertical_gates=vertical_gates,
760+
is_spiral_peps=is_spiral_peps,
761+
spiral_unitary_operator=spiral_unitary_operator,
762+
)

varipeps/mapping/florett_pentagon.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,3 +1249,74 @@ def __call__(
12491249
return result, single_gates_result
12501250
else:
12511251
return result
1252+
1253+
def save_to_group(self, grp: h5py.Group):
1254+
cls = type(self)
1255+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
1256+
1257+
grp_gates = grp.create_group("gates", track_order=True)
1258+
grp_gates.attrs["len"] = len(self.green_gates)
1259+
for i, (g_g, b_g, bl_g) in enumerate(
1260+
zip(self.green_gates, self.blue_gates, self.black_gates, strict=True)
1261+
):
1262+
grp_gates.create_dataset(
1263+
f"green_gate_{i:d}",
1264+
data=g_g,
1265+
compression="gzip",
1266+
compression_opts=6,
1267+
)
1268+
grp_gates.create_dataset(
1269+
f"blue_gate_{i:d}", data=b_g, compression="gzip", compression_opts=6
1270+
)
1271+
grp_gates.create_dataset(
1272+
f"black_gate_{i:d}", data=bl_g, compression="gzip", compression_opts=6
1273+
)
1274+
1275+
grp.attrs["real_d"] = self.real_d
1276+
grp.attrs["normalization_factor"] = self.normalization_factor
1277+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
1278+
1279+
if self.is_spiral_peps:
1280+
grp.create_dataset(
1281+
"spiral_unitary_operator",
1282+
data=self.spiral_unitary_operator,
1283+
compression="gzip",
1284+
compression_opts=6,
1285+
)
1286+
1287+
@classmethod
1288+
def load_from_group(cls, grp: h5py.Group):
1289+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
1290+
raise ValueError(
1291+
"The HDF5 group suggests that this is not the right class to load data from it."
1292+
)
1293+
1294+
green_gates = tuple(
1295+
jnp.asarray(grp["gates"][f"green_gate_{i:d}"])
1296+
for i in range(grp["gates"].attrs["len"])
1297+
)
1298+
blue_gates = tuple(
1299+
jnp.asarray(grp["gates"][f"blue_gate_{i:d}"])
1300+
for i in range(grp["gates"].attrs["len"])
1301+
)
1302+
black_gates = tuple(
1303+
jnp.asarray(grp["gates"][f"black_gate_{i:d}"])
1304+
for i in range(grp["gates"].attrs["len"])
1305+
)
1306+
1307+
is_spiral_peps = grp.attrs["is_spiral_peps"]
1308+
1309+
if is_spiral_peps:
1310+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
1311+
else:
1312+
spiral_unitary_operator = None
1313+
1314+
return cls(
1315+
green_gates=green_gates,
1316+
blue_gates=blue_gates,
1317+
black_gates=black_gates,
1318+
real_d=grp.attrs["real_d"],
1319+
normalization_factor=grp.attrs["normalization_factor"],
1320+
is_spiral_peps=is_spiral_peps,
1321+
spiral_unitary_operator=spiral_unitary_operator,
1322+
)

varipeps/mapping/honeycomb.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,77 @@ def __call__(
304304
else:
305305
return result
306306

307+
def save_to_group(self, grp: h5py.Group):
308+
cls = type(self)
309+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
310+
311+
grp_gates = grp.create_group("gates", track_order=True)
312+
grp_gates.attrs["len"] = len(self.x_gates)
313+
for i, (x_g, y_g, z_g) in enumerate(
314+
zip(self.x_gates, self.y_gates, self.z_gates, strict=True)
315+
):
316+
grp_gates.create_dataset(
317+
f"x_gate_{i:d}",
318+
data=x_g,
319+
compression="gzip",
320+
compression_opts=6,
321+
)
322+
grp_gates.create_dataset(
323+
f"y_gate_{i:d}", data=y_g, compression="gzip", compression_opts=6
324+
)
325+
grp_gates.create_dataset(
326+
f"z_gate_{i:d}", data=z_g, compression="gzip", compression_opts=6
327+
)
328+
329+
grp.attrs["real_d"] = self.real_d
330+
grp.attrs["normalization_factor"] = self.normalization_factor
331+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
332+
333+
if self.is_spiral_peps:
334+
grp.create_dataset(
335+
"spiral_unitary_operator",
336+
data=self.spiral_unitary_operator,
337+
compression="gzip",
338+
compression_opts=6,
339+
)
340+
341+
@classmethod
342+
def load_from_group(cls, grp: h5py.Group):
343+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
344+
raise ValueError(
345+
"The HDF5 group suggests that this is not the right class to load data from it."
346+
)
347+
348+
x_gates = tuple(
349+
jnp.asarray(grp["gates"][f"x_gate_{i:d}"])
350+
for i in range(grp["gates"].attrs["len"])
351+
)
352+
y_gates = tuple(
353+
jnp.asarray(grp["gates"][f"y_gate_{i:d}"])
354+
for i in range(grp["gates"].attrs["len"])
355+
)
356+
z_gates = tuple(
357+
jnp.asarray(grp["gates"][f"z_gate_{i:d}"])
358+
for i in range(grp["gates"].attrs["len"])
359+
)
360+
361+
is_spiral_peps = grp.attrs["is_spiral_peps"]
362+
363+
if is_spiral_peps:
364+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
365+
else:
366+
spiral_unitary_operator = None
367+
368+
return cls(
369+
x_gates=x_gates,
370+
y_gates=y_gates,
371+
z_gates=z_gates,
372+
real_d=grp.attrs["real_d"],
373+
normalization_factor=grp.attrs["normalization_factor"],
374+
is_spiral_peps=is_spiral_peps,
375+
spiral_unitary_operator=spiral_unitary_operator,
376+
)
377+
307378

308379
@dataclass
309380
class Honeycomb_Map_To_Square(Map_To_PEPS_Model):

0 commit comments

Comments
 (0)