Skip to content

Commit 0a23705

Browse files
committed
Revert "Disable the gauge-fixing of the final CTM tensors for now due to terrible GPU performance"
This reverts commit faf6313.
1 parent a4097d9 commit 0a23705

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

varipeps/ctmrg/absorption.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,36 @@ def _get_ctmrg_1x2_structure(
9696

9797
def _post_process_CTM_tensors(a: jnp.ndarray, config: VariPEPS_Config) -> jnp.ndarray:
9898
a = a / jnp.linalg.norm(a)
99-
return a
99+
a_abs = jnp.abs(a)
100+
a_abs_max = jnp.max(a_abs)
101+
102+
def scan_max_element(carry, x):
103+
x_a, x_a_abs = x
104+
found, phase = carry
105+
106+
def new_phase(ph, curr_x, curr_x_abs):
107+
return cond(
108+
curr_x_abs >= (config.svd_sign_fix_eps * a_abs_max),
109+
lambda p, c_x, c_x_a: c_x / c_x_a,
110+
lambda p, c_x, c_x_a: p,
111+
ph,
112+
curr_x,
113+
curr_x_abs,
114+
)
115+
116+
phase = cond(
117+
found, lambda ph, curr_x, curr_x_abs: ph, new_phase, phase, x_a, x_a_abs
118+
)
119+
120+
return (jnp.logical_not(jnp.isnan(phase)), phase), None
121+
122+
(_, phase), _ = scan(
123+
scan_max_element,
124+
(jnp.array(False), jnp.array(jnp.nan, dtype=a.dtype)),
125+
(a.flatten(), a_abs.flatten()),
126+
)
127+
128+
return a * phase.conj()
100129

101130

102131
def do_left_absorption(

0 commit comments

Comments
 (0)