Skip to content

Commit 8f35813

Browse files
committed
Add feature to make life eaiser for users of the expectation function classes
1 parent ab730c6 commit 8f35813

File tree

7 files changed

+56
-0
lines changed

7 files changed

+56
-0
lines changed

peps_ad/expectation/two_sites.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import jax.numpy as jnp
77
from jax import jit
88

9+
from peps_ad import peps_ad_config
910
from peps_ad.peps import PEPS_Tensor, PEPS_Unit_Cell
1011
from peps_ad.contractions import apply_contraction
1112
from .model import Expectation_Model
@@ -419,6 +420,12 @@ class Two_Sites_Expectation_Value(Expectation_Model):
419420
spiral_unitary_operator: Optional[jnp.ndarray] = None
420421

421422
def __post_init__(self) -> None:
423+
if isinstance(self.horizontal_gates, jnp.ndarray):
424+
self.horizontal_gates = (self.horizontal_gates,)
425+
426+
if isinstance(self.vertical_gates, jnp.ndarray):
427+
self.vertical_gates = (self.vertical_gates,)
428+
422429
if (
423430
len(self.horizontal_gates) > 0
424431
and len(self.horizontal_gates) > 0

peps_ad/mapping/florett_pentagon.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,15 @@ class Florett_Pentagon_Expectation_Value(Expectation_Model):
775775
spiral_unitary_operator: Optional[jnp.ndarray] = None
776776

777777
def __post_init__(self) -> None:
778+
if isinstance(self.black_gates, jnp.ndarray):
779+
self.black_gates = (self.black_gates,)
780+
781+
if isinstance(self.green_gates, jnp.ndarray):
782+
self.green_gates = (self.green_gates,)
783+
784+
if isinstance(self.blue_gates, jnp.ndarray):
785+
self.blue_gates = (self.blue_gates,)
786+
778787
if (len(self.green_gates) != len(self.blue_gates)) or (
779788
len(self.green_gates) != len(self.black_gates)
780789
):

peps_ad/mapping/honeycomb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ class Honeycomb_Expectation_Value(Expectation_Model):
5555
spiral_unitary_operator: Optional[jnp.ndarray] = None
5656

5757
def __post_init__(self) -> None:
58+
if isinstance(self.x_gates, jnp.ndarray):
59+
self.x_gates = (self.x_gates,)
60+
61+
if isinstance(self.y_gates, jnp.ndarray):
62+
self.y_gates = (self.y_gates,)
63+
64+
if isinstance(self.z_gates, jnp.ndarray):
65+
self.z_gates = (self.z_gates,)
66+
5867
if (
5968
(
6069
len(self.x_gates) > 0

peps_ad/mapping/kagome.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class Kagome_PESS3_Expectation_Value(Expectation_Model):
5353
spiral_unitary_operator: Optional[jnp.ndarray] = None
5454

5555
def __post_init__(self) -> None:
56+
if isinstance(self.upward_triangle_gates, jnp.ndarray):
57+
self.upward_triangle_gates = (self.upward_triangle_gates,)
58+
59+
if isinstance(self.downward_triangle_gates, jnp.ndarray):
60+
self.downward_triangle_gates = (self.downward_triangle_gates,)
61+
5662
if (
5763
len(self.upward_triangle_gates) > 0
5864
and len(self.downward_triangle_gates) > 0

peps_ad/mapping/maple_leaf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,15 @@ class Maple_Leaf_Expectation_Value(Expectation_Model):
426426
spiral_unitary_operator: Optional[jnp.ndarray] = None
427427

428428
def __post_init__(self) -> None:
429+
if isinstance(self.green_gates, jnp.ndarray):
430+
self.green_gates = (self.green_gates,)
431+
432+
if isinstance(self.blue_gates, jnp.ndarray):
433+
self.blue_gates = (self.blue_gates,)
434+
435+
if isinstance(self.red_gates, jnp.ndarray):
436+
self.red_gates = (self.red_gates,)
437+
429438
if (len(self.green_gates) != len(self.blue_gates)) or (
430439
len(self.green_gates) != len(self.red_gates)
431440
):

peps_ad/mapping/square_kagome.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,18 @@ class Square_Kagome_Expectation_Value(Expectation_Model):
535535
spiral_unitary_operator: Optional[jnp.ndarray] = None
536536

537537
def __post_init__(self) -> None:
538+
if isinstance(self.triangle_gates, jnp.ndarray):
539+
self.triangle_gates = (self.triangle_gates,)
540+
541+
if isinstance(self.square_gates, jnp.ndarray):
542+
self.square_gates = (self.square_gates,)
543+
544+
if isinstance(self.plus_gates, jnp.ndarray):
545+
self.plus_gates = (self.plus_gates,)
546+
547+
if isinstance(self.cross_gates, jnp.ndarray):
548+
self.cross_gates = (self.cross_gates,)
549+
538550
if len(self.cross_gates) > 0:
539551
raise NotImplementedError("Cross term calculation is not implemented yet.")
540552

peps_ad/mapping/triangular.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class Triangular_Expectation_Value(Expectation_Model):
5555
is_spiral_peps: bool = False
5656
spiral_unitary_operator: Optional[jnp.ndarray] = None
5757

58+
def __post_init__(self) -> None:
59+
if isinstance(self.nearest_neighbor_gates, jnp.ndarray):
60+
self.nearest_neighbor_gates = (self.nearest_neighbor_gates,)
61+
5862
def __call__(
5963
self,
6064
peps_tensors: Sequence[jnp.ndarray],

0 commit comments

Comments
 (0)