Skip to content

Commit acf453a

Browse files
committed
Use opt_einsum to optimize contraction orders
1 parent 0a23705 commit acf453a

File tree

2 files changed

+46
-49
lines changed

2 files changed

+46
-49
lines changed

varipeps/contractions/apply.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,18 @@ def __class_getitem__(cls, name: str) -> Checkpointing_Cache:
3535
_ncon_jitted = jax.jit(_jittable_ncon, static_argnums=(1, 2, 3, 4, 5), inline=True)
3636

3737

38+
@partial(
39+
jax.jit, static_argnames=("name", "disable_identity_check", "custom_definition")
40+
)
3841
def apply_contraction(
3942
name: str,
4043
peps_tensors: Sequence[jnp.ndarray],
4144
peps_tensor_objs: Sequence[PEPS_Tensor],
4245
additional_tensors: Sequence[jnp.ndarray],
4346
*,
44-
disable_identity_check: bool = False,
47+
disable_identity_check: bool = True,
4548
custom_definition: Optional[Definition] = None,
4649
config: VariPEPS_Config = varipeps_config,
47-
_jitable: bool = False,
4850
) -> jnp.ndarray:
4951
"""
5052
Apply a contraction to a list of tensors.
@@ -83,8 +85,7 @@ class definition for details.
8385
)
8486

8587
if (
86-
not _jitable
87-
and not disable_identity_check
88+
not disable_identity_check
8889
and not all(isinstance(t, jax.core.Tracer) for t in peps_tensors)
8990
and not all(isinstance(to.tensor, jax.core.Tracer) for to in peps_tensor_objs)
9091
and not all(
@@ -133,36 +134,19 @@ class definition for details.
133134

134135
tensors += additional_tensors
135136

136-
if _jitable:
137-
if config.checkpointing_ncon:
138-
f = jax.checkpoint(_ncon_jitted, static_argnums=(1, 2, 3, 4, 5))
139-
else:
140-
f = _ncon_jitted
141-
142-
return f(
143-
tensors,
144-
contraction["ncon_flat_network"],
145-
contraction["ncon_sizes"],
146-
contraction["ncon_con_order"],
147-
contraction["ncon_out_order"],
148-
tn.backends.backend_factory.get_backend("jax"),
149-
)
137+
tensor_shapes = tuple(tuple(e.shape) for e in tensors)
150138

151-
if config.checkpointing_ncon:
152-
f = jax.checkpoint(
153-
partial(
154-
tn.ncon, network_structure=contraction["ncon_network"], backend="jax"
155-
)
156-
)
157-
else:
158-
f = partial(
159-
tn.ncon, network_structure=contraction["ncon_network"], backend="jax"
139+
path = contraction["einsum_cache"].get(tensor_shapes)
140+
141+
if path is None:
142+
path, _ = jnp.einsum_path(
143+
contraction["einsum_network"],
144+
*tensors,
145+
optimize="optimal" if len(tensors) < 10 else "dp",
160146
)
147+
contraction["einsum_cache"][tensor_shapes] = path
161148

162-
return f(tensors)
149+
return jnp.einsum(contraction["einsum_network"], *tensors, optimize=path)
163150

164151

165-
apply_contraction_jitted = jax.jit(
166-
partial(apply_contraction, _jitable=True),
167-
static_argnames=("name", "disable_identity_check", "custom_definition"),
168-
)
152+
apply_contraction_jitted = apply_contraction

varipeps/contractions/definitions.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from collections import Counter
66

7+
import opt_einsum
78
import tensornetwork as tn
89

910
from typing import Dict, Optional, Union, List, Tuple, Sequence
@@ -116,6 +117,31 @@ def _create_filter_and_network(
116117
network_additional_tensors,
117118
)
118119

120+
@classmethod
121+
def _convert_to_einsum(cls, network):
122+
max_contracted = 0
123+
min_open = 0
124+
125+
result = tuple([] for _ in range(len(network)))
126+
127+
for ti, t in enumerate(network):
128+
for i in t:
129+
if i < 0 and i < min_open:
130+
min_open = i
131+
elif i > 0 and i > max_contracted:
132+
max_contracted = i
133+
if max_contracted >= (min_open + 52):
134+
raise ValueError("Letters in conversion are overlapping.")
135+
136+
result[ti].append(opt_einsum.get_symbol(i))
137+
138+
result = ["".join(e) for e in result]
139+
open_result = [opt_einsum.get_symbol(i) for i in range(-1, min_open - 1, -1)]
140+
141+
result = f"{','.join(result)}->{''.join(open_result)}"
142+
143+
return result
144+
119145
@classmethod
120146
def _process_def(cls, e, name):
121147
(
@@ -129,6 +155,8 @@ def _process_def(cls, e, name):
129155
j for i in network_peps_tensors for j in i
130156
] + network_additional_tensors
131157

158+
einsum_network = cls._convert_to_einsum(ncon_network)
159+
132160
flatted_ncon_list = [j for i in ncon_network for j in i]
133161
counter_ncon_list = Counter(flatted_ncon_list)
134162
for ind, c in counter_ncon_list.items():
@@ -142,28 +170,13 @@ def _process_def(cls, e, name):
142170
f'Non-monotonous indices in definition "{name}". Please check!'
143171
)
144172

145-
(
146-
mapped_ncon_network,
147-
mapping,
148-
) = tn.ncon_interface._canonicalize_network_structure(ncon_network)
149-
flat_network = tuple(l for sublist in mapped_ncon_network for l in sublist)
150-
unique_flat_network = list(set(flat_network))
151-
152-
out_order = tuple(
153-
sorted([l for l in unique_flat_network if l < 0], reverse=True)
154-
)
155-
con_order = tuple(sorted([l for l in unique_flat_network if l > 0]))
156-
sizes = tuple(len(l) for l in ncon_network)
157-
158173
e["filter_peps_tensors"] = filter_peps_tensors
159174
e["filter_additional_tensors"] = filter_additional_tensors
160175
e["network_peps_tensors"] = network_peps_tensors
161176
e["network_additional_tensors"] = network_additional_tensors
162177
e["ncon_network"] = ncon_network
163-
e["ncon_flat_network"] = flat_network
164-
e["ncon_sizes"] = sizes
165-
e["ncon_con_order"] = con_order
166-
e["ncon_out_order"] = out_order
178+
e["einsum_network"] = einsum_network
179+
e["einsum_cache"] = dict()
167180

168181
@classmethod
169182
def _prepare_defs(cls):

0 commit comments

Comments
 (0)