@@ -256,6 +256,7 @@ def get_onsite_gates(g_e, b_e, r_e, d):
256256 red_45 ,
257257 )
258258
259+
259260def get_onsite_gates_hexagon (b_e , d ):
260261 Id_other_sites = jnp .eye (d ** 4 )
261262
@@ -343,6 +344,7 @@ def _calc_onsite_gate(
343344
344345 return result , single_gates
345346
347+
346348@partial (jit , static_argnums = (1 , 2 ))
347349def _calc_onsite_gate_hexagon (
348350 blue_gates : Sequence [jnp .ndarray ],
@@ -353,9 +355,7 @@ def _calc_onsite_gate_hexagon(
353355
354356 single_gates = [None ] * result_length
355357
356- for i , (b_e , ) in enumerate (
357- zip (blue_gates , strict = True )
358- ):
358+ for i , (b_e ,) in enumerate (zip (blue_gates , strict = True )):
359359 (
360360 blue_12 ,
361361 blue_23 ,
@@ -365,14 +365,7 @@ def _calc_onsite_gate_hexagon(
365365 blue_61 ,
366366 ) = get_onsite_gates_hexagon (b_e , d )
367367
368- result [i ] = (
369- blue_12 +
370- blue_23 +
371- blue_34 +
372- blue_45 +
373- blue_56 +
374- blue_61
375- )
368+ result [i ] = blue_12 + blue_23 + blue_34 + blue_45 + blue_56 + blue_61
376369
377370 single_gates [i ] = (
378371 blue_12 ,
@@ -398,6 +391,7 @@ def get_right_gates(b_e, r_e, d):
398391
399392 return red_61 , blue_62
400393
394+
401395def get_right_gates_hexagon (r_e , g_e , d ):
402396 Id_other_site = jnp .eye (d ** 2 )
403397
@@ -439,6 +433,7 @@ def _calc_right_gate(
439433
440434 return result , single_gates
441435
436+
442437@partial (jit , static_argnums = (2 , 3 ))
443438def _calc_right_gate_hexagon (
444439 red_gates : Sequence [jnp .ndarray ],
@@ -468,6 +463,7 @@ def get_down_gates(b_e, r_e, d):
468463
469464 return blue_35 , red_36
470465
466+
471467def get_down_gates_hexagon (r_e , g_e , d ):
472468 Id_other_site = jnp .eye (d ** 2 )
473469
@@ -509,6 +505,7 @@ def _calc_down_gate(
509505
510506 return result , single_gates
511507
508+
512509@partial (jit , static_argnums = (2 , 3 ))
513510def _calc_down_gate_hexagon (
514511 red_gates : Sequence [jnp .ndarray ],
@@ -538,6 +535,7 @@ def get_diagonal_gates(b_e, r_e, d):
538535
539536 return blue_41 , red_31
540537
538+
541539def get_diagonal_gates_hexagon (r_e , g_e , d ):
542540 Id_other_site = jnp .eye (d ** 2 )
543541
@@ -579,6 +577,7 @@ def _calc_diagonal_gate(
579577
580578 return result , single_gates
581579
580+
582581@partial (jit , static_argnums = (2 , 3 ))
583582def _calc_diagonal_gate_hexagon (
584583 red_gates : Sequence [jnp .ndarray ],
@@ -1881,7 +1880,11 @@ def __call__(
18811880 density_matrix_top_left ,
18821881 density_matrix_bottom_right ,
18831882 ) = partially_traced_diagonal_two_site_density_matrices_triangular (
1884- diagonal_tensors , diagonal_tensor_objs , 2 , 6 , ((3 , 4 ), (1 ,)),
1883+ diagonal_tensors ,
1884+ diagonal_tensor_objs ,
1885+ 2 ,
1886+ 6 ,
1887+ ((3 , 4 ), (1 ,)),
18851888 )
18861889
18871890 if return_single_gate_results :
@@ -2064,6 +2067,7 @@ def load_from_group(cls, grp: h5py.Group):
20642067 spiral_unitary_operator = spiral_unitary_operator ,
20652068 )
20662069
2070+
20672071@dataclass
20682072class Maple_Leaf_Hexagon_Triangular_CTMRG_Expectation_Value (Expectation_Model ):
20692073 """
@@ -2316,7 +2320,6 @@ def __call__(
23162320 self ._full_onsite_tuple ,
23172321 )
23182322
2319-
23202323 vertical_tensors_i = view .get_indices ((slice (0 , 2 , None ), 0 ))
23212324 vertical_tensors = [
23222325 peps_tensors [i ] for j in vertical_tensors_i for i in j
@@ -2344,7 +2347,6 @@ def __call__(
23442347 self ._result_type is jnp .float64 ,
23452348 )
23462349
2347-
23482350 horizontal_tensors_i = view .get_indices ((0 , slice (0 , 2 , None )))
23492351 horizontal_tensors = [
23502352 peps_tensors [i ] for j in horizontal_tensors_i for i in j
@@ -2354,7 +2356,11 @@ def __call__(
23542356 density_matrix_left ,
23552357 density_matrix_right ,
23562358 ) = partially_traced_horizontal_two_site_density_matrices_triangular (
2357- horizontal_tensors , horizontal_tensor_objs , 2 , 6 , ((2 , 3 ), (5 , 6 ))
2359+ horizontal_tensors ,
2360+ horizontal_tensor_objs ,
2361+ 2 ,
2362+ 6 ,
2363+ ((2 , 3 ), (5 , 6 )),
23582364 )
23592365
23602366 if return_single_gate_results :
@@ -2372,7 +2378,6 @@ def __call__(
23722378 self ._result_type is jnp .float64 ,
23732379 )
23742380
2375-
23762381 diagonal_tensors_i = view .get_indices (
23772382 (slice (0 , 2 , None ), slice (0 , 2 , None ))
23782383 )
@@ -2566,4 +2571,4 @@ def load_from_group(cls, grp: h5py.Group):
25662571 normalization_factor = grp .attrs ["normalization_factor" ],
25672572 is_spiral_peps = is_spiral_peps ,
25682573 spiral_unitary_operator = spiral_unitary_operator ,
2569- )
2574+ )
0 commit comments