Skip to content

Commit 5595eab

Browse files
authored
Deprecate old serialization ((de)serialize_value_or_type) and add developer docs (#450)
* document and expose serialization module * add functools.wraps call to allow_kwargs decorator, as before it was breaking the autodoc functionality * restructure and update developer docs * move content to separate pages * update section on serialization * ci: update pip via python -m pip pip install -U pip setuptools wheel leads to an error: https://github.com/bayesflow-org/bayesflow/actions/runs/14692655483/job/41230057180?pr=449 * serializable: increase depth in sys._getframe The functools.wrap decorator adds a frame object to the call stack * deprecate (de)serialize_value_or_type - add deprecation warning, remove functionality - replace all occurences with the corresponding new functions
1 parent a322ff1 commit 5595eab

File tree

10 files changed

+169
-169
lines changed

10 files changed

+169
-169
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import keras
2-
from keras.saving import (
3-
deserialize_keras_object as deserialize,
4-
serialize_keras_object as serialize,
5-
register_keras_serializable as serializable,
6-
)
72

8-
from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
3+
from bayesflow.utils import model_kwargs, find_network
4+
from bayesflow.utils.serialization import deserialize, serializable, serialize
95
from bayesflow.types import Shape, Tensor
106
from bayesflow.scores import ScoringRule, ParametricDistributionScore
117
from bayesflow.utils.decorators import allow_batch_size
@@ -30,10 +26,10 @@ def __init__(
3026
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
3127

3228
self.config = {
29+
"subnet": serialize(subnet),
30+
"scores": serialize(scores),
3331
**kwargs,
3432
}
35-
self.config = serialize_value_or_type(self.config, "subnet", subnet)
36-
self.config["scores"] = serialize(self.scores)
3733

3834
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
3935
"""Builds all network components based on shapes of conditions and targets.
@@ -119,7 +115,7 @@ def get_config(self):
119115
def from_config(cls, config):
120116
config = config.copy()
121117
config["scores"] = deserialize(config["scores"])
122-
config = deserialize_value_or_type(config, "subnet")
118+
config["subnet"] = deserialize(config["subnet"])
123119
return cls(**config)
124120

125121
def call(

bayesflow/scores/scoring_rule.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import math
22

33
import keras
4-
from keras.saving import register_keras_serializable as serializable
54

65
from bayesflow.types import Shape, Tensor
7-
from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type
6+
from bayesflow.utils import find_network
7+
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

99

1010
@serializable(package="bayesflow.scores")
@@ -51,23 +51,16 @@ def __init__(
5151
self.config = {"subnets_kwargs": self.subnets_kwargs}
5252

5353
def get_config(self):
54-
self.config["subnets"] = {
55-
key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items()
56-
}
57-
self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()}
54+
self.config["subnets"] = {key: serialize(subnet) for key, subnet in self.subnets.items()}
55+
self.config["links"] = {key: serialize(link) for key, link in self.links.items()}
5856

5957
return self.config
6058

6159
@classmethod
6260
def from_config(cls, config):
6361
config = config.copy()
64-
config["subnets"] = {
65-
key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"]
66-
for key, subnet_dict in config["subnets"].items()
67-
}
68-
config["links"] = {
69-
key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items()
70-
}
62+
config["subnets"] = {key: deserialize(subnet_dict) for key, subnet_dict in config["subnets"].items()}
63+
config["links"] = {key: deserialize(link_dict) for key, link_dict in config["links"].items()}
7164

7265
return cls(**config)
7366

bayesflow/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
keras_utils,
77
logging,
88
numpy_utils,
9+
serialization,
910
)
1011

1112
from .callbacks import detailed_loss_callback
@@ -104,4 +105,4 @@
104105

105106
from ._docs import _add_imports_to_all
106107

107-
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"])
108+
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])

bayesflow/utils/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator:
1717
def wrapper(f: Fn) -> Fn: ...
1818
@overload
1919
def wrapper(*fargs: any, **fkwargs: any) -> Fn: ...
20+
@wraps(fn)
2021
def wrapper(*fargs: any, **fkwargs: any) -> Fn:
2122
if len(fargs) == 1 and not fkwargs and callable(fargs[0]):
2223
# called without arguments

bayesflow/utils/serialization.py

Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import keras
66
import numpy as np
77
import sys
8+
from warnings import warn
89

910
# this import needs to be exactly like this to work with monkey patching
1011
from keras.saving import deserialize_keras_object
@@ -19,109 +20,102 @@
1920

2021

2122
def serialize_value_or_type(config, name, obj):
22-
"""Serialize an object that can be either a value or a type
23-
and add it to a copy of the supplied dictionary.
23+
"""This function is deprecated."""
24+
warn(
25+
"This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.",
26+
DeprecationWarning,
27+
stacklevel=2,
28+
)
2429

25-
Parameters
26-
----------
27-
config : dict
28-
Dictionary to add the serialized object to. This function does not
29-
modify the dictionary in place, but returns a modified copy.
30-
name : str
31-
Name of the obj that should be stored. Required for later deserialization.
32-
obj : object or type
33-
The object to serialize. If `obj` is of type `type`, we use
34-
`keras.saving.get_registered_name` to obtain the registered type name.
35-
If it is not a type, we try to serialize it as a Keras object.
3630

37-
Returns
38-
-------
39-
updated_config : dict
40-
Updated dictionary with a new key `"_bayesflow_<name>_type"` or
41-
`"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
42-
the suffix indicates how the stored value has to be deserialized.
43-
44-
Notes
45-
-----
46-
We allow strings or `type` parameters at several places to instantiate objects
47-
of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
48-
be serialized, we have to distinguish the two cases for serialization and
49-
deserialization. This function is a helper function to standardize and
50-
simplify this.
51-
"""
52-
updated_config = config.copy()
53-
if isinstance(obj, type):
54-
updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
55-
else:
56-
updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
57-
return updated_config
31+
def deserialize_value_or_type(config, name):
32+
"""This function is deprecated."""
33+
warn(
34+
"This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.",
35+
DeprecationWarning,
36+
stacklevel=2,
37+
)
5838

5939

60-
def deserialize_value_or_type(config, name):
61-
"""Deserialize an object that can be either a value or a type and add
62-
it to the supplied dictionary.
40+
def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):
41+
"""Deserialize an object serialized with :py:func:`serialize`.
42+
43+
Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of
44+
classes.
6345
6446
Parameters
6547
----------
66-
config : dict
67-
Dictionary containing the object to deserialize. If a type was
68-
serialized, it should contain the key `"_bayesflow_<name>_type"`.
69-
If an object was serialized, it should contain the key
70-
`"_bayesflow_<name>_val"`. In a copy of this dictionary,
71-
the item will be replaced with the key `name`.
72-
name : str
73-
Name of the object to deserialize.
48+
config : dict
49+
Python dict describing the object.
50+
custom_objects : dict, optional
51+
Python dict containing a mapping between custom object names and the corresponding
52+
classes or functions. Forwarded to `keras.saving.deserialize_keras_object`.
53+
safe_mode : bool, optional
54+
Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False,
55+
loading an object has the potential to trigger arbitrary code execution. This argument
56+
is only applicable to the Keras v3 model format. Defaults to True.
57+
Forwarded to `keras.saving.deserialize_keras_object`.
7458
7559
Returns
7660
-------
77-
updated_config : dict
78-
Updated dictionary with a new key `name`, with a value that is either
79-
a type or an object.
61+
obj :
62+
The object described by the config dictionary.
63+
64+
Raises
65+
------
66+
ValueError
67+
If a type in the config can not be deserialized.
8068
8169
See Also
8270
--------
83-
serialize_value_or_type
71+
serialize
8472
"""
85-
updated_config = config.copy()
86-
if f"{PREFIX}{name}_type" in config:
87-
updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
88-
del updated_config[f"{PREFIX}{name}_type"]
89-
elif f"{PREFIX}{name}_val" in config:
90-
updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
91-
del updated_config[f"{PREFIX}{name}_val"]
92-
return updated_config
93-
94-
95-
def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9673
with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize:
97-
if isinstance(obj, str) and obj.startswith(_type_prefix):
74+
if isinstance(config, str) and config.startswith(_type_prefix):
9875
# we marked this as a type during serialization
99-
obj = obj[len(_type_prefix) :]
76+
config = config[len(_type_prefix) :]
10077
tp = keras.saving.get_registered_object(
10178
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
102-
obj,
79+
config,
10380
custom_objects=custom_objects,
10481
module_objects=np.__dict__ | builtins.__dict__,
10582
)
10683
if tp is None:
10784
raise ValueError(
108-
f"Could not deserialize type {obj!r}. Make sure it is registered with "
85+
f"Could not deserialize type {config!r}. Make sure it is registered with "
10986
f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
11087
)
11188
return tp
112-
if inspect.isclass(obj):
89+
if inspect.isclass(config):
11390
# add this base case since keras does not cover it
114-
return obj
91+
return config
11592

116-
obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
93+
obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
11794

11895
return obj
11996

12097

12198
@allow_args
122-
def serializable(cls, package=None, name=None):
99+
def serializable(cls, package: str | None = None, name: str | None = None):
100+
"""Register class as Keras serialize.
101+
102+
Wrapper function around `keras.saving.register_keras_serializable` to automatically
103+
set the `package` and `name` arguments.
104+
105+
Parameters
106+
----------
107+
cls : type
108+
The class to register.
109+
package : str, optional
110+
`package` argument forwarded to `keras.saving.register_keras_serializable`.
111+
If None is provided, the package is automatically inferred using the __name__
112+
attribute of the module the class resides in.
113+
name : str, optional
114+
`name` argument forwarded to `keras.saving.register_keras_serializable`.
115+
If None is provided, the classe's __name__ attribute is used.
116+
"""
123117
if package is None:
124-
frame = sys._getframe(1)
118+
frame = sys._getframe(2)
125119
g = frame.f_globals
126120
package = g.get("__name__", "bayesflow")
127121

@@ -133,6 +127,26 @@ def serializable(cls, package=None, name=None):
133127

134128

135129
def serialize(obj):
130+
"""Serialize an object using Keras.
131+
132+
Wrapper function around `keras.saving.serialize_keras_object`, which adds the
133+
ability to serialize classes.
134+
135+
Parameters
136+
----------
137+
object : Keras serializable object, or class
138+
The object to serialize
139+
140+
Returns
141+
-------
142+
config : dict
143+
A python dict that represents the object. The python dict can be deserialized via
144+
:py:func:`deserialize`.
145+
146+
See Also
147+
--------
148+
deserialize
149+
"""
136150
if isinstance(obj, (tuple, list, dict)):
137151
return keras.tree.map_structure(serialize, obj)
138152
elif inspect.isclass(obj):

0 commit comments

Comments
 (0)