@@ -1249,3 +1249,74 @@ def __call__(
12491249 return result , single_gates_result
12501250 else :
12511251 return result
1252+
1253+ def save_to_group (self , grp : h5py .Group ):
1254+ cls = type (self )
1255+ grp .attrs ["class" ] = f"{ cls .__module__ } .{ cls .__qualname__ } "
1256+
1257+ grp_gates = grp .create_group ("gates" , track_order = True )
1258+ grp_gates .attrs ["len" ] = len (self .green_gates )
1259+ for i , (g_g , b_g , bl_g ) in enumerate (
1260+ zip (self .green_gates , self .blue_gates , self .black_gates , strict = True )
1261+ ):
1262+ grp_gates .create_dataset (
1263+ f"green_gate_{ i :d} " ,
1264+ data = g_g ,
1265+ compression = "gzip" ,
1266+ compression_opts = 6 ,
1267+ )
1268+ grp_gates .create_dataset (
1269+ f"blue_gate_{ i :d} " , data = b_g , compression = "gzip" , compression_opts = 6
1270+ )
1271+ grp_gates .create_dataset (
1272+ f"black_gate_{ i :d} " , data = bl_g , compression = "gzip" , compression_opts = 6
1273+ )
1274+
1275+ grp .attrs ["real_d" ] = self .real_d
1276+ grp .attrs ["normalization_factor" ] = self .normalization_factor
1277+ grp .attrs ["is_spiral_peps" ] = self .is_spiral_peps
1278+
1279+ if self .is_spiral_peps :
1280+ grp .create_dataset (
1281+ "spiral_unitary_operator" ,
1282+ data = self .spiral_unitary_operator ,
1283+ compression = "gzip" ,
1284+ compression_opts = 6 ,
1285+ )
1286+
1287+ @classmethod
1288+ def load_from_group (cls , grp : h5py .Group ):
1289+ if not grp .attrs ["class" ] == f"{ cls .__module__ } .{ cls .__qualname__ } " :
1290+ raise ValueError (
1291+ "The HDF5 group suggests that this is not the right class to load data from it."
1292+ )
1293+
1294+ green_gates = tuple (
1295+ jnp .asarray (grp ["gates" ][f"green_gate_{ i :d} " ])
1296+ for i in range (grp ["gates" ].attrs ["len" ])
1297+ )
1298+ blue_gates = tuple (
1299+ jnp .asarray (grp ["gates" ][f"blue_gate_{ i :d} " ])
1300+ for i in range (grp ["gates" ].attrs ["len" ])
1301+ )
1302+ black_gates = tuple (
1303+ jnp .asarray (grp ["gates" ][f"black_gate_{ i :d} " ])
1304+ for i in range (grp ["gates" ].attrs ["len" ])
1305+ )
1306+
1307+ is_spiral_peps = grp .attrs ["is_spiral_peps" ]
1308+
1309+ if is_spiral_peps :
1310+ spiral_unitary_operator = jnp .asarray (grp ["spiral_unitary_operator" ])
1311+ else :
1312+ spiral_unitary_operator = None
1313+
1314+ return cls (
1315+ green_gates = green_gates ,
1316+ blue_gates = blue_gates ,
1317+ black_gates = black_gates ,
1318+ real_d = grp .attrs ["real_d" ],
1319+ normalization_factor = grp .attrs ["normalization_factor" ],
1320+ is_spiral_peps = is_spiral_peps ,
1321+ spiral_unitary_operator = spiral_unitary_operator ,
1322+ )
0 commit comments