@@ -598,37 +598,37 @@ def _calc_onsite_gate(
598598 )
599599 blue_36 = blue_36 .reshape (d ** 9 , d ** 9 )
600600
601- green_12 = jnp .kron (g_e , Id_other_sites )
602- green_12 = green_12 .reshape (
601+ green_base = jnp .kron (g_e , Id_other_sites )
602+ green_base = green_base .reshape (
603603 d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d
604604 )
605605
606- green_24 = green_12 .transpose (
606+ green_24 = green_base .transpose (
607607 2 , 0 , 3 , 1 , 4 , 5 , 6 , 7 , 8 , 11 , 9 , 12 , 10 , 13 , 14 , 15 , 16 , 17
608608 )
609609 green_24 = green_24 .reshape (d ** 9 , d ** 9 )
610610
611- green_45 = green_12 .transpose (
611+ green_45 = green_base .transpose (
612612 2 , 3 , 4 , 0 , 1 , 5 , 6 , 7 , 8 , 11 , 12 , 13 , 9 , 10 , 14 , 15 , 16 , 17
613613 )
614614 green_45 = green_45 .reshape (d ** 9 , d ** 9 )
615615
616- green_46 = green_12 .transpose (
616+ green_46 = green_base .transpose (
617617 2 , 3 , 4 , 0 , 5 , 1 , 6 , 7 , 8 , 11 , 12 , 13 , 9 , 14 , 10 , 15 , 16 , 17
618618 )
619619 green_46 = green_46 .reshape (d ** 9 , d ** 9 )
620620
621- green_37 = green_12 .transpose (
621+ green_37 = green_base .transpose (
622622 2 , 3 , 0 , 4 , 5 , 6 , 1 , 7 , 8 , 11 , 12 , 9 , 13 , 14 , 15 , 10 , 16 , 17
623623 )
624624 green_37 = green_37 .reshape (d ** 9 , d ** 9 )
625625
626- green_78 = green_12 .transpose (
626+ green_78 = green_base .transpose (
627627 2 , 3 , 4 , 5 , 6 , 7 , 0 , 1 , 8 , 11 , 12 , 13 , 14 , 15 , 16 , 9 , 10 , 17
628628 )
629629 green_78 = green_78 .reshape (d ** 9 , d ** 9 )
630630
631- green_79 = green_12 .transpose (
631+ green_79 = green_base .transpose (
632632 2 , 3 , 4 , 5 , 6 , 7 , 0 , 8 , 1 , 11 , 12 , 13 , 14 , 15 , 16 , 9 , 17 , 10
633633 )
634634 green_79 = green_79 .reshape (d ** 9 , d ** 9 )
@@ -869,7 +869,9 @@ def __post_init__(self) -> None:
869869 )
870870
871871 if self .is_spiral_peps :
872- raise NotImplementedError
872+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
873+ self .spiral_unitary_operator
874+ )
873875
874876 def __call__ (
875877 self ,
@@ -891,16 +893,146 @@ def __call__(
891893 working_onsite_gates = tuple (
892894 o for e in self ._onsite_single_gates for o in e
893895 )
894- working_h_single_gates = tuple (
895- h for e in self ._right_single_gates for h in e
896+
897+ if self .is_spiral_peps :
898+ if isinstance (spiral_vectors , jnp .ndarray ):
899+ spiral_vectors = (
900+ spiral_vectors ,
901+ spiral_vectors ,
902+ spiral_vectors ,
903+ )
904+ if len (spiral_vectors ) == 1 :
905+ spiral_vectors = (
906+ spiral_vectors [0 ],
907+ spiral_vectors [0 ],
908+ None ,
909+ None ,
910+ None ,
911+ None ,
912+ None ,
913+ None ,
914+ spiral_vectors [0 ],
915+ )
916+ if len (spiral_vectors ) == 4 :
917+ spiral_vectors = (
918+ spiral_vectors [0 ],
919+ spiral_vectors [1 ],
920+ None ,
921+ None ,
922+ None ,
923+ None ,
924+ None ,
925+ None ,
926+ spiral_vectors [2 ],
927+ )
928+ if len (spiral_vectors ) != 9 :
929+ raise ValueError ("Length mismatch for spiral vectors!" )
930+
931+ working_h_gates = tuple (
932+ apply_unitary (
933+ h ,
934+ jnp .array ((0 , 1 )),
935+ spiral_vectors [0 :9 :8 ],
936+ self ._spiral_D ,
937+ self ._spiral_sigma ,
938+ self .real_d ,
939+ 3 ,
940+ (1 , 2 ),
941+ varipeps_config .spiral_wavevector_type ,
942+ )
943+ for h in self ._right_tuple
896944 )
897- working_v_single_gates = tuple (
898- v for e in self ._down_single_gates for v in e
945+ working_v_gates = tuple (
946+ apply_unitary (
947+ v ,
948+ jnp .array ((1 , 0 )),
949+ spiral_vectors [:2 ],
950+ self ._spiral_D ,
951+ self ._spiral_sigma ,
952+ self .real_d ,
953+ 4 ,
954+ (2 , 3 ),
955+ varipeps_config .spiral_wavevector_type ,
956+ )
957+ for v in self ._down_tuple
899958 )
900- working_d_single_gates = tuple (
901- d for e in self ._diagonal_single_gates for d in e
959+ working_d_gates = tuple (
960+ apply_unitary (
961+ d ,
962+ jnp .array ((1 , 1 )),
963+ spiral_vectors [:1 ],
964+ self ._spiral_D ,
965+ self ._spiral_sigma ,
966+ self .real_d ,
967+ 3 ,
968+ (2 ,),
969+ varipeps_config .spiral_wavevector_type ,
970+ )
971+ for d in self ._diagonal_tuple
902972 )
903973
974+ if return_single_gate_results :
975+ working_h_single_gates = tuple (
976+ apply_unitary (
977+ h ,
978+ jnp .array ((0 , 1 )),
979+ spiral_vectors [0 :9 :8 ],
980+ self ._spiral_D ,
981+ self ._spiral_sigma ,
982+ self .real_d ,
983+ 3 ,
984+ (1 , 2 ),
985+ varipeps_config .spiral_wavevector_type ,
986+ )
987+ for e in self ._right_single_gates
988+ for h in e
989+ )
990+ working_v_single_gates = tuple (
991+ apply_unitary (
992+ v ,
993+ jnp .array ((1 , 0 )),
994+ spiral_vectors [:2 ],
995+ self ._spiral_D ,
996+ self ._spiral_sigma ,
997+ self .real_d ,
998+ 4 ,
999+ (2 , 3 ),
1000+ varipeps_config .spiral_wavevector_type ,
1001+ )
1002+ for e in self ._down_single_gates
1003+ for v in e
1004+ )
1005+ working_d_single_gates = tuple (
1006+ apply_unitary (
1007+ d ,
1008+ jnp .array ((1 , 1 )),
1009+ spiral_vectors [:1 ],
1010+ self ._spiral_D ,
1011+ self ._spiral_sigma ,
1012+ self .real_d ,
1013+ 3 ,
1014+ (2 ,),
1015+ varipeps_config .spiral_wavevector_type ,
1016+ )
1017+ for e in self ._diagonal_single_gates
1018+ for d in e
1019+ )
1020+ else :
1021+ working_h_gates = self ._right_tuple
1022+ working_v_gates = self ._down_tuple
1023+ working_d_gates = self ._diagonal_tuple
1024+
1025+ if return_single_gate_results :
1026+ working_h_single_gates = tuple (
1027+ h for e in self ._right_single_gates for h in e
1028+ )
1029+ working_v_single_gates = tuple (
1030+ v for e in self ._down_single_gates for v in e
1031+ )
1032+ working_d_single_gates = tuple (
1033+ d for e in self ._diagonal_single_gates for d in e
1034+ )
1035+
9041036 for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
9051037 for y , view in iter_rows :
9061038 # On site term
@@ -937,14 +1069,14 @@ def __call__(
9371069 step_result_horizontal = _two_site_workhorse (
9381070 density_matrix_left ,
9391071 density_matrix_right ,
940- self . _right_tuple + working_h_single_gates ,
1072+ working_h_gates + working_h_single_gates ,
9411073 self ._result_type is jnp .float64 ,
9421074 )
9431075 else :
9441076 step_result_horizontal = _two_site_workhorse (
9451077 density_matrix_left ,
9461078 density_matrix_right ,
947- self . _right_tuple ,
1079+ working_h_gates ,
9481080 self ._result_type is jnp .float64 ,
9491081 )
9501082
@@ -964,14 +1096,14 @@ def __call__(
9641096 step_result_vertical = _two_site_workhorse (
9651097 density_matrix_top ,
9661098 density_matrix_bottom ,
967- self . _down_tuple + working_v_single_gates ,
1099+ working_v_gates + working_v_single_gates ,
9681100 self ._result_type is jnp .float64 ,
9691101 )
9701102 else :
9711103 step_result_vertical = _two_site_workhorse (
9721104 density_matrix_top ,
9731105 density_matrix_bottom ,
974- self . _down_tuple ,
1106+ working_v_gates ,
9751107 self ._result_type is jnp .float64 ,
9761108 )
9771109
@@ -1011,7 +1143,7 @@ def __call__(
10111143 density_matrix_bottom_right ,
10121144 traced_density_matrix_top_right ,
10131145 traced_density_matrix_bottom_left ,
1014- self . _diagonal_tuple + working_d_single_gates ,
1146+ working_d_gates + working_d_single_gates ,
10151147 self ._result_type is jnp .float64 ,
10161148 )
10171149 else :
@@ -1020,7 +1152,7 @@ def __call__(
10201152 density_matrix_bottom_right ,
10211153 traced_density_matrix_top_right ,
10221154 traced_density_matrix_bottom_left ,
1023- self . _diagonal_tuple ,
1155+ working_d_gates ,
10241156 self ._result_type is jnp .float64 ,
10251157 )
10261158
0 commit comments