33
44from jax .tree_util import register_pytree_node_class
55
6- from typing import TypeVar , Tuple , Any , Type
6+ from typing import TypeVar , Tuple , Any , Type , NoReturn
77
88T_VariPEPS_Config = TypeVar ("T_VariPEPS_Config" , bound = "VariPEPS_Config" )
99
@@ -177,7 +177,6 @@ class VariPEPS_Config:
177177 Constant used in Hager-Zhang line search method.
178178 line_search_hager_zhang_rho (:obj:`float`):
179179 Constant used in Hager-Zhang line search method.
180-
181180 basinhopping_niter (:obj:`int`):
182181 Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
183182 See this function for details.
@@ -264,6 +263,25 @@ class VariPEPS_Config:
264263 # Spiral PEPS
265264 spiral_wavevector_type : Wavevector_Type = Wavevector_Type .TWO_PI_POSITIVE_ONLY
266265
266+ def update (self , name : str , value : Any ) -> NoReturn :
267+ self .__setattr__ (name , value )
268+
269+ def __setattr__ (self , name : str , value : Any ) -> NoReturn :
270+ try :
271+ field = self .__dataclass_fields__ [name ]
272+ except KeyError as e :
273+ raise KeyError (f"Unknown config option '{ name } '." ) from e
274+
275+ if not type (value ) is field .type :
276+ if field .type is float and type (value ) is int :
277+ pass
278+ else :
279+ raise TypeError (
280+ f"Type mismatch for option '{ name } ', got '{ type (value )} ', expected '{ field .type } '."
281+ )
282+
283+ super ().__setattr__ (name , value )
284+
267285 def tree_flatten (self ) -> Tuple [Tuple [Any , ...], Tuple [Any , ...]]:
268286 aux_data = (
269287 {name : getattr (self , name ) for name in self .__dataclass_fields__ .keys ()},
@@ -283,3 +301,35 @@ def tree_unflatten(
283301
284302
285303config = VariPEPS_Config ()
304+
305+
306+ class ConfigModuleWrapper :
307+ __slots__ = {
308+ "Optimizing_Methods" ,
309+ "Line_Search_Methods" ,
310+ "Projector_Method" ,
311+ "Wavevector_Type" ,
312+ "VariPEPS_Config" ,
313+ "config" ,
314+ }
315+
316+ def __init__ (self ):
317+ for e in self .__slots__ :
318+ setattr (self , e , globals ()[e ])
319+
320+ def __getattr__ (self , name : str ) -> Any :
321+ if name .startswith ("__" ) or name in self .__slots__ :
322+ return super ().__getattr__ (name )
323+ else :
324+ return getattr (self .config , name )
325+
326+ def __setattr__ (self , name : str , value : Any ) -> NoReturn :
327+ if not name .startswith ("__" ) and name not in self .__slots__ :
328+ setattr (self .config , name , value )
329+ elif not hasattr (self , name ):
330+ super ().__setattr__ (name , value )
331+ else :
332+ raise AttributeError (f"Attribute '{ name } ' is write-protected." )
333+
334+
335+ wrapper = ConfigModuleWrapper ()
0 commit comments