File tree Expand file tree Collapse file tree 7 files changed +56
-0
lines changed Expand file tree Collapse file tree 7 files changed +56
-0
lines changed Original file line number Diff line number Diff line change 66import jax .numpy as jnp
77from jax import jit
88
9+ from peps_ad import peps_ad_config
910from peps_ad .peps import PEPS_Tensor , PEPS_Unit_Cell
1011from peps_ad .contractions import apply_contraction
1112from .model import Expectation_Model
@@ -419,6 +420,12 @@ class Two_Sites_Expectation_Value(Expectation_Model):
419420 spiral_unitary_operator : Optional [jnp .ndarray ] = None
420421
421422 def __post_init__ (self ) -> None :
423+ if isinstance (self .horizontal_gates , jnp .ndarray ):
424+ self .horizontal_gates = (self .horizontal_gates ,)
425+
426+ if isinstance (self .vertical_gates , jnp .ndarray ):
427+ self .vertical_gates = (self .vertical_gates ,)
428+
422429 if (
423430 len (self .horizontal_gates ) > 0
424431 and len (self .horizontal_gates ) > 0
Original file line number Diff line number Diff line change @@ -775,6 +775,15 @@ class Florett_Pentagon_Expectation_Value(Expectation_Model):
775775 spiral_unitary_operator : Optional [jnp .ndarray ] = None
776776
777777 def __post_init__ (self ) -> None :
778+ if isinstance (self .black_gates , jnp .ndarray ):
779+ self .black_gates = (self .black_gates ,)
780+
781+ if isinstance (self .green_gates , jnp .ndarray ):
782+ self .green_gates = (self .green_gates ,)
783+
784+ if isinstance (self .blue_gates , jnp .ndarray ):
785+ self .blue_gates = (self .blue_gates ,)
786+
778787 if (len (self .green_gates ) != len (self .blue_gates )) or (
779788 len (self .green_gates ) != len (self .black_gates )
780789 ):
Original file line number Diff line number Diff line change @@ -55,6 +55,15 @@ class Honeycomb_Expectation_Value(Expectation_Model):
5555 spiral_unitary_operator : Optional [jnp .ndarray ] = None
5656
5757 def __post_init__ (self ) -> None :
58+ if isinstance (self .x_gates , jnp .ndarray ):
59+ self .x_gates = (self .x_gates ,)
60+
61+ if isinstance (self .y_gates , jnp .ndarray ):
62+ self .y_gates = (self .y_gates ,)
63+
64+ if isinstance (self .z_gates , jnp .ndarray ):
65+ self .z_gates = (self .z_gates ,)
66+
5867 if (
5968 (
6069 len (self .x_gates ) > 0
Original file line number Diff line number Diff line change @@ -53,6 +53,12 @@ class Kagome_PESS3_Expectation_Value(Expectation_Model):
5353 spiral_unitary_operator : Optional [jnp .ndarray ] = None
5454
5555 def __post_init__ (self ) -> None :
56+ if isinstance (self .upward_triangle_gates , jnp .ndarray ):
57+ self .upward_triangle_gates = (self .upward_triangle_gates ,)
58+
59+ if isinstance (self .downward_triangle_gates , jnp .ndarray ):
60+ self .downward_triangle_gates = (self .downward_triangle_gates ,)
61+
5662 if (
5763 len (self .upward_triangle_gates ) > 0
5864 and len (self .downward_triangle_gates ) > 0
Original file line number Diff line number Diff line change @@ -426,6 +426,15 @@ class Maple_Leaf_Expectation_Value(Expectation_Model):
426426 spiral_unitary_operator : Optional [jnp .ndarray ] = None
427427
428428 def __post_init__ (self ) -> None :
429+ if isinstance (self .green_gates , jnp .ndarray ):
430+ self .green_gates = (self .green_gates ,)
431+
432+ if isinstance (self .blue_gates , jnp .ndarray ):
433+ self .blue_gates = (self .blue_gates ,)
434+
435+ if isinstance (self .red_gates , jnp .ndarray ):
436+ self .red_gates = (self .red_gates ,)
437+
429438 if (len (self .green_gates ) != len (self .blue_gates )) or (
430439 len (self .green_gates ) != len (self .red_gates )
431440 ):
Original file line number Diff line number Diff line change @@ -535,6 +535,18 @@ class Square_Kagome_Expectation_Value(Expectation_Model):
535535 spiral_unitary_operator : Optional [jnp .ndarray ] = None
536536
537537 def __post_init__ (self ) -> None :
538+ if isinstance (self .triangle_gates , jnp .ndarray ):
539+ self .triangle_gates = (self .triangle_gates ,)
540+
541+ if isinstance (self .square_gates , jnp .ndarray ):
542+ self .square_gates = (self .square_gates ,)
543+
544+ if isinstance (self .plus_gates , jnp .ndarray ):
545+ self .plus_gates = (self .plus_gates ,)
546+
547+ if isinstance (self .cross_gates , jnp .ndarray ):
548+ self .cross_gates = (self .cross_gates ,)
549+
538550 if len (self .cross_gates ) > 0 :
539551 raise NotImplementedError ("Cross term calculation is not implemented yet." )
540552
Original file line number Diff line number Diff line change @@ -55,6 +55,10 @@ class Triangular_Expectation_Value(Expectation_Model):
5555 is_spiral_peps : bool = False
5656 spiral_unitary_operator : Optional [jnp .ndarray ] = None
5757
58+ def __post_init__ (self ) -> None :
59+ if isinstance (self .nearest_neighbor_gates , jnp .ndarray ):
60+ self .nearest_neighbor_gates = (self .nearest_neighbor_gates ,)
61+
5862 def __call__ (
5963 self ,
6064 peps_tensors : Sequence [jnp .ndarray ],
You can’t perform that action at this time.
0 commit comments