@@ -1987,6 +1987,144 @@ def __post_init__(self) -> None:
19871987 # "At least one transfer tensors mismatch bond dimensions of PEPS tensor."
19881988 # )
19891989
1990+ @classmethod
1991+ def from_tensor (
1992+ cls : Type [T_PEPS_Tensor ],
1993+ tensor : Tensor ,
1994+ d : int ,
1995+ D : Union [int , Sequence [int ]],
1996+ chi : int ,
1997+ interlayer_chi : Optional [int ] = None ,
1998+ max_chi : Optional [int ] = None ,
1999+ * ,
2000+ ctm_tensors_are_identities : bool = True ,
2001+ normalize : bool = True ,
2002+ seed : Optional [int ] = None ,
2003+ backend : str = "jax" ,
2004+ ) -> T_PEPS_Tensor :
2005+ """
2006+ Initialize a PEPS tensor object with a given tensor and new CTM tensors.
2007+
2008+ Args:
2009+ tensor (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
2010+ PEPS tensor to initialize the object with
2011+ d (:obj:`int`):
2012+ Physical dimension
2013+ D (:obj:`int` or :term:`sequence` of :obj:`int`):
2014+ Bond dimensions for the PEPS tensor
2015+ chi (:obj:`int`):
2016+ Bond dimension for the environment tensors
2017+ interlayer_chi (:obj:`int`):
2018+ Bond dimension for the interlayer bonds of the environment tensors
2019+ max_chi (:obj:`int`):
2020+ Maximal allowed bond dimension for the environment tensors
2021+ Keyword args:
2022+ ctm_tensors_are_identities (:obj:`bool`, optional):
2023+ Flag if the CTM tensors are initialized as identities. Otherwise,
2024+ they are initialized randomly. Defaults to True.
2025+ normalize (:obj:`bool`, optional):
2026+ Flag if the generated tensors are normalized. Defaults to True.
2027+ seed (:obj:`int`, optional):
2028+ Seed for the random number generator.
2029+ backend (:obj:`str`, optional):
2030+ Backend for the generated tensors (may be ``jax`` or ``numpy``).
2031+ Defaults to ``jax``.
2032+ Returns:
2033+ PEPS_Tensor:
2034+ Instance of PEPS_Tensor with the randomly initialized tensors.
2035+ """
2036+ if not is_tensor (tensor ):
2037+ raise ValueError ("Invalid argument for tensor." )
2038+
2039+ if isinstance (D , int ):
2040+ D = (D ,) * 4
2041+ elif isinstance (D , collections .abc .Sequence ) and not isinstance (D , tuple ):
2042+ D = tuple (D )
2043+
2044+ if not all (isinstance (i , int ) for i in D ) or not len (D ) == 4 :
2045+ raise ValueError ("Invalid argument for D." )
2046+
2047+ if (
2048+ tensor .shape [0 ] != D [0 ]
2049+ or tensor .shape [1 ] != D [1 ]
2050+ or tensor .shape [3 ] != D [2 ]
2051+ or tensor .shape [4 ] != D [3 ]
2052+ or tensor .shape [2 ] != d
2053+ ):
2054+ raise ValueError ("Tensor dimensions mismatch the dimension arguments." )
2055+
2056+ if interlayer_chi is None :
2057+ interlayer_chi = chi
2058+ if max_chi is None :
2059+ max_chi = chi
2060+
2061+ dtype = tensor .dtype
2062+
2063+ if ctm_tensors_are_identities :
2064+ C1 = jnp .ones ((1 , 1 ), dtype = dtype )
2065+ C2 = jnp .ones ((1 , 1 ), dtype = dtype )
2066+ C3 = jnp .ones ((1 , 1 ), dtype = dtype )
2067+ C4 = jnp .ones ((1 , 1 ), dtype = dtype )
2068+
2069+ T1 = jnp .eye (D [3 ], dtype = dtype ).reshape (1 , D [3 ], D [3 ], 1 )
2070+ T2 = jnp .eye (D [2 ], dtype = dtype ).reshape (D [2 ], D [2 ], 1 , 1 )
2071+ T3 = jnp .eye (D [1 ], dtype = dtype ).reshape (1 , 1 , D [1 ], D [1 ])
2072+ T4 = jnp .eye (D [0 ], dtype = dtype ).reshape (1 , D [0 ], D [0 ], 1 )
2073+
2074+ return cls (
2075+ tensor = tensor ,
2076+ C1 = C1 ,
2077+ C2 = C2 ,
2078+ C3 = C3 ,
2079+ C4 = C4 ,
2080+ T1 = T1 ,
2081+ T2 = T2 ,
2082+ T3 = T3 ,
2083+ T4 = T4 ,
2084+ d = d ,
2085+ D = D , # type: ignore
2086+ chi = chi ,
2087+ interlayer_chi = interlayer_chi ,
2088+ max_chi = max_chi ,
2089+ )
2090+ else :
2091+ rng = PEPS_Random_Number_Generator .get_generator (seed , backend = backend )
2092+
2093+ C1 = rng .block ((chi , chi ), dtype , normalize = normalize )
2094+ C2 = rng .block ((chi , chi ), dtype , normalize = normalize )
2095+ C3 = rng .block ((chi , chi ), dtype , normalize = normalize )
2096+ C4 = rng .block ((chi , chi ), dtype , normalize = normalize )
2097+
2098+ T1_ket = rng .block ((chi , D [3 ], interlayer_chi ), dtype , normalize = normalize )
2099+ T1_bra = rng .block ((interlayer_chi , D [3 ], chi ), dtype , normalize = normalize )
2100+ T2_ket = rng .block ((interlayer_chi , D [2 ], chi ), dtype , normalize = normalize )
2101+ T2_bra = rng .block ((chi , D [2 ], interlayer_chi ), dtype , normalize = normalize )
2102+ T3_ket = rng .block ((chi , D [1 ], interlayer_chi ), dtype , normalize = normalize )
2103+ T3_bra = rng .block ((interlayer_chi , D [1 ], chi ), dtype , normalize = normalize )
2104+ T4_ket = rng .block ((interlayer_chi , D [0 ], chi ), dtype , normalize = normalize )
2105+ T4_bra = rng .block ((chi , D [0 ], interlayer_chi ), dtype , normalize = normalize )
2106+
2107+ return cls (
2108+ tensor = tensor ,
2109+ C1 = C1 ,
2110+ C2 = C2 ,
2111+ C3 = C3 ,
2112+ C4 = C4 ,
2113+ T1_ket = T1_ket ,
2114+ T1_bra = T1_bra ,
2115+ T2_ket = T2_ket ,
2116+ T2_bra = T2_bra ,
2117+ T3_ket = T3_ket ,
2118+ T3_bra = T3_bra ,
2119+ T4_ket = T4_ket ,
2120+ T4_bra = T4_bra ,
2121+ d = d ,
2122+ D = D , # type: ignore
2123+ chi = chi ,
2124+ interlayer_chi = interlayer_chi ,
2125+ max_chi = max_chi ,
2126+ )
2127+
19902128 @property
19912129 def left_upper_transfer_shape (self ) -> Tensor :
19922130 return self .T4_ket .shape [2 ]
0 commit comments