1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from copy import deepcopy
15- from dataclasses import Field , dataclass , fields
15+ from dataclasses import MISSING , Field , dataclass , fields
1616from typing import Any , ClassVar
1717
1818import numpy as np
@@ -67,7 +67,16 @@ def sampling_state(self) -> DataClassState:
6767 state_class = self ._state_class
6868 kwargs = {}
6969 for field in fields (state_class ):
70- val = getattr (self , field .name )
70+ is_tensor_name = field .metadata .get ("tensor_name" , False )
71+ val : Any
72+ if is_tensor_name :
73+ val = [var .name for var in getattr (self , "vars" )]
74+ else :
75+ val = getattr (self , field .name , field .default )
76+ if val is MISSING :
77+ raise AttributeError (
78+ f"{ type (self ).__name__ !r} object has no attribute { field .name !r} "
79+ )
7180 _val : Any
7281 if isinstance (val , WithSamplingState ):
7382 _val = val .sampling_state
@@ -85,11 +94,17 @@ def sampling_state(self, state: DataClassState):
8594 state , state_class
8695 ), f"Encountered invalid state class '{ state .__class__ } '. State must be '{ state_class } '"
8796 for field in fields (state_class ):
97+ is_tensor_name = field .metadata .get ("tensor_name" , False )
8898 state_val = deepcopy (getattr (state , field .name ))
8999 if isinstance (state_val , RandomGeneratorState ):
90100 state_val = random_generator_from_state (state_val )
91- self_val = getattr (self , field .name )
92101 is_frozen = field .metadata .get ("frozen" , False )
102+ self_val : Any
103+ if is_tensor_name :
104+ self_val = [var .name for var in getattr (self , "vars" )]
105+ assert is_frozen
106+ else :
107+ self_val = getattr (self , field .name , field .default )
93108 if is_frozen :
94109 if not equal_dataclass_values (state_val , self_val ):
95110 raise ValueError (
0 commit comments