Skip to content

Commit 4de3fd5

Browse files
committed
Implement storage funcs for triangular CTMRG expectation functions
1 parent 35c2b0c commit 4de3fd5

File tree

4 files changed

+346
-0
lines changed

4 files changed

+346
-0
lines changed

varipeps/expectation/triangular_next_nearest.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from functools import partial
33

4+
import h5py
5+
46
import jax
57
import jax.numpy as jnp
68
from jax import jit
@@ -568,3 +570,128 @@ def __call__(
568570
return result[0]
569571
else:
570572
return result
573+
574+
def save_to_group(self, grp: h5py.Group):
575+
cls = type(self)
576+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
577+
578+
grp_gates = grp.create_group("gates", track_order=True)
579+
grp_gates.attrs["len"] = len(self.nearest_horizontal_gates)
580+
for i, (
581+
h_g,
582+
v_g,
583+
d_g,
584+
nn_neg_pos_g,
585+
nn_pos_2_pos_g,
586+
nn_2_pos_pos_g,
587+
) in enumerate(
588+
zip(
589+
self.nearest_horizontal_gates,
590+
self.nearest_vertical_gates,
591+
self.nearest_diagonal_gates,
592+
self.next_nearest_neg_x_pos_y_gates,
593+
self.next_nearest_pos_x_2_pos_y_gates,
594+
self.next_nearest_2_pos_x_pos_y_gates,
595+
strict=True,
596+
)
597+
):
598+
grp_gates.create_dataset(
599+
f"nearest_horizontal_gate_{i:d}",
600+
data=h_g,
601+
compression="gzip",
602+
compression_opts=6,
603+
)
604+
grp_gates.create_dataset(
605+
f"nearest_vertical_gate_{i:d}",
606+
data=v_g,
607+
compression="gzip",
608+
compression_opts=6,
609+
)
610+
grp_gates.create_dataset(
611+
f"nearest_diagonal_gate_{i:d}",
612+
data=d_g,
613+
compression="gzip",
614+
compression_opts=6,
615+
)
616+
grp_gates.create_dataset(
617+
f"next_nearest_neg_x_pos_y_gate_{i:d}",
618+
data=nn_neg_pos_g,
619+
compression="gzip",
620+
compression_opts=6,
621+
)
622+
grp_gates.create_dataset(
623+
f"next_nearest_pos_x_2_pos_y_gate_{i:d}",
624+
data=nn_pos_2_pos_g,
625+
compression="gzip",
626+
compression_opts=6,
627+
)
628+
grp_gates.create_dataset(
629+
f"next_nearest_2_pos_x_pos_y_gate_{i:d}",
630+
data=nn_2_pos_pos_g,
631+
compression="gzip",
632+
compression_opts=6,
633+
)
634+
635+
grp.attrs["real_d"] = self.real_d
636+
grp.attrs["normalization_factor"] = self.normalization_factor
637+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
638+
639+
if self.is_spiral_peps:
640+
grp.create_dataset(
641+
"spiral_unitary_operator",
642+
data=self.spiral_unitary_operator,
643+
compression="gzip",
644+
compression_opts=6,
645+
)
646+
647+
@classmethod
648+
def load_from_group(cls, grp: h5py.Group):
649+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
650+
raise ValueError(
651+
"The HDF5 group suggests that this is not the right class to load data from it."
652+
)
653+
654+
horizontal_gates = tuple(
655+
jnp.asarray(grp["gates"][f"nearest_horizontal_gate_{i:d}"])
656+
for i in range(grp["gates"].attrs["len"])
657+
)
658+
vertical_gates = tuple(
659+
jnp.asarray(grp["gates"][f"nearest_vertical_gate_{i:d}"])
660+
for i in range(grp["gates"].attrs["len"])
661+
)
662+
diagonal_gates = tuple(
663+
jnp.asarray(grp["gates"][f"nearest_diagonal_gate_{i:d}"])
664+
for i in range(grp["gates"].attrs["len"])
665+
)
666+
next_nearest_neg_x_pos_y_gates = tuple(
667+
jnp.asarray(grp["gates"][f"next_nearest_neg_x_pos_y_gate_{i:d}"])
668+
for i in range(grp["gates"].attrs["len"])
669+
)
670+
next_nearest_pos_x_2_pos_y_gates = tuple(
671+
jnp.asarray(grp["gates"][f"next_nearest_pos_x_2_pos_y_gate_{i:d}"])
672+
for i in range(grp["gates"].attrs["len"])
673+
)
674+
next_nearest_2_pos_x_pos_y_gates = tuple(
675+
jnp.asarray(grp["gates"][f"next_nearest_2_pos_x_pos_y_gate_{i:d}"])
676+
for i in range(grp["gates"].attrs["len"])
677+
)
678+
679+
is_spiral_peps = grp.attrs["is_spiral_peps"]
680+
681+
if is_spiral_peps:
682+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
683+
else:
684+
spiral_unitary_operator = None
685+
686+
return cls(
687+
horizontal_gates=horizontal_gates,
688+
vertical_gates=vertical_gates,
689+
diagonal_gates=diagonal_gates,
690+
next_nearest_neg_x_pos_y_gates=next_nearest_neg_x_pos_y_gates,
691+
next_nearest_pos_x_2_pos_y_gates=next_nearest_pos_x_2_pos_y_gates,
692+
next_nearest_2_pos_x_pos_y_gates=next_nearest_2_pos_x_pos_y_gates,
693+
real_d=grp.attrs["real_d"],
694+
normalization_factor=grp.attrs["normalization_factor"],
695+
is_spiral_peps=is_spiral_peps,
696+
spiral_unitary_operator=spiral_unitary_operator,
697+
)

varipeps/expectation/triangular_two_sites.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from functools import partial
33

4+
import h5py
5+
46
import jax
57
import jax.numpy as jnp
68
from jax import jit
@@ -302,3 +304,79 @@ def __call__(
302304
return result[0]
303305
else:
304306
return result
307+
308+
def save_to_group(self, grp: h5py.Group):
309+
cls = type(self)
310+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
311+
312+
grp_gates = grp.create_group("gates", track_order=True)
313+
grp_gates.attrs["len"] = len(self.horizontal_gates)
314+
for i, (h_g, v_g, d_g) in enumerate(
315+
zip(
316+
self.horizontal_gates,
317+
self.vertical_gates,
318+
self.diagonal_gates,
319+
strict=True,
320+
)
321+
):
322+
grp_gates.create_dataset(
323+
f"horizontal_gate_{i:d}",
324+
data=h_g,
325+
compression="gzip",
326+
compression_opts=6,
327+
)
328+
grp_gates.create_dataset(
329+
f"vertical_gate_{i:d}", data=v_g, compression="gzip", compression_opts=6
330+
)
331+
grp_gates.create_dataset(
332+
f"diagonal_gate_{i:d}", data=d_g, compression="gzip", compression_opts=6
333+
)
334+
335+
grp.attrs["real_d"] = self.real_d
336+
grp.attrs["normalization_factor"] = self.normalization_factor
337+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
338+
339+
if self.is_spiral_peps:
340+
grp.create_dataset(
341+
"spiral_unitary_operator",
342+
data=self.spiral_unitary_operator,
343+
compression="gzip",
344+
compression_opts=6,
345+
)
346+
347+
@classmethod
348+
def load_from_group(cls, grp: h5py.Group):
349+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
350+
raise ValueError(
351+
"The HDF5 group suggests that this is not the right class to load data from it."
352+
)
353+
354+
horizontal_gates = tuple(
355+
jnp.asarray(grp["gates"][f"horizontal_gate_{i:d}"])
356+
for i in range(grp["gates"].attrs["len"])
357+
)
358+
vertical_gates = tuple(
359+
jnp.asarray(grp["gates"][f"vertical_gate_{i:d}"])
360+
for i in range(grp["gates"].attrs["len"])
361+
)
362+
diagonal_gates = tuple(
363+
jnp.asarray(grp["gates"][f"diagonal_gate_{i:d}"])
364+
for i in range(grp["gates"].attrs["len"])
365+
)
366+
367+
is_spiral_peps = grp.attrs["is_spiral_peps"]
368+
369+
if is_spiral_peps:
370+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
371+
else:
372+
spiral_unitary_operator = None
373+
374+
return cls(
375+
horizontal_gates=horizontal_gates,
376+
vertical_gates=vertical_gates,
377+
diagonal_gates=diagonal_gates,
378+
real_d=grp.attrs["real_d"],
379+
normalization_factor=grp.attrs["normalization_factor"],
380+
is_spiral_peps=is_spiral_peps,
381+
spiral_unitary_operator=spiral_unitary_operator,
382+
)

varipeps/mapping/kagome.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,3 +2090,73 @@ def __call__(
20902090
return result, single_gates_result
20912091
else:
20922092
return result
2093+
2094+
def save_to_group(self, grp: h5py.Group):
2095+
cls = type(self)
2096+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
2097+
2098+
grp_gates = grp.create_group("gates", track_order=True)
2099+
grp_gates.attrs["len"] = len(self.up_nearest_gates)
2100+
for i, (u_g, d_g) in enumerate(
2101+
zip(
2102+
self.up_nearest_gates,
2103+
self.down_nearest_gates,
2104+
strict=True,
2105+
)
2106+
):
2107+
grp_gates.create_dataset(
2108+
f"up_nearest_gate_{i:d}",
2109+
data=u_g,
2110+
compression="gzip",
2111+
compression_opts=6,
2112+
)
2113+
grp_gates.create_dataset(
2114+
f"down_nearest_gate_{i:d}",
2115+
data=d_g,
2116+
compression="gzip",
2117+
compression_opts=6,
2118+
)
2119+
2120+
grp.attrs["real_d"] = self.real_d
2121+
grp.attrs["normalization_factor"] = self.normalization_factor
2122+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
2123+
2124+
if self.is_spiral_peps:
2125+
grp.create_dataset(
2126+
"spiral_unitary_operator",
2127+
data=self.spiral_unitary_operator,
2128+
compression="gzip",
2129+
compression_opts=6,
2130+
)
2131+
2132+
@classmethod
2133+
def load_from_group(cls, grp: h5py.Group):
2134+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
2135+
raise ValueError(
2136+
"The HDF5 group suggests that this is not the right class to load data from it."
2137+
)
2138+
2139+
up_nearest_gates = tuple(
2140+
jnp.asarray(grp["gates"][f"up_nearest_gate_{i:d}"])
2141+
for i in range(grp["gates"].attrs["len"])
2142+
)
2143+
down_nearest_gates = tuple(
2144+
jnp.asarray(grp["gates"][f"down_nearest_gate_{i:d}"])
2145+
for i in range(grp["gates"].attrs["len"])
2146+
)
2147+
2148+
is_spiral_peps = grp.attrs["is_spiral_peps"]
2149+
2150+
if is_spiral_peps:
2151+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
2152+
else:
2153+
spiral_unitary_operator = None
2154+
2155+
return cls(
2156+
up_nearest_gates=up_nearest_gates,
2157+
down_nearest_gates=down_nearest_gates,
2158+
real_d=grp.attrs["real_d"],
2159+
normalization_factor=grp.attrs["normalization_factor"],
2160+
is_spiral_peps=is_spiral_peps,
2161+
spiral_unitary_operator=spiral_unitary_operator,
2162+
)

varipeps/mapping/maple_leaf.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,3 +2330,74 @@ def __call__(
23302330
return result, single_gates_result
23312331
else:
23322332
return result
2333+
2334+
def save_to_group(self, grp: h5py.Group):
2335+
cls = type(self)
2336+
grp.attrs["class"] = f"{cls.__module__}.{cls.__qualname__}"
2337+
2338+
grp_gates = grp.create_group("gates", track_order=True)
2339+
grp_gates.attrs["len"] = len(self.green_gates)
2340+
for i, (g_g, b_g, r_g) in enumerate(
2341+
zip(self.green_gates, self.blue_gates, self.red_gates, strict=True)
2342+
):
2343+
grp_gates.create_dataset(
2344+
f"green_gate_{i:d}",
2345+
data=g_g,
2346+
compression="gzip",
2347+
compression_opts=6,
2348+
)
2349+
grp_gates.create_dataset(
2350+
f"blue_gate_{i:d}", data=b_g, compression="gzip", compression_opts=6
2351+
)
2352+
grp_gates.create_dataset(
2353+
f"red_gate_{i:d}", data=r_g, compression="gzip", compression_opts=6
2354+
)
2355+
2356+
grp.attrs["real_d"] = self.real_d
2357+
grp.attrs["normalization_factor"] = self.normalization_factor
2358+
grp.attrs["is_spiral_peps"] = self.is_spiral_peps
2359+
2360+
if self.is_spiral_peps:
2361+
grp.create_dataset(
2362+
"spiral_unitary_operator",
2363+
data=self.spiral_unitary_operator,
2364+
compression="gzip",
2365+
compression_opts=6,
2366+
)
2367+
2368+
@classmethod
2369+
def load_from_group(cls, grp: h5py.Group):
2370+
if not grp.attrs["class"] == f"{cls.__module__}.{cls.__qualname__}":
2371+
raise ValueError(
2372+
"The HDF5 group suggests that this is not the right class to load data from it."
2373+
)
2374+
2375+
green_gates = tuple(
2376+
jnp.asarray(grp["gates"][f"green_gate_{i:d}"])
2377+
for i in range(grp["gates"].attrs["len"])
2378+
)
2379+
blue_gates = tuple(
2380+
jnp.asarray(grp["gates"][f"blue_gate_{i:d}"])
2381+
for i in range(grp["gates"].attrs["len"])
2382+
)
2383+
red_gates = tuple(
2384+
jnp.asarray(grp["gates"][f"red_gate_{i:d}"])
2385+
for i in range(grp["gates"].attrs["len"])
2386+
)
2387+
2388+
is_spiral_peps = grp.attrs["is_spiral_peps"]
2389+
2390+
if is_spiral_peps:
2391+
spiral_unitary_operator = jnp.asarray(grp["spiral_unitary_operator"])
2392+
else:
2393+
spiral_unitary_operator = None
2394+
2395+
return cls(
2396+
green_gates=green_gates,
2397+
blue_gates=blue_gates,
2398+
red_gates=red_gates,
2399+
real_d=grp.attrs["real_d"],
2400+
normalization_factor=grp.attrs["normalization_factor"],
2401+
is_spiral_peps=is_spiral_peps,
2402+
spiral_unitary_operator=spiral_unitary_operator,
2403+
)

0 commit comments

Comments
 (0)