55import keras
66import numpy as np
77import sys
8+ from warnings import warn
89
910# this import needs to be exactly like this to work with monkey patching
1011from keras .saving import deserialize_keras_object
1920
2021
2122def serialize_value_or_type (config , name , obj ):
22- """Serialize an object that can be either a value or a type
23- and add it to a copy of the supplied dictionary.
23+ """This function is deprecated."""
24+ warn (
25+ "This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize." ,
26+ DeprecationWarning ,
27+ stacklevel = 2 ,
28+ )
2429
25- Parameters
26- ----------
27- config : dict
28- Dictionary to add the serialized object to. This function does not
29- modify the dictionary in place, but returns a modified copy.
30- name : str
31- Name of the obj that should be stored. Required for later deserialization.
32- obj : object or type
33- The object to serialize. If `obj` is of type `type`, we use
34- `keras.saving.get_registered_name` to obtain the registered type name.
35- If it is not a type, we try to serialize it as a Keras object.
3630
37- Returns
38- -------
39- updated_config : dict
40- Updated dictionary with a new key `"_bayesflow_<name>_type"` or
41- `"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
42- the suffix indicates how the stored value has to be deserialized.
43-
44- Notes
45- -----
46- We allow strings or `type` parameters at several places to instantiate objects
47- of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
48- be serialized, we have to distinguish the two cases for serialization and
49- deserialization. This function is a helper function to standardize and
50- simplify this.
51- """
52- updated_config = config .copy ()
53- if isinstance (obj , type ):
54- updated_config [f"{ PREFIX } { name } _type" ] = keras .saving .get_registered_name (obj )
55- else :
56- updated_config [f"{ PREFIX } { name } _val" ] = keras .saving .serialize_keras_object (obj )
57- return updated_config
31+ def deserialize_value_or_type (config , name ):
32+ """This function is deprecated."""
33+ warn (
34+ "This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize." ,
35+ DeprecationWarning ,
36+ stacklevel = 2 ,
37+ )
5838
5939
60- def deserialize_value_or_type (config , name ):
61- """Deserialize an object that can be either a value or a type and add
62- it to the supplied dictionary.
40+ def deserialize (config : dict , custom_objects = None , safe_mode = True , ** kwargs ):
41+ """Deserialize an object serialized with :py:func:`serialize`.
42+
43+ Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of
44+ classes.
6345
6446 Parameters
6547 ----------
66- config : dict
67- Dictionary containing the object to deserialize. If a type was
68- serialized, it should contain the key `"_bayesflow_<name>_type"`.
69- If an object was serialized, it should contain the key
70- `"_bayesflow_<name>_val"`. In a copy of this dictionary,
71- the item will be replaced with the key `name`.
72- name : str
73- Name of the object to deserialize.
48+ config : dict
49+ Python dict describing the object.
50+ custom_objects : dict, optional
51+ Python dict containing a mapping between custom object names and the corresponding
52+ classes or functions. Forwarded to `keras.saving.deserialize_keras_object`.
53+ safe_mode : bool, optional
54+ Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False,
55+ loading an object has the potential to trigger arbitrary code execution. This argument
56+ is only applicable to the Keras v3 model format. Defaults to True.
57+ Forwarded to `keras.saving.deserialize_keras_object`.
7458
7559 Returns
7660 -------
77- updated_config : dict
78- Updated dictionary with a new key `name`, with a value that is either
79- a type or an object.
61+ obj :
62+ The object described by the config dictionary.
63+
64+ Raises
65+ ------
66+ ValueError
67+ If a type in the config can not be deserialized.
8068
8169 See Also
8270 --------
83- serialize_value_or_type
71+ serialize
8472 """
85- updated_config = config .copy ()
86- if f"{ PREFIX } { name } _type" in config :
87- updated_config [name ] = keras .saving .get_registered_object (config [f"{ PREFIX } { name } _type" ])
88- del updated_config [f"{ PREFIX } { name } _type" ]
89- elif f"{ PREFIX } { name } _val" in config :
90- updated_config [name ] = keras .saving .deserialize_keras_object (config [f"{ PREFIX } { name } _val" ])
91- del updated_config [f"{ PREFIX } { name } _val" ]
92- return updated_config
93-
94-
95- def deserialize (obj , custom_objects = None , safe_mode = True , ** kwargs ):
9673 with monkey_patch (deserialize_keras_object , deserialize ) as original_deserialize :
97- if isinstance (obj , str ) and obj .startswith (_type_prefix ):
74+ if isinstance (config , str ) and config .startswith (_type_prefix ):
9875 # we marked this as a type during serialization
99- obj = obj [len (_type_prefix ) :]
76+ config = config [len (_type_prefix ) :]
10077 tp = keras .saving .get_registered_object (
10178 # TODO: can we pass module objects without overwriting numpy's dict with builtins?
102- obj ,
79+ config ,
10380 custom_objects = custom_objects ,
10481 module_objects = np .__dict__ | builtins .__dict__ ,
10582 )
10683 if tp is None :
10784 raise ValueError (
108- f"Could not deserialize type { obj !r} . Make sure it is registered with "
85+ f"Could not deserialize type { config !r} . Make sure it is registered with "
10986 f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
11087 )
11188 return tp
112- if inspect .isclass (obj ):
89+ if inspect .isclass (config ):
11390 # add this base case since keras does not cover it
114- return obj
91+ return config
11592
116- obj = original_deserialize (obj , custom_objects = custom_objects , safe_mode = safe_mode , ** kwargs )
93+ obj = original_deserialize (config , custom_objects = custom_objects , safe_mode = safe_mode , ** kwargs )
11794
11895 return obj
11996
12097
12198@allow_args
122- def serializable (cls , package = None , name = None ):
99+ def serializable (cls , package : str | None = None , name : str | None = None ):
100+ """Register class as Keras serialize.
101+
102+ Wrapper function around `keras.saving.register_keras_serializable` to automatically
103+ set the `package` and `name` arguments.
104+
105+ Parameters
106+ ----------
107+ cls : type
108+ The class to register.
109+ package : str, optional
110+ `package` argument forwarded to `keras.saving.register_keras_serializable`.
111+ If None is provided, the package is automatically inferred using the __name__
112+ attribute of the module the class resides in.
113+ name : str, optional
114+ `name` argument forwarded to `keras.saving.register_keras_serializable`.
115+ If None is provided, the classe's __name__ attribute is used.
116+ """
123117 if package is None :
124- frame = sys ._getframe (1 )
118+ frame = sys ._getframe (2 )
125119 g = frame .f_globals
126120 package = g .get ("__name__" , "bayesflow" )
127121
@@ -133,6 +127,26 @@ def serializable(cls, package=None, name=None):
133127
134128
135129def serialize (obj ):
130+ """Serialize an object using Keras.
131+
132+ Wrapper function around `keras.saving.serialize_keras_object`, which adds the
133+ ability to serialize classes.
134+
135+ Parameters
136+ ----------
137+ object : Keras serializable object, or class
138+ The object to serialize
139+
140+ Returns
141+ -------
142+ config : dict
143+ A python dict that represents the object. The python dict can be deserialized via
144+ :py:func:`deserialize`.
145+
146+ See Also
147+ --------
148+ deserialize
149+ """
136150 if isinstance (obj , (tuple , list , dict )):
137151 return keras .tree .map_structure (serialize , obj )
138152 elif inspect .isclass (obj ):
0 commit comments