|
| 1 | +import math |
| 2 | +from typing import Union, Literal |
| 3 | + |
| 4 | +from keras import ops |
| 5 | + |
| 6 | +from bayesflow.types import Tensor |
| 7 | +from bayesflow.utils.serialization import deserialize, serializable |
| 8 | + |
| 9 | +from .noise_schedule import NoiseSchedule |
| 10 | + |
| 11 | + |
| 12 | +# disable module check, use potential module after moving from experimental |
| 13 | +@serializable("bayesflow.networks", disable_module_check=True) |
| 14 | +class CosineNoiseSchedule(NoiseSchedule): |
| 15 | + """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1]. |
| 16 | +
|
| 17 | + [1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022) |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + min_log_snr: float = -15, |
| 23 | + max_log_snr: float = 15, |
| 24 | + shift: float = 0.0, |
| 25 | + weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid", |
| 26 | + ): |
| 27 | + """ |
| 28 | + Initialize the cosine noise schedule. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + min_log_snr : float, optional |
| 33 | + The minimum log signal-to-noise ratio (lambda). Default is -15. |
| 34 | + max_log_snr : float, optional |
| 35 | + The maximum log signal-to-noise ratio (lambda). Default is 15. |
| 36 | + shift : float, optional |
| 37 | + Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0. |
| 38 | + For images, use shift = log(base_resolution / d), where d is the used resolution of the image. |
| 39 | + weighting : Literal["sigmoid", "likelihood_weighting"], optional |
| 40 | + The type of weighting function to use for the noise schedule. Default is "sigmoid". |
| 41 | + """ |
| 42 | + super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting) |
| 43 | + self._shift = shift |
| 44 | + self._weighting = weighting |
| 45 | + self.log_snr_min = min_log_snr |
| 46 | + self.log_snr_max = max_log_snr |
| 47 | + |
| 48 | + self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True) |
| 49 | + self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True) |
| 50 | + |
| 51 | + def _truncated_t(self, t: Tensor) -> Tensor: |
| 52 | + return self._t_min + (self._t_max - self._t_min) * t |
| 53 | + |
| 54 | + def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor: |
| 55 | + """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" |
| 56 | + t_trunc = self._truncated_t(t) |
| 57 | + return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift |
| 58 | + |
| 59 | + def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor: |
| 60 | + """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" |
| 61 | + # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2)) |
| 62 | + return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5)) |
| 63 | + |
| 64 | + def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor: |
| 65 | + """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE.""" |
| 66 | + t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training) |
| 67 | + |
| 68 | + # Compute the truncated time t_trunc |
| 69 | + t_trunc = self._truncated_t(t) |
| 70 | + dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc) |
| 71 | + |
| 72 | + # Using the chain rule on f(t) = log(1 + e^(-snr(t))): |
| 73 | + # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt |
| 74 | + dsnr_dt = dsnr_dx * (self._t_max - self._t_min) |
| 75 | + factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t)) |
| 76 | + return -factor * dsnr_dt |
| 77 | + |
| 78 | + def get_config(self): |
| 79 | + return dict( |
| 80 | + min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting |
| 81 | + ) |
| 82 | + |
| 83 | + @classmethod |
| 84 | + def from_config(cls, config, custom_objects=None): |
| 85 | + return cls(**deserialize(config, custom_objects=custom_objects)) |
0 commit comments