@@ -35,16 +35,18 @@ def __class_getitem__(cls, name: str) -> Checkpointing_Cache:
3535_ncon_jitted = jax .jit (_jittable_ncon , static_argnums = (1 , 2 , 3 , 4 , 5 ), inline = True )
3636
3737
38+ @partial (
39+ jax .jit , static_argnames = ("name" , "disable_identity_check" , "custom_definition" )
40+ )
3841def apply_contraction (
3942 name : str ,
4043 peps_tensors : Sequence [jnp .ndarray ],
4144 peps_tensor_objs : Sequence [PEPS_Tensor ],
4245 additional_tensors : Sequence [jnp .ndarray ],
4346 * ,
44- disable_identity_check : bool = False ,
47+ disable_identity_check : bool = True ,
4548 custom_definition : Optional [Definition ] = None ,
4649 config : VariPEPS_Config = varipeps_config ,
47- _jitable : bool = False ,
4850) -> jnp .ndarray :
4951 """
5052 Apply a contraction to a list of tensors.
@@ -83,8 +85,7 @@ class definition for details.
8385 )
8486
8587 if (
86- not _jitable
87- and not disable_identity_check
88+ not disable_identity_check
8889 and not all (isinstance (t , jax .core .Tracer ) for t in peps_tensors )
8990 and not all (isinstance (to .tensor , jax .core .Tracer ) for to in peps_tensor_objs )
9091 and not all (
@@ -133,36 +134,19 @@ class definition for details.
133134
134135 tensors += additional_tensors
135136
136- if _jitable :
137- if config .checkpointing_ncon :
138- f = jax .checkpoint (_ncon_jitted , static_argnums = (1 , 2 , 3 , 4 , 5 ))
139- else :
140- f = _ncon_jitted
141-
142- return f (
143- tensors ,
144- contraction ["ncon_flat_network" ],
145- contraction ["ncon_sizes" ],
146- contraction ["ncon_con_order" ],
147- contraction ["ncon_out_order" ],
148- tn .backends .backend_factory .get_backend ("jax" ),
149- )
137+ tensor_shapes = tuple (tuple (e .shape ) for e in tensors )
150138
151- if config .checkpointing_ncon :
152- f = jax .checkpoint (
153- partial (
154- tn .ncon , network_structure = contraction ["ncon_network" ], backend = "jax"
155- )
156- )
157- else :
158- f = partial (
159- tn .ncon , network_structure = contraction ["ncon_network" ], backend = "jax"
139+ path = contraction ["einsum_cache" ].get (tensor_shapes )
140+
141+ if path is None :
142+ path , _ = jnp .einsum_path (
143+ contraction ["einsum_network" ],
144+ * tensors ,
145+ optimize = "optimal" if len (tensors ) < 10 else "dp" ,
160146 )
147+ contraction ["einsum_cache" ][tensor_shapes ] = path
161148
162- return f ( tensors )
149+ return jnp . einsum ( contraction [ "einsum_network" ], * tensors , optimize = path )
163150
164151
165- apply_contraction_jitted = jax .jit (
166- partial (apply_contraction , _jitable = True ),
167- static_argnames = ("name" , "disable_identity_check" , "custom_definition" ),
168- )
152+ apply_contraction_jitted = apply_contraction
0 commit comments