Skip to content

Commit e2c8304

Browse files
vpratzarrjonLarsKue
authored
Add diffusion model implementation (#408)
This commit contains the following changes (see PR #408 for discussions) - DiffusionModel following the formalism in Kingma et. al (2023) [1] - Stochastic sampler to solve SDEs - Tests for the diffusion model [1] https://arxiv.org/abs/2303.00848 --------- Co-authored-by: arrjon <jonas.arruda@uni-bonn.de> Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com> Co-authored-by: LarsKue <lars@kuehmichel.de>
1 parent ccf9ca0 commit e2c8304

File tree

15 files changed

+1204
-7
lines changed

15 files changed

+1204
-7
lines changed

bayesflow/experimental/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from .cif import CIF
66
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
7+
from .diffusion_model import DiffusionModel
78
from .free_form_flow import FreeFormFlow
89

910
from ..utils._docs import _add_imports_to_all
1011

11-
_add_imports_to_all(include_modules=[])
12+
_add_imports_to_all(include_modules=["diffusion_model"])
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .diffusion_model import DiffusionModel
2+
from .noise_schedule import NoiseSchedule
3+
from .cosine_noise_schedule import CosineNoiseSchedule
4+
from .edm_noise_schedule import EDMNoiseSchedule
5+
from .dispatch import find_noise_schedule
6+
7+
from ...utils._docs import _add_imports_to_all
8+
9+
_add_imports_to_all(include_modules=[])
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import math
2+
from typing import Union, Literal
3+
4+
from keras import ops
5+
6+
from bayesflow.types import Tensor
7+
from bayesflow.utils.serialization import deserialize, serializable
8+
9+
from .noise_schedule import NoiseSchedule
10+
11+
12+
# disable module check, use potential module after moving from experimental
13+
@serializable("bayesflow.networks", disable_module_check=True)
14+
class CosineNoiseSchedule(NoiseSchedule):
15+
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
16+
17+
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
18+
"""
19+
20+
def __init__(
21+
self,
22+
min_log_snr: float = -15,
23+
max_log_snr: float = 15,
24+
shift: float = 0.0,
25+
weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid",
26+
):
27+
"""
28+
Initialize the cosine noise schedule.
29+
30+
Parameters
31+
----------
32+
min_log_snr : float, optional
33+
The minimum log signal-to-noise ratio (lambda). Default is -15.
34+
max_log_snr : float, optional
35+
The maximum log signal-to-noise ratio (lambda). Default is 15.
36+
shift : float, optional
37+
Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
38+
For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
39+
weighting : Literal["sigmoid", "likelihood_weighting"], optional
40+
The type of weighting function to use for the noise schedule. Default is "sigmoid".
41+
"""
42+
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
43+
self._shift = shift
44+
self._weighting = weighting
45+
self.log_snr_min = min_log_snr
46+
self.log_snr_max = max_log_snr
47+
48+
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
49+
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
50+
51+
def _truncated_t(self, t: Tensor) -> Tensor:
52+
return self._t_min + (self._t_max - self._t_min) * t
53+
54+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
55+
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
56+
t_trunc = self._truncated_t(t)
57+
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
58+
59+
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
60+
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
61+
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
62+
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
63+
64+
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
65+
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
66+
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)
67+
68+
# Compute the truncated time t_trunc
69+
t_trunc = self._truncated_t(t)
70+
dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc)
71+
72+
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
73+
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
74+
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
75+
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
76+
return -factor * dsnr_dt
77+
78+
def get_config(self):
79+
return dict(
80+
min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting
81+
)
82+
83+
@classmethod
84+
def from_config(cls, config, custom_objects=None):
85+
return cls(**deserialize(config, custom_objects=custom_objects))

0 commit comments

Comments
 (0)