Skip to content

Commit e37135a

Browse files
committed
Implement code to calculate overlap of two unit cells
1 parent 91417f2 commit e37135a

File tree

7 files changed

+337
-1
lines changed

7 files changed

+337
-1
lines changed

varipeps/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import expectation
1010
from . import mapping
1111
from . import optimization
12+
from . import overlap
1213
from . import peps
1314
from . import typing
1415
from . import utils

varipeps/contractions/definitions.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3277,5 +3277,107 @@ def _prepare_defs(cls):
32773277
],
32783278
}
32793279

3280+
overlap_one_site_only_corners: Definition = {
3281+
"tensors": [["C1", "C2", "C3", "C4"]],
3282+
"network": [
3283+
[
3284+
(1, 2), # C1
3285+
(2, 3), # C2
3286+
(4, 3), # C3
3287+
(4, 1), # C4
3288+
],
3289+
],
3290+
}
3291+
3292+
overlap_one_site_only_transfer_horizontal: Definition = {
3293+
"tensors": [["C1", "T1", "C2", "C3", "T3", "C4"]],
3294+
"network": [
3295+
[
3296+
(1, 2), # C1
3297+
(2, 4, 5, 6), # T1
3298+
(6, 7), # C2
3299+
(7, 8), # C3
3300+
(3, 8, 5, 4), # T3
3301+
(3, 1), # C4
3302+
],
3303+
],
3304+
}
3305+
3306+
overlap_one_site_only_transfer_vertical: Definition = {
3307+
"tensors": [["C1", "C2", "T2", "C3", "C4", "T4"]],
3308+
"network": [
3309+
[
3310+
(2, 1), # C1
3311+
(1, 3), # C2
3312+
(4, 5, 8, 3), # T2
3313+
(7, 8), # C3
3314+
(7, 6), # C4
3315+
(6, 5, 4, 2), # T4
3316+
],
3317+
],
3318+
}
3319+
3320+
overlap_four_sites_square_only_corners: Definition = {
3321+
"tensors": [["C1"], ["C2"], ["C3"], ["C4"]],
3322+
"network": [
3323+
[
3324+
(1, 2), # C1
3325+
],
3326+
[
3327+
(2, 3), # C2
3328+
],
3329+
[
3330+
(4, 3), # C3
3331+
],
3332+
[
3333+
(4, 1), # C4
3334+
],
3335+
],
3336+
}
3337+
3338+
overlap_four_sites_square_transfer_horizontal: Definition = {
3339+
"tensors": [["C1", "T1"], ["T1", "C2"], ["C3", "T3"], ["T3", "C4"]],
3340+
"network": [
3341+
[
3342+
(1, 2), # C1
3343+
(2, 4, 5, 11), # T1
3344+
],
3345+
[
3346+
(11, 9, 10, 7), # T1
3347+
(7, 6), # C2
3348+
],
3349+
[
3350+
(8, 6), # C3
3351+
(12, 8, 10, 9), # T3
3352+
],
3353+
[
3354+
(3, 12, 5, 4), # T3
3355+
(3, 1), # C4
3356+
],
3357+
],
3358+
}
3359+
3360+
overlap_four_sites_square_transfer_vertical: Definition = {
3361+
"tensors": [["T4", "C1"], ["C2", "T2"], ["T2", "C3"], ["C4", "T4"]],
3362+
"network": [
3363+
[
3364+
(11, 5, 4, 2), # T4
3365+
(2, 1), # C1
3366+
],
3367+
[
3368+
(1, 3), # C2
3369+
(4, 5, 12, 3), # T2
3370+
],
3371+
[
3372+
(9, 10, 7, 12), # T2
3373+
(6, 7), # C3
3374+
],
3375+
[
3376+
(6, 8), # C4
3377+
(8, 10, 9, 11), # T4
3378+
],
3379+
],
3380+
}
3381+
32803382

32813383
Definitions._prepare_defs()

varipeps/overlap/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import overlap
2+
from . import overlap_single_site
3+
from . import overlap_four_sites
4+
5+
from .overlap import calculate_overlap

varipeps/overlap/overlap.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import jax.numpy as jnp
2+
3+
from varipeps.ctmrg import calc_ctmrg_env
4+
from varipeps.peps import PEPS_Unit_Cell
5+
6+
from . import overlap_single_site
7+
from . import overlap_four_sites
8+
9+
overlap_mapping = {
10+
(1, 1): overlap_single_site.Overlap_Single_Site,
11+
(2, 2): overlap_four_sites.Overlap_Four_Sites_Square,
12+
}
13+
14+
15+
def calculate_overlap(unitcell_A, unitcell_B, chi, max_chi):
16+
structure_A = tuple(tuple(i) for i in unitcell_A.data.structure)
17+
structure_B = tuple(tuple(i) for i in unitcell_B.data.structure)
18+
19+
if structure_A != structure_B:
20+
raise ValueError("Structure of both unit cells have to be the same.")
21+
22+
size = unitcell_A.get_size()
23+
num_tensors = size[0] * size[1]
24+
overlap_func = overlap_mapping[size].calc_overlap
25+
26+
unitcell_A = unitcell_A.convert_to_full_transfer()
27+
unitcell_B = unitcell_B.convert_to_full_transfer()
28+
29+
norm_A = overlap_func(unitcell_A)
30+
norm_B = overlap_func(unitcell_B)
31+
32+
overlap_tensors = [
33+
type(t).from_tensor(
34+
t.tensor / norm_A ** (1 / (2 * num_tensors)), t.d, t.D, chi, max_chi=max_chi
35+
)
36+
for t in unitcell_A.get_unique_tensors()
37+
]
38+
39+
for i, e in enumerate(overlap_tensors):
40+
e.tensor_conj = unitcell_B.get_unique_tensors()[i].tensor.conj() / norm_B ** (
41+
1 / (2 * num_tensors)
42+
)
43+
44+
overlap_unitcell = PEPS_Unit_Cell.from_tensor_list(overlap_tensors, structure_A)
45+
46+
overlap_unitcell, _ = calc_ctmrg_env(
47+
[i.tensor for i in overlap_tensors], overlap_unitcell
48+
)
49+
50+
overlap_AB = overlap_func(overlap_unitcell)
51+
52+
return overlap_AB
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import numpy as np
2+
import jax.numpy as jnp
3+
4+
from jax import jit
5+
6+
from varipeps.contractions import apply_contraction_jitted
7+
8+
9+
class Overlap_Four_Sites_Square:
10+
@staticmethod
11+
@jit
12+
def calc_overlap(unitcell):
13+
top_left = apply_contraction_jitted(
14+
"ctmrg_top_left", [unitcell[0, 0][0][0].tensor], [unitcell[0, 0][0][0]], []
15+
)
16+
top_left = top_left.reshape(
17+
np.prod(top_left.shape[:3]), np.prod(top_left.shape[3:])
18+
)
19+
20+
top_right = apply_contraction_jitted(
21+
"ctmrg_top_right", [unitcell[0, 1][0][0].tensor], [unitcell[0, 1][0][0]], []
22+
)
23+
top_right = top_right.reshape(
24+
np.prod(top_right.shape[:3]), np.prod(top_right.shape[3:])
25+
)
26+
27+
bottom_left = apply_contraction_jitted(
28+
"ctmrg_bottom_left",
29+
[unitcell[1, 0][0][0].tensor],
30+
[unitcell[1, 0][0][0]],
31+
[],
32+
)
33+
bottom_left = bottom_left.reshape(
34+
np.prod(bottom_left.shape[:3]), np.prod(bottom_left.shape[3:])
35+
)
36+
37+
bottom_right = apply_contraction_jitted(
38+
"ctmrg_bottom_right",
39+
[unitcell[1, 1][0][0].tensor],
40+
[unitcell[1, 1][0][0]],
41+
[],
42+
)
43+
bottom_right = bottom_right.reshape(
44+
np.prod(bottom_right.shape[:3]), np.prod(bottom_right.shape[3:])
45+
)
46+
47+
norm_with_sites = jnp.trace(top_left @ top_right @ bottom_left @ bottom_right)
48+
49+
norm_corners = apply_contraction_jitted(
50+
"overlap_four_sites_square_only_corners",
51+
[
52+
unitcell[0, 0][0][0].tensor,
53+
unitcell[0, 1][0][0].tensor,
54+
unitcell[1, 1][0][0].tensor,
55+
unitcell[1, 0][0][0].tensor,
56+
],
57+
[
58+
unitcell[0, 0][0][0],
59+
unitcell[0, 1][0][0],
60+
unitcell[1, 1][0][0],
61+
unitcell[1, 0][0][0],
62+
],
63+
[],
64+
)
65+
66+
norm_horizontal = apply_contraction_jitted(
67+
"overlap_four_sites_square_transfer_horizontal",
68+
[
69+
unitcell[0, 0][0][0].tensor,
70+
unitcell[0, 1][0][0].tensor,
71+
unitcell[1, 1][0][0].tensor,
72+
unitcell[1, 0][0][0].tensor,
73+
],
74+
[
75+
unitcell[0, 0][0][0],
76+
unitcell[0, 1][0][0],
77+
unitcell[1, 1][0][0],
78+
unitcell[1, 0][0][0],
79+
],
80+
[],
81+
)
82+
83+
norm_vertical = apply_contraction_jitted(
84+
"overlap_four_sites_square_transfer_vertical",
85+
[
86+
unitcell[0, 0][0][0].tensor,
87+
unitcell[0, 1][0][0].tensor,
88+
unitcell[1, 1][0][0].tensor,
89+
unitcell[1, 0][0][0].tensor,
90+
],
91+
[
92+
unitcell[0, 0][0][0],
93+
unitcell[0, 1][0][0],
94+
unitcell[1, 1][0][0],
95+
unitcell[1, 0][0][0],
96+
],
97+
[],
98+
)
99+
100+
return (norm_with_sites * norm_corners) / (norm_horizontal * norm_vertical)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import jax.numpy as jnp
2+
3+
from jax import jit
4+
5+
from varipeps.contractions import apply_contraction_jitted
6+
7+
8+
class Overlap_Single_Site:
9+
@staticmethod
10+
@jit
11+
def calc_overlap(unitcell):
12+
density_matrix = apply_contraction_jitted(
13+
"density_matrix_one_site",
14+
[unitcell[0, 0][0][0].tensor],
15+
[unitcell[0, 0][0][0]],
16+
[],
17+
)
18+
19+
norm_with_site = jnp.trace(density_matrix)
20+
21+
norm_corners = apply_contraction_jitted(
22+
"overlap_one_site_only_corners",
23+
[unitcell[0, 0][0][0].tensor],
24+
[unitcell[0, 0][0][0]],
25+
[],
26+
)
27+
28+
norm_horizontal = apply_contraction_jitted(
29+
"overlap_one_site_only_transfer_horizontal",
30+
[unitcell[0, 0][0][0].tensor],
31+
[unitcell[0, 0][0][0]],
32+
[],
33+
)
34+
35+
norm_vertical = apply_contraction_jitted(
36+
"overlap_one_site_only_transfer_vertical",
37+
[unitcell[0, 0][0][0].tensor],
38+
[unitcell[0, 0][0][0]],
39+
[],
40+
)
41+
42+
return (norm_with_site * norm_corners) / (norm_horizontal * norm_vertical)

0 commit comments

Comments
 (0)