From 82b3ab4088a82b3bc951175a0e9d5f9ee971a77a Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 6 Sep 2025 11:31:27 +0200 Subject: [PATCH 01/61] allow tensor in DiagonalNormal dimension --- bayesflow/distributions/diagonal_normal.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 6b64445c7..25a7797df 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -57,7 +57,7 @@ def __init__( self.trainable_parameters = trainable_parameters self.seed_generator = seed_generator or keras.random.SeedGenerator() - self.dim = None + self.dims = None self._mean = None self._std = None @@ -65,10 +65,10 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dim = int(input_shape[-1]) + self.dims = input_shape[1:] - self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") - self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") + self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") + self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") if self.trainable_parameters: self._mean = self.add_weight( @@ -91,14 +91,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) + log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + ops.log(self._std) + ) result += log_normalization_constant return result @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: - return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator) + return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator) def get_config(self): base_config = super().get_config() From 8fbf7374ca95826e6f7f954e237c0c8d1a955b95 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:04:54 +0200 Subject: [PATCH 02/61] fix sum dims --- bayesflow/distributions/diagonal_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 25a7797df..9cf068137 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -91,7 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( ops.log(self._std) ) result += log_normalization_constant From 5c27246b3aa48f9c8ba400596200c20ada7251ae Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:08:16 +0200 Subject: [PATCH 03/61] fix batch_shape for sample --- bayesflow/approximators/continuous_approximator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index fb2e95a56..f27a612f0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,7 +535,10 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = keras.ops.shape(inference_conditions)[:-1] + batch_shape = ( + batch_size, + num_samples, + ) else: batch_shape = (num_samples,) From c684bcace2add939945da94af16f3de8b0f9cdc9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 17:55:10 +0200 Subject: [PATCH 04/61] dims to tuple --- bayesflow/distributions/diagonal_normal.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 9cf068137..83b3e556b 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -65,7 +65,7 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dims = input_shape[1:] + self.dims = tuple(input_shape[1:]) self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") @@ -91,9 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( - ops.log(self._std) - ) + log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) result += log_normalization_constant return result From 06976344cb70d9e268392437d0240268cf3d4778 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:04:34 +0200 Subject: [PATCH 05/61] first draft compositional --- .../approximators/continuous_approximator.py | 164 +++++++++++ .../diffusion_model/diffusion_model.py | 261 ++++++++++++++++++ bayesflow/networks/inference_network.py | 17 ++ 3 files changed, 442 insertions(+) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f27a612f0..5a183922f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -638,3 +638,167 @@ def _batch_size_from_data(self, data: Mapping[str, any]) -> int: inference variables as present. """ return keras.ops.shape(data["inference_variables"])[0] + + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + split: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Generates compositional samples from the approximator given input conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + This method handles the extra compositional dimension appropriately. + + Parameters + ---------- + num_samples : int + Number of samples to generate. + conditions : dict[str, np.ndarray] + Dictionary of conditioning variables as NumPy arrays with shape + (n_datasets, n_compositional_conditions, ...). + split : bool, default=False + Whether to split the output arrays along the last axis and return one column vector per target variable + samples. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing generated samples with compositional structure preserved. + """ + original_shapes = {} + flattened_conditions = {} + for key, value in conditions.items(): # Flatten compositional dimensions + original_shapes[key] = value.shape + n_datasets, n_comp = value.shape[:2] + flattened_shape = (n_datasets * n_comp,) + value.shape[2:] + flattened_conditions[key] = value.reshape(flattened_shape) + n_datasets, n_comp = original_shapes[next(iter(original_shapes))][:2] + + # Prepare data using existing method (handles adaptation and standardization) + prepared_conditions = self._prepare_data(flattened_conditions, **kwargs) + + # Remove any superfluous keys, just retain actual conditions + prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + + # Sample using compositional sampling + samples = self._compositional_sample( + num_samples=num_samples, n_datasets=n_datasets, n_compositional=n_comp, **prepared_conditions, **kwargs + ) + + if "inference_variables" in self.standardize: + samples = self.standardize_layers["inference_variables"](samples, forward=False) + + samples = {"inference_variables": samples} + samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) + + # Back-transform quantities and samples + samples = self._back_transform_compositional(samples, original_shapes, **kwargs) + + if split: + samples = split_arrays(samples, axis=-1) + return samples + + def _compositional_sample( + self, + num_samples: int, + n_datasets: int, + n_compositional: int, + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + **kwargs, + ) -> Tensor: + """ + Internal method for compositional sampling. + """ + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot use summary variables without a summary network.") + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + if self.summary_network is not None: + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) + inference_conditions = concatenate_valid([inference_conditions, summary_outputs], axis=-1) + + if inference_conditions is not None: + # Reshape conditions for compositional sampling + # From (n_datasets * n_comp, dims) to (n_datasets, n_comp, dims) + condition_dims = keras.ops.shape(inference_conditions)[-1] + inference_conditions = keras.ops.reshape( + inference_conditions, (n_datasets, n_compositional, condition_dims) + ) + + # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) + inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) + inference_conditions = keras.ops.broadcast_to( + inference_conditions, (n_datasets, n_compositional, num_samples, condition_dims) + ) + + batch_shape = (n_datasets, n_compositional, num_samples) + else: + raise ValueError("Cannot perform compositional sampling without inference conditions.") + + return self.inference_network.sample( + batch_shape, + conditions=inference_conditions, + compositional=True, + **filter_kwargs(kwargs, self.inference_network.sample), + ) + + def _back_transform_compositional( + self, samples: dict[str, np.ndarray], original_shapes: dict[str, tuple], **kwargs + ) -> dict[str, np.ndarray]: + """ + Back-transform compositional samples, handling the extra compositional dimension. + """ + # Get the sample shape to understand the compositional structure + inference_samples = samples["inference_variables"] + sample_shape = inference_samples.shape + + # Determine compositional dimensions from original shapes + # Assuming all condition keys have the same compositional structure + first_key = next(iter(original_shapes.keys())) + n_datasets, n_compositional = original_shapes[first_key][:2] + + # Reshape samples to match compositional structure if needed + if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, dims) + num_samples, dims = sample_shape[1], sample_shape[2] + inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, dims) + samples["inference_variables"] = inference_samples + + # For back-transformation, we might need to flatten again temporarily + # depending on how the adapter expects the data + flattened_samples = {} + for key, value in samples.items(): + if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, dims) + n_d, n_c, n_s, dims = value.shape + flattened_samples[key] = value.reshape(n_d * n_c, n_s, dims) + else: + flattened_samples[key] = value + + # Apply inverse transformation + transformed = self.adapter(flattened_samples, inverse=True, strict=False, **kwargs) + + # Reshape back to compositional structure + final_samples = {} + for key, value in transformed.items(): + if key in original_shapes: + # Reshape to include compositional dimension + if len(value.shape) >= 2: + num_samples = value.shape[1] + remaining_dims = value.shape[2:] if len(value.shape) > 2 else () + final_samples[key] = value.reshape(n_datasets, n_compositional, num_samples, *remaining_dims) + else: + final_samples[key] = value + else: + final_samples[key] = value + + return final_samples diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..e815b89db 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -362,6 +362,7 @@ def _forward( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} @@ -412,6 +413,7 @@ def _inverse( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} @@ -541,3 +543,262 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_datasets, n_compositional = ops.shape(xz)[0], ops.shape(xz)[1] + print(xz.shape, n_datasets, n_compositional) + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + # Compute individual dataset scores + individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) + + # Compute prior score component + prior_score = self.compute_prior_score(xz) + + # Combine scores using compositional formula + # s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + n = ops.cast(n_compositional, dtype=ops.dtype(time)) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + + # Sum individual scores across compositional dimension + summed_individual_scores = ops.sum(individual_scores, axis=1, keepdims=True) + + # Prior contribution: (1-n)(1-t) * prior_score + prior_weight = (1.0 - n) * (1.0 - time_tensor) + weighted_prior = prior_weight * prior_score + + # Combined score + compositional_score = weighted_prior + summed_individual_scores + + # Broadcast back to full compositional shape + compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + print(velocity.shape) + return velocity + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + alpha_t: Tensor, + sigma_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Apply subnet to each compositional condition separately + transformed_log_snr = self._transform_log_snr(log_snr_t) + + # Reshape for processing: flatten compositional dimension temporarily + original_shape = ops.shape(xz) + n_datasets, n_comp = original_shape[0], original_shape[1] + remaining_dims = original_shape[2:] + + # Flatten for subnet application + xz_flat = ops.reshape(xz, (n_datasets * n_comp,) + remaining_dims) + log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) + conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) + alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) + sigma_flat = ops.reshape(sigma_t, (n_datasets * n_comp,) + ops.shape(sigma_t)[2:]) + + # Apply subnet + subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) + pred = self.output_projector(subnet_out, training=training) + + # Convert prediction to x + x_pred = self.convert_prediction_to_x( + pred=pred, z=xz_flat, alpha_t=alpha_flat, sigma_t=sigma_flat, log_snr_t=log_snr_flat + ) + + # Compute score: (α_t * x_pred - z) / σ_t² + score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) + + # Reshape back to compositional structure + score = ops.reshape(score, original_shape) + + return score + + def _compositional_forward( + self, + x: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Forward pass for compositional diffusion. + """ + integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for forward integration.") + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + x = x / ops.sqrt(ops.cast(ops.shape(x)[1], dtype=ops.dtype(x))) + + if density: + + def deltas(time, xz): + v = self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + # For density, we need trace but compositional trace is complex + # Simplified version - could be extended + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": x, + "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), + } + state = integrate(deltas, state, **integrate_kwargs) + + z = state["xz"] + # Simplified density computation + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) + ops.squeeze(state["trace"], axis=-1) + return z, log_density + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + } + + state = {"xz": x} + state = integrate(deltas, state, **integrate_kwargs) + z = state["xz"] + return z + + def _compositional_inverse( + self, + z: Tensor, + conditions: Tensor = None, + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion (sampling). + """ + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=True, conditions=conditions, training=training + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, time=time, stochastic_solver=False, conditions=conditions, training=training + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x + + @staticmethod + def compute_prior_score(xz: Tensor) -> Tensor: + return ops.ones_like(xz) # todo: Placeholder implementation + # raise NotImplementedError('Please implement the prior score computation method.') diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..f2e5c512f 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -27,11 +27,18 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, + compositional: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: + if compositional: + return self._inverse_compositional( + xz, conditions=conditions, density=density, training=training, **kwargs + ) return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) + if compositional: + return self._forward_compositional(xz, conditions=conditions, density=density, training=training, **kwargs) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( @@ -44,6 +51,16 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError + def _forward_compositional( + self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + + def _inverse_compositional( + self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + @allow_batch_size def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample(batch_shape) From b8e849e4393f3a5199c39beb9db6a7ed710a4063 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:12:52 +0200 Subject: [PATCH 06/61] first draft compositional --- bayesflow/workflows/basic_workflow.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 34fa03794..f1941ac3a 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -286,6 +286,37 @@ def sample( """ return self.approximator.sample(num_samples=num_samples, conditions=conditions, **kwargs) + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Draws `num_samples` samples from the approximator given specified composition conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + + Parameters + ---------- + num_samples : int + The number of samples to generate. + conditions : dict[str, np.ndarray] + A dictionary where keys represent variable names and values are + NumPy arrays containing the adapted simulated variables. Keys used as summary or inference + conditions during training should be present. + Should have shape (n_datasets, n_compositional_conditions, ...). + **kwargs : dict, optional + Additional keyword arguments passed to the approximator's sampling function. + + Returns + ------- + dict[str, np.ndarray] + A dictionary where keys correspond to variable names and + values are arrays containing the generated samples. + """ + return self.approximator.compositional_sample(num_samples=num_samples, conditions=conditions, **kwargs) + def estimate( self, *, From a280af32c5dd6ea9c657079b8e08116f7cde374b Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 13:15:16 +0200 Subject: [PATCH 07/61] first draft compositional --- bayesflow/networks/diffusion_model/diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index e815b89db..fa32b2bc6 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -672,7 +672,7 @@ def _compute_individual_scores( return score - def _compositional_forward( + def _forward_compositional( self, x: Tensor, conditions: Tensor = None, @@ -727,7 +727,7 @@ def deltas(time, xz): z = state["xz"] return z - def _compositional_inverse( + def _inverse_compositional( self, z: Tensor, conditions: Tensor = None, From b9faf31104821473f76c9dfa773cb403f7f82190 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:03:01 +0200 Subject: [PATCH 08/61] first draft compositional --- .../approximators/continuous_approximator.py | 23 ++++++++++--------- .../diffusion_model/diffusion_model.py | 6 ++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 5a183922f..c28424d6f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -730,19 +730,19 @@ def _compositional_sample( if inference_conditions is not None: # Reshape conditions for compositional sampling - # From (n_datasets * n_comp, dims) to (n_datasets, n_comp, dims) - condition_dims = keras.ops.shape(inference_conditions)[-1] + # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) + condition_dims = keras.ops.shape(inference_conditions)[2:] inference_conditions = keras.ops.reshape( - inference_conditions, (n_datasets, n_compositional, condition_dims) + inference_conditions, (n_datasets, n_compositional, *condition_dims) ) # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) inference_conditions = keras.ops.broadcast_to( - inference_conditions, (n_datasets, n_compositional, num_samples, condition_dims) + inference_conditions, (n_datasets, n_compositional, num_samples, *condition_dims) ) - batch_shape = (n_datasets, n_compositional, num_samples) + batch_shape = (n_datasets, num_samples) else: raise ValueError("Cannot perform compositional sampling without inference conditions.") @@ -769,18 +769,19 @@ def _back_transform_compositional( n_datasets, n_compositional = original_shapes[first_key][:2] # Reshape samples to match compositional structure if needed - if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, dims) - num_samples, dims = sample_shape[1], sample_shape[2] - inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, dims) + if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, ..., dims) + num_samples, dims = sample_shape[1], sample_shape[2:] + inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, *dims) samples["inference_variables"] = inference_samples # For back-transformation, we might need to flatten again temporarily # depending on how the adapter expects the data flattened_samples = {} for key, value in samples.items(): - if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, dims) - n_d, n_c, n_s, dims = value.shape - flattened_samples[key] = value.reshape(n_d * n_c, n_s, dims) + if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, ..., dims) + n_d, n_c, n_s = value.shape[:3] + dims = value.shape[3:] + flattened_samples[key] = value.reshape(n_d * n_c, n_s, *dims) else: flattened_samples[key] = value diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index fa32b2bc6..157bc360c 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -578,8 +578,8 @@ def compositional_velocity( raise ValueError("Conditions are required for compositional sampling") # Get shapes for compositional structure - n_datasets, n_compositional = ops.shape(xz)[0], ops.shape(xz)[1] - print(xz.shape, n_datasets, n_compositional) + n_compositional = ops.shape(conditions)[1] + print(ops.shape(xz), ops.shape(conditions)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -620,7 +620,7 @@ def compositional_velocity( # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt velocity = f - 0.5 * g_squared * compositional_score - print(velocity.shape) + print(velocity.shape, velocity) return velocity def _compute_individual_scores( From 9b7eb1696ee04bbff2b72e75950e51984de2e611 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:11:20 +0200 Subject: [PATCH 09/61] fix shapes --- bayesflow/approximators/continuous_approximator.py | 2 +- bayesflow/networks/diffusion_model/diffusion_model.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c28424d6f..12a0878af 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -731,7 +731,7 @@ def _compositional_sample( if inference_conditions is not None: # Reshape conditions for compositional sampling # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) - condition_dims = keras.ops.shape(inference_conditions)[2:] + condition_dims = keras.ops.shape(inference_conditions)[1:] inference_conditions = keras.ops.reshape( inference_conditions, (n_datasets, n_compositional, *condition_dims) ) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 157bc360c..132cf2f80 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -579,7 +579,7 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - print(ops.shape(xz), ops.shape(conditions)) + print(ops.shape(xz), ops.shape(conditions)) # (1, 100, 2), (1, 2, 100, 2) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -644,12 +644,13 @@ def _compute_individual_scores( transformed_log_snr = self._transform_log_snr(log_snr_t) # Reshape for processing: flatten compositional dimension temporarily - original_shape = ops.shape(xz) + original_shape = ops.shape(conditions) n_datasets, n_comp = original_shape[0], original_shape[1] - remaining_dims = original_shape[2:] # Flatten for subnet application - xz_flat = ops.reshape(xz, (n_datasets * n_comp,) + remaining_dims) + xz_flat = ops.expand_dims(xz, axis=1) # (n_datasets, 1, ...) + xz_flat = ops.broadcast_to(xz_flat, (n_datasets, n_comp) + ops.shape(xz)[1:]) + xz_flat = ops.reshape(xz_flat, (n_datasets * n_comp,) + ops.shape(xz)[1:]) log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) From e79aac11555d37b192d0cbf26da092f2814a4f12 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:12:51 +0200 Subject: [PATCH 10/61] fix shapes --- bayesflow/networks/diffusion_model/diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 132cf2f80..7233a5ce5 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -598,7 +598,7 @@ def compositional_velocity( time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Sum individual scores across compositional dimension - summed_individual_scores = ops.sum(individual_scores, axis=1, keepdims=True) + summed_individual_scores = ops.sum(individual_scores, axis=1) # Prior contribution: (1-n)(1-t) * prior_score prior_weight = (1.0 - n) * (1.0 - time_tensor) @@ -608,7 +608,7 @@ def compositional_velocity( compositional_score = weighted_prior + summed_individual_scores # Broadcast back to full compositional shape - compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + # compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) From 8a802409e649012d61f9571268c5fb291a138d86 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:29:51 +0200 Subject: [PATCH 11/61] fix shapes --- .../diffusion_model/diffusion_model.py | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 7233a5ce5..666e1986b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -583,10 +583,12 @@ def compositional_velocity( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores + print(xz.shape, log_snr_t.shape, alpha_t.shape, sigma_t.shape, conditions.shape) + # (1, 100, 2) (1, 100, 1) (1, 100, 1) (1, 100, 1) (1, 2, 100, 2) individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) # Compute prior score component @@ -643,18 +645,34 @@ def _compute_individual_scores( # Apply subnet to each compositional condition separately transformed_log_snr = self._transform_log_snr(log_snr_t) - # Reshape for processing: flatten compositional dimension temporarily - original_shape = ops.shape(conditions) - n_datasets, n_comp = original_shape[0], original_shape[1] + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand noise schedule components to match compositional structure + log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples) + dims) - # Flatten for subnet application - xz_flat = ops.expand_dims(xz, axis=1) # (n_datasets, 1, ...) - xz_flat = ops.broadcast_to(xz_flat, (n_datasets, n_comp) + ops.shape(xz)[1:]) - xz_flat = ops.reshape(xz_flat, (n_datasets * n_comp,) + ops.shape(xz)[1:]) - log_snr_flat = ops.reshape(transformed_log_snr, (n_datasets * n_comp,) + ops.shape(transformed_log_snr)[2:]) - conditions_flat = ops.reshape(conditions, (n_datasets * n_comp,) + ops.shape(conditions)[2:]) - alpha_flat = ops.reshape(alpha_t, (n_datasets * n_comp,) + ops.shape(alpha_t)[2:]) - sigma_flat = ops.reshape(sigma_t, (n_datasets * n_comp,) + ops.shape(sigma_t)[2:]) + alpha_expanded = ops.expand_dims(alpha_t, axis=1) + alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples) + dims) + + sigma_expanded = ops.expand_dims(sigma_t, axis=1) + sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples) + dims) + alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples) + dims) + sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples) + dims) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) # Apply subnet subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) @@ -669,8 +687,7 @@ def _compute_individual_scores( score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) # Reshape back to compositional structure - score = ops.reshape(score, original_shape) - + score = ops.reshape(score, (n_datasets, n_compositional, num_samples)) return score def _forward_compositional( From 00fbc619626db2e5b180a9e9c0fc0a2025600fe8 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:31:30 +0200 Subject: [PATCH 12/61] fix shapes --- bayesflow/networks/diffusion_model/diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 666e1986b..3943b8986 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -583,7 +583,7 @@ def compositional_velocity( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:1] + (1,)) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores From e6158e7c76b418b5c37a35579b16ddc8495e8fab Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:33:37 +0200 Subject: [PATCH 13/61] fix shapes --- .../networks/diffusion_model/diffusion_model.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3943b8986..94f278b7d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -579,7 +579,6 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - print(ops.shape(xz), ops.shape(conditions)) # (1, 100, 2), (1, 2, 100, 2) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -587,8 +586,6 @@ def compositional_velocity( alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores - print(xz.shape, log_snr_t.shape, alpha_t.shape, sigma_t.shape, conditions.shape) - # (1, 100, 2) (1, 100, 1) (1, 100, 1) (1, 100, 1) (1, 2, 100, 2) individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) # Compute prior score component @@ -659,19 +656,19 @@ def _compute_individual_scores( # Expand noise schedule components to match compositional structure log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) - log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples) + dims) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) alpha_expanded = ops.expand_dims(alpha_t, axis=1) - alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples) + dims) + alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples, 1)) sigma_expanded = ops.expand_dims(sigma_t, axis=1) - sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples) + dims) + sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples, 1)) # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) - log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples) + dims) - alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples) + dims) - sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples, 1)) + sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples, 1)) conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) # Apply subnet From 1ac39b2521212647cdcc41de5010510a26615c8b Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:38:34 +0200 Subject: [PATCH 14/61] fix shapes --- .../approximators/continuous_approximator.py | 53 +------------------ .../diffusion_model/diffusion_model.py | 3 +- 2 files changed, 2 insertions(+), 54 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 12a0878af..c8bd77a57 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -697,7 +697,7 @@ def compositional_sample( samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) # Back-transform quantities and samples - samples = self._back_transform_compositional(samples, original_shapes, **kwargs) + samples = self.adapter(samples, inverse=True, strict=False, **kwargs) if split: samples = split_arrays(samples, axis=-1) @@ -752,54 +752,3 @@ def _compositional_sample( compositional=True, **filter_kwargs(kwargs, self.inference_network.sample), ) - - def _back_transform_compositional( - self, samples: dict[str, np.ndarray], original_shapes: dict[str, tuple], **kwargs - ) -> dict[str, np.ndarray]: - """ - Back-transform compositional samples, handling the extra compositional dimension. - """ - # Get the sample shape to understand the compositional structure - inference_samples = samples["inference_variables"] - sample_shape = inference_samples.shape - - # Determine compositional dimensions from original shapes - # Assuming all condition keys have the same compositional structure - first_key = next(iter(original_shapes.keys())) - n_datasets, n_compositional = original_shapes[first_key][:2] - - # Reshape samples to match compositional structure if needed - if len(sample_shape) == 3: # (n_datasets * n_comp, num_samples, ..., dims) - num_samples, dims = sample_shape[1], sample_shape[2:] - inference_samples = inference_samples.reshape(n_datasets, n_compositional, num_samples, *dims) - samples["inference_variables"] = inference_samples - - # For back-transformation, we might need to flatten again temporarily - # depending on how the adapter expects the data - flattened_samples = {} - for key, value in samples.items(): - if len(value.shape) == 4: # (n_datasets, n_comp, num_samples, ..., dims) - n_d, n_c, n_s = value.shape[:3] - dims = value.shape[3:] - flattened_samples[key] = value.reshape(n_d * n_c, n_s, *dims) - else: - flattened_samples[key] = value - - # Apply inverse transformation - transformed = self.adapter(flattened_samples, inverse=True, strict=False, **kwargs) - - # Reshape back to compositional structure - final_samples = {} - for key, value in transformed.items(): - if key in original_shapes: - # Reshape to include compositional dimension - if len(value.shape) >= 2: - num_samples = value.shape[1] - remaining_dims = value.shape[2:] if len(value.shape) > 2 else () - final_samples[key] = value.reshape(n_datasets, n_compositional, num_samples, *remaining_dims) - else: - final_samples[key] = value - else: - final_samples[key] = value - - return final_samples diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 94f278b7d..69e6db619 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -619,7 +619,6 @@ def compositional_velocity( # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt velocity = f - 0.5 * g_squared * compositional_score - print(velocity.shape, velocity) return velocity def _compute_individual_scores( @@ -684,7 +683,7 @@ def _compute_individual_scores( score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) # Reshape back to compositional structure - score = ops.reshape(score, (n_datasets, n_compositional, num_samples)) + score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) return score def _forward_compositional( From 9fd9cf887d09bc29c03b40f7a9f802a55fe2971e Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 14:52:17 +0200 Subject: [PATCH 15/61] add minibatch --- .../diffusion_model/diffusion_model.py | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 69e6db619..1a302c5e4 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -550,6 +550,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, + mini_batch_size: int | None, training: bool = False, ) -> Tensor: """ @@ -566,6 +567,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + mini_batch_size : int or None + Size of mini-batches for processing compositional conditions to save memory. training : bool, optional Whether in training mode @@ -579,35 +582,35 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] + n = ops.cast(n_compositional, dtype=ops.dtype(time)) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + conditions_batch = conditions[:, idx[:mini_batch_size]] + else: + conditions_batch = conditions + # Compute individual dataset scores - individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions, training) + individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component prior_score = self.compute_prior_score(xz) - # Combine scores using compositional formula - # s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - n = ops.cast(n_compositional, dtype=ops.dtype(time)) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) - - # Sum individual scores across compositional dimension - summed_individual_scores = ops.sum(individual_scores, axis=1) + # Combine scores using compositional formula, mean over individual scores and scale with n to get sum + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - # Prior contribution: (1-n)(1-t) * prior_score - prior_weight = (1.0 - n) * (1.0 - time_tensor) - weighted_prior = prior_weight * prior_score + # Prior contribution + weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score # Combined score - compositional_score = weighted_prior + summed_individual_scores - - # Broadcast back to full compositional shape - # compositional_score = ops.broadcast_to(compositional_score, ops.shape(xz)) + compositional_score = weighted_prior_score + summed_individual_scores # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -700,6 +703,7 @@ def _forward_compositional( integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs + mini_batch_size = integrate_kwargs.get("mini_batch_size", None) if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for forward integration.") @@ -711,7 +715,12 @@ def _forward_compositional( def deltas(time, xz): v = self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) # For density, we need trace but compositional trace is complex # Simplified version - could be extended @@ -732,7 +741,12 @@ def deltas(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } @@ -755,6 +769,7 @@ def _inverse_compositional( integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs + mini_batch_size = integrate_kwargs.get("mini_batch_size", None) if density: if integrate_kwargs["method"] == "euler_maruyama": @@ -762,7 +777,12 @@ def _inverse_compositional( def deltas(time, xz): v = self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) return {"xz": v, "trace": trace} @@ -784,7 +804,12 @@ def deltas(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=True, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } @@ -803,7 +828,12 @@ def diffusion(time, xz): def deltas(time, xz): return { "xz": self.compositional_velocity( - xz, time=time, stochastic_solver=False, conditions=conditions, training=training + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + mini_batch_size=mini_batch_size, + training=training, ) } From 830e9295b9bfa03c6ed714564baba2fb3e7c7860 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:11:35 +0200 Subject: [PATCH 16/61] add compositional_bridge --- .../diffusion_model/diffusion_model.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 1a302c5e4..3ebb0dd43 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -544,6 +544,24 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} + @staticmethod + def compositional_bridge(time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.ones_like(time) + def compositional_velocity( self, xz: Tensor, @@ -610,7 +628,7 @@ def compositional_velocity( weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score # Combined score - compositional_score = weighted_prior_score + summed_individual_scores + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) # Compute velocity using standard drift-diffusion formulation f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -703,13 +721,14 @@ def _forward_compositional( integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.get("mini_batch_size", None) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for forward integration.") # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - x = x / ops.sqrt(ops.cast(ops.shape(x)[1], dtype=ops.dtype(x))) + scale_latent = ops.shape(conditions)[1] * self.compositional_bridge(ops.ones(1)) + x = x / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(x))) if density: @@ -769,7 +788,7 @@ def _inverse_compositional( integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.get("mini_batch_size", None) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) if density: if integrate_kwargs["method"] == "euler_maruyama": @@ -844,5 +863,4 @@ def deltas(time, xz): @staticmethod def compute_prior_score(xz: Tensor) -> Tensor: - return ops.ones_like(xz) # todo: Placeholder implementation - # raise NotImplementedError('Please implement the prior score computation method.') + raise NotImplementedError("Please implement the prior score computation method.") From f97594b522a598ebb2a353ba8a11aefa192006da Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:31:26 +0200 Subject: [PATCH 17/61] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 99 ++++--------------- 1 file changed, 20 insertions(+), 79 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3ebb0dd43..9dd3dc899 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -568,7 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, - mini_batch_size: int | None, + mini_batch_idx: Sequence | None, training: bool = False, ) -> Tensor: """ @@ -585,8 +585,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - mini_batch_size : int or None - Size of mini-batches for processing compositional conditions to save memory. + mini_batch_idx : Sequence + Indices for mini-batch selection along the compositional axis. training : bool, optional Whether in training mode @@ -608,14 +608,11 @@ def compositional_velocity( log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - conditions_batch = conditions[:, idx[:mini_batch_size]] + # Compute individual dataset scores + if mini_batch_idx is not None: + conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions - - # Compute individual dataset scores individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component @@ -707,73 +704,6 @@ def _compute_individual_scores( score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) return score - def _forward_compositional( - self, - x: Tensor, - conditions: Tensor = None, - density: bool = False, - training: bool = False, - **kwargs, - ) -> Tensor | tuple[Tensor, Tensor]: - """ - Forward pass for compositional diffusion. - """ - integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} - integrate_kwargs = integrate_kwargs | self.integrate_kwargs - integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - - if integrate_kwargs["method"] == "euler_maruyama": - raise ValueError("Stochastic methods are not supported for forward integration.") - - # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - scale_latent = ops.shape(conditions)[1] * self.compositional_bridge(ops.ones(1)) - x = x / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(x))) - - if density: - - def deltas(time, xz): - v = self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - mini_batch_size=mini_batch_size, - training=training, - ) - # For density, we need trace but compositional trace is complex - # Simplified version - could be extended - trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) - return {"xz": v, "trace": trace} - - state = { - "xz": x, - "trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)), - } - state = integrate(deltas, state, **integrate_kwargs) - - z = state["xz"] - # Simplified density computation - log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) + ops.squeeze(state["trace"], axis=-1) - return z, log_density - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = {"xz": x} - state = integrate(deltas, state, **integrate_kwargs) - z = state["xz"] - return z - def _inverse_compositional( self, z: Tensor, @@ -790,6 +720,17 @@ def _inverse_compositional( integrate_kwargs = integrate_kwargs | kwargs mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + n_compositional = ops.shape(conditions)[1] + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + else: + mini_batch_idx = None + if density: if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for density computation.") @@ -800,7 +741,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) @@ -827,7 +768,7 @@ def deltas(time, xz): time=time, stochastic_solver=True, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) } @@ -851,7 +792,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, - mini_batch_size=mini_batch_size, + mini_batch_idx=mini_batch_idx, training=training, ) } From 7219a71aac81c72bf686acd02e3a5d0ade316fe6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:37:34 +0200 Subject: [PATCH 18/61] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 9dd3dc899..8cf58c4f8 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -725,17 +725,18 @@ def _inverse_compositional( scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - else: - mini_batch_idx = None - if density: if integrate_kwargs["method"] == "euler_maruyama": raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + v = self.compositional_velocity( xz, time=time, @@ -762,6 +763,13 @@ def deltas(time, xz): if integrate_kwargs["method"] == "euler_maruyama": def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + return { "xz": self.compositional_velocity( xz, @@ -786,6 +794,13 @@ def diffusion(time, xz): else: def deltas(time, xz): + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + else: + mini_batch_idx = None + return { "xz": self.compositional_velocity( xz, From a10026a00e7fa04eae7bccffb062b77d9c8e1ce9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 15:45:41 +0200 Subject: [PATCH 19/61] fix mini batch randomness --- .../diffusion_model/diffusion_model.py | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 8cf58c4f8..37fc9033b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -568,7 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, - mini_batch_idx: Sequence | None, + mini_batch_size: int | None = None, training: bool = False, ) -> Tensor: """ @@ -585,8 +585,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - mini_batch_idx : Sequence - Indices for mini-batch selection along the compositional axis. + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional Whether in training mode @@ -609,7 +609,10 @@ def compositional_velocity( alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores - if mini_batch_idx is not None: + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions @@ -720,6 +723,14 @@ def _inverse_compositional( integrate_kwargs = integrate_kwargs | kwargs mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + # if backend is jax, mini batching does not work + if ops.__name__ == "jax": + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional n_compositional = ops.shape(conditions)[1] scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) @@ -730,19 +741,12 @@ def _inverse_compositional( raise ValueError("Stochastic methods are not supported for density computation.") def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - v = self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) @@ -763,20 +767,13 @@ def deltas(time, xz): if integrate_kwargs["method"] == "euler_maruyama": def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=True, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) } @@ -794,20 +791,13 @@ def diffusion(time, xz): else: def deltas(time, xz): - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - else: - mini_batch_idx = None - return { "xz": self.compositional_velocity( xz, time=time, stochastic_solver=False, conditions=conditions, - mini_batch_idx=mini_batch_idx, + mini_batch_size=mini_batch_size, training=training, ) } @@ -819,4 +809,5 @@ def deltas(time, xz): @staticmethod def compute_prior_score(xz: Tensor) -> Tensor: - raise NotImplementedError("Please implement the prior score computation method.") + return ops.ones_like(xz) + # raise NotImplementedError("Please implement the prior score computation method.") From 457eb5d6a7b7d3ec82ab7811cb4e7dbcd7f976ff Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:12:32 +0200 Subject: [PATCH 20/61] add prior score --- .../approximators/continuous_approximator.py | 32 ++++++++++++++++++- .../diffusion_model/diffusion_model.py | 22 +++++++------ bayesflow/networks/inference_network.py | 17 ++++++++-- bayesflow/workflows/basic_workflow.py | 7 +++- 4 files changed, 64 insertions(+), 14 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c8bd77a57..0a046d57e 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -644,6 +644,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], split: bool = False, **kwargs, ) -> dict[str, np.ndarray]: @@ -659,6 +660,8 @@ def compositional_sample( conditions : dict[str, np.ndarray] Dictionary of conditioning variables as NumPy arrays with shape (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the log probability of samples under the prior distribution. split : bool, default=False Whether to split the output arrays along the last axis and return one column vector per target variable samples. @@ -685,9 +688,34 @@ def compositional_sample( # Remove any superfluous keys, just retain actual conditions prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + # Prepare prior scores to handle adapter + def compute_prior_score_pre(_samples: Tensor) -> Tensor: + if "inference_variables" in self.standardize: + _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( + _samples, forward=False, log_det_jac=True + ) + else: + log_det_jac_standardize = 0 + _samples = {"inference_variables": _samples} + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, _samples) + adapted_samples, log_det_jac = self.adapter( + _samples, inverse=True, strict=False, log_det_jac=True, **kwargs + ) + prior_score = keras.ops.convert_to_tensor(compute_prior_score(adapted_samples)) + if log_det_jac is not None: + prior_score += keras.ops.convert_to_tensor(log_det_jac) + if log_det_jac_standardize is not None: + prior_score += keras.ops.convert_to_tensor(log_det_jac_standardize) + return prior_score + # Sample using compositional sampling samples = self._compositional_sample( - num_samples=num_samples, n_datasets=n_datasets, n_compositional=n_comp, **prepared_conditions, **kwargs + num_samples=num_samples, + n_datasets=n_datasets, + n_compositional=n_comp, + compute_prior_score=compute_prior_score_pre, + **prepared_conditions, + **kwargs, ) if "inference_variables" in self.standardize: @@ -708,6 +736,7 @@ def _compositional_sample( num_samples: int, n_datasets: int, n_compositional: int, + compute_prior_score: Callable[[Tensor], Tensor], inference_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs, @@ -750,5 +779,6 @@ def _compositional_sample( batch_shape, conditions=inference_conditions, compositional=True, + compute_prior_score=compute_prior_score, **filter_kwargs(kwargs, self.inference_network.sample), ) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 37fc9033b..b65c61d01 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal +from typing import Literal, Callable import keras from keras import ops @@ -568,6 +568,7 @@ def compositional_velocity( time: float | Tensor, stochastic_solver: bool, conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], mini_batch_size: int | None = None, training: bool = False, ) -> Tensor: @@ -585,6 +586,8 @@ def compositional_velocity( Whether to use stochastic (SDE) or deterministic (ODE) formulation conditions : Tensor Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). mini_batch_size : int or None Mini batch size for computing individual scores. If None, use all conditions. training : bool, optional @@ -619,7 +622,7 @@ def compositional_velocity( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - prior_score = self.compute_prior_score(xz) + prior_score = compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) @@ -710,13 +713,14 @@ def _compute_individual_scores( def _inverse_compositional( self, z: Tensor, - conditions: Tensor = None, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], density: bool = False, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: """ - Inverse pass for compositional diffusion (sampling). + Inverse pass for compositional diffusion sampling. """ integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs @@ -725,7 +729,7 @@ def _inverse_compositional( if mini_batch_size is not None: # if backend is jax, mini batching does not work - if ops.__name__ == "jax": + if keras.backend.backend() == "jax": raise ValueError( "Mini batching is not supported with JAX backend. Set mini_batch_size to None " "or use another backend." @@ -746,6 +750,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -773,6 +778,7 @@ def deltas(time, xz): time=time, stochastic_solver=True, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -797,6 +803,7 @@ def deltas(time, xz): time=time, stochastic_solver=False, conditions=conditions, + compute_prior_score=compute_prior_score, mini_batch_size=mini_batch_size, training=training, ) @@ -806,8 +813,3 @@ def deltas(time, xz): x = state["xz"] return x - - @staticmethod - def compute_prior_score(xz: Tensor) -> Tensor: - return ops.ones_like(xz) - # raise NotImplementedError("Please implement the prior score computation method.") diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index f2e5c512f..250c93b22 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,3 +1,4 @@ +from typing import Callable import keras from bayesflow.types import Shape, Tensor @@ -52,12 +53,24 @@ def _inverse( raise NotImplementedError def _forward_compositional( - self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + self, + x: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError def _inverse_compositional( - self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index f1941ac3a..2ef326dae 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,6 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], + prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -306,6 +307,8 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). + prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function. @@ -315,7 +318,9 @@ def compositional_sample( A dictionary where keys correspond to variable names and values are arrays containing the generated samples. """ - return self.approximator.compositional_sample(num_samples=num_samples, conditions=conditions, **kwargs) + return self.approximator.compositional_sample( + num_samples=num_samples, conditions=conditions, prior_score=prior_score, **kwargs + ) def estimate( self, From 7de473649f066fce40dd93c74512c6894ac1be74 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:14:15 +0200 Subject: [PATCH 21/61] add prior score --- bayesflow/workflows/basic_workflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 2ef326dae..cfa63545b 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -291,7 +291,7 @@ def compositional_sample( *, num_samples: int, conditions: Mapping[str, np.ndarray], - prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], **kwargs, ) -> dict[str, np.ndarray]: """ @@ -307,7 +307,7 @@ def compositional_sample( NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present. Should have shape (n_datasets, n_compositional_conditions, ...). - prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] A function that computes the log probability of samples under the prior distribution. **kwargs : dict, optional Additional keyword arguments passed to the approximator's sampling function. @@ -319,7 +319,7 @@ def compositional_sample( values are arrays containing the generated samples. """ return self.approximator.compositional_sample( - num_samples=num_samples, conditions=conditions, prior_score=prior_score, **kwargs + num_samples=num_samples, conditions=conditions, compute_prior_score=compute_prior_score, **kwargs ) def estimate( From 1ee0e785086b4e28d92e037b3c84c62ddd44c1f4 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:40:20 +0200 Subject: [PATCH 22/61] add prior score draft --- .../approximators/continuous_approximator.py | 28 ++++++++++++++----- bayesflow/networks/inference_network.py | 24 ++++++++++++---- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 0a046d57e..cbc7d30c3 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -701,12 +701,27 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) - prior_score = keras.ops.convert_to_tensor(compute_prior_score(adapted_samples)) - if log_det_jac is not None: - prior_score += keras.ops.convert_to_tensor(log_det_jac) - if log_det_jac_standardize is not None: - prior_score += keras.ops.convert_to_tensor(log_det_jac_standardize) - return prior_score + prior_score = compute_prior_score(adapted_samples) + prior_score_final = {} + for i, key in enumerate(adapted_samples): # todo: assumes same order, might incorrect + prior_score_final[key] = prior_score[key] + if len(log_det_jac_standardize) > 0: + prior_score_final[key] += log_det_jac_standardize[:, i] + if len(log_det_jac) > 0: + prior_score_final[key] += log_det_jac[:, i] + prior_score_final[key] = keras.ops.convert_to_tensor(prior_score_final[key]) + # make a tensor + out = keras.ops.concatenate(list(prior_score_final.values()), axis=-1) + return out + + # Test prior score function, useful for debugging + test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) + test = compute_prior_score_pre(test) + if test.shape[:2] != (n_datasets, num_samples): + raise ValueError( + "The provided compute_prior_score function does not return the correct shape. " + f"Expected ({n_datasets}, {num_samples}, ...), got {test.shape}." + ) # Sample using compositional sampling samples = self._compositional_sample( @@ -778,7 +793,6 @@ def _compositional_sample( return self.inference_network.sample( batch_shape, conditions=inference_conditions, - compositional=True, compute_prior_score=compute_prior_score, **filter_kwargs(kwargs, self.inference_network.sample), ) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 250c93b22..4fd22e468 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -28,18 +28,32 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, - compositional: bool = False, + compute_prior_score: Callable[[Tensor], Tensor] = None, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: - if compositional: + if compute_prior_score is not None: return self._inverse_compositional( xz, conditions=conditions, density=density, training=training, **kwargs ) - return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) - if compositional: - return self._forward_compositional(xz, conditions=conditions, density=density, training=training, **kwargs) + return self._inverse( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) + if compute_prior_score is not None: + return self._forward_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( From f71359bbdb0cd95b042876df38766bcd150ff4dd Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 16:41:46 +0200 Subject: [PATCH 23/61] add prior score draft --- bayesflow/networks/inference_network.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index 4fd22e468..9488f644d 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -35,16 +35,14 @@ def call( if inverse: if compute_prior_score is not None: return self._inverse_compositional( - xz, conditions=conditions, density=density, training=training, **kwargs + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, ) - return self._inverse( - xz, - conditions=conditions, - compute_prior_score=compute_prior_score, - density=density, - training=training, - **kwargs, - ) + return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) if compute_prior_score is not None: return self._forward_compositional( xz, From 6210c07ade674e5cbf0ef5a2def2b21b1125d7e7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:01:14 +0200 Subject: [PATCH 24/61] add prior score draft --- .../approximators/continuous_approximator.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index cbc7d30c3..b4e543495 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -14,6 +14,7 @@ squeeze_inner_estimates_dict, concatenate_valid, concatenate_valid_shapes, + expand_right_as, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -690,29 +691,28 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: + return keras.ops.zeros_like(_samples) if "inference_variables" in self.standardize: _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( _samples, forward=False, log_det_jac=True ) else: log_det_jac_standardize = 0 - _samples = {"inference_variables": _samples} - _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, _samples) + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) prior_score = compute_prior_score(adapted_samples) - prior_score_final = {} - for i, key in enumerate(adapted_samples): # todo: assumes same order, might incorrect - prior_score_final[key] = prior_score[key] - if len(log_det_jac_standardize) > 0: - prior_score_final[key] += log_det_jac_standardize[:, i] + for key in adapted_samples: + prior_score[key] = prior_score[key] if len(log_det_jac) > 0: - prior_score_final[key] += log_det_jac[:, i] - prior_score_final[key] = keras.ops.convert_to_tensor(prior_score_final[key]) + prior_score[key] += log_det_jac[key] + prior_score[key] = keras.ops.convert_to_tensor(prior_score[key]) # make a tensor - out = keras.ops.concatenate(list(prior_score_final.values()), axis=-1) - return out + out = keras.ops.concatenate( + list(prior_score.values()), axis=-1 + ) # todo: assumes same order, might be incorrect + return out + expand_right_as(log_det_jac_standardize, out) # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) From bcb9f60a63ddd2ade564121a6f15e9933824a132 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:19:10 +0200 Subject: [PATCH 25/61] add prior score draft --- bayesflow/approximators/continuous_approximator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index b4e543495..5c727eac0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -691,7 +691,6 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: - return keras.ops.zeros_like(_samples) if "inference_variables" in self.standardize: _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( _samples, forward=False, log_det_jac=True @@ -707,7 +706,8 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: prior_score[key] = prior_score[key] if len(log_det_jac) > 0: prior_score[key] += log_det_jac[key] - prior_score[key] = keras.ops.convert_to_tensor(prior_score[key]) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) # make a tensor out = keras.ops.concatenate( list(prior_score.values()), axis=-1 From 455f03c2336549ea8c9cd8161af8b16f1af5150d Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 8 Sep 2025 17:25:22 +0200 Subject: [PATCH 26/61] fix dtype --- bayesflow/networks/diffusion_model/diffusion_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index b65c61d01..3099788bd 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -603,8 +603,6 @@ def compositional_velocity( # Get shapes for compositional structure n_compositional = ops.shape(conditions)[1] - n = ops.cast(n_compositional, dtype=ops.dtype(time)) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) @@ -628,9 +626,10 @@ def compositional_velocity( summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) # Prior contribution - weighted_prior_score = (1.0 - n) * (1.0 - time_tensor) * prior_score + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score # Combined score + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) # Compute velocity using standard drift-diffusion formulation From 89523a98494d40046d75581d146e466a6e295b49 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:16:10 +0200 Subject: [PATCH 27/61] fix docstring --- bayesflow/approximators/continuous_approximator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 5c727eac0..2cd4225c6 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -662,7 +662,7 @@ def compositional_sample( Dictionary of conditioning variables as NumPy arrays with shape (n_datasets, n_compositional_conditions, ...). compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] - A function that computes the log probability of samples under the prior distribution. + A function that computes the score of the log prior distribution. split : bool, default=False Whether to split the output arrays along the last axis and return one column vector per target variable samples. From e55631dcec299cecc0e1ce27ebf0cd668ab960ba Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:30:35 +0200 Subject: [PATCH 28/61] fix batch_shape in sample --- bayesflow/approximators/continuous_approximator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f27a612f0..b60f4e4bd 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,10 +535,9 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = ( - batch_size, - num_samples, - ) + + target_dim = self.inference_network.base_distribution.dims + batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] else: batch_shape = (num_samples,) From 3eaff24a1d314435040a688559e3c85d5235f9c6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:35:56 +0200 Subject: [PATCH 29/61] fix batch_shape for point approximator --- bayesflow/approximators/continuous_approximator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index b60f4e4bd..13ba32cb9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -536,8 +536,12 @@ def _sample( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - target_dim = self.inference_network.base_distribution.dims - batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] + if hasattr(self.inference_network, "base_distribution"): + target_shape_len = len(self.inference_network.base_distribution.dims) + else: + # point approximator has no base_distribution + target_shape_len = 1 + batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len] else: batch_shape = (num_samples,) From e97e375f458eadd41ed19b0a574ad0fa0daf9b59 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Sep 2025 18:11:31 +0200 Subject: [PATCH 30/61] fix docstring --- bayesflow/simulators/sequential_simulator.py | 4 ++-- bayesflow/simulators/simulator.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index 21e1542e6..96ab0ead3 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -88,7 +88,7 @@ def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]: return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs) def sample_parallel( - self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs ) -> dict[str, np.ndarray]: """ Sample in parallel from the sequential simulator. @@ -101,7 +101,7 @@ def sample_parallel( n_jobs : int, optional Number of parallel jobs. -1 uses all available cores. Default is -1. verbose : int, optional - Verbosity level for joblib. Default is 0 (no output). + Verbosity level for joblib. Default is 1 (minimal output). **kwargs Additional keyword arguments passed to each simulator. These may include previously sampled outputs used as inputs for subsequent simulators. diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index 00d3d84f3..53d54e455 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -95,3 +95,8 @@ def accept_all_predicate(x): return np.full((sample_size,), True) return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs) + + def sample_parallel( + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs + ) -> dict[str, np.ndarray]: + raise NotImplementedError From caa2d67ec4f934984d9a932515472ec9e432b4e8 Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 10 Sep 2025 18:54:15 +0200 Subject: [PATCH 31/61] fix float32 --- bayesflow/approximators/continuous_approximator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 546de64ed..7a0d757d9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -706,7 +706,8 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: ) prior_score = compute_prior_score(adapted_samples) for key in adapted_samples: - prior_score[key] = prior_score[key] + if isinstance(prior_score[key], np.ndarray): + prior_score[key] = prior_score[key].astype("float32") if len(log_det_jac) > 0: prior_score[key] += log_det_jac[key] From 1ac9bff056cb539a45a76391e2733afff6fb99ca Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 12:34:57 +0200 Subject: [PATCH 32/61] reorganize --- .../diffusion_model/diffusion_model.py | 71 +++++++++++++++---- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 3099788bd..332165a5e 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -593,6 +593,64 @@ def compositional_velocity( training : bool, optional Whether in training mode + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + Returns ------- Tensor @@ -631,18 +689,7 @@ def compositional_velocity( # Combined score time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) - - # Compute velocity using standard drift-diffusion formulation - f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) - - if stochastic_solver: - # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW - velocity = f - g_squared * compositional_score - else: - # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt - velocity = f - 0.5 * g_squared * compositional_score - - return velocity + return compositional_score def _compute_individual_scores( self, From df23f892a5d9b7d23eec0c9ff2ffe61a266e8698 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 14:02:09 +0200 Subject: [PATCH 33/61] add annealed_langevin --- .../diffusion_model/diffusion_model.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 332165a5e..c3a625e06 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,12 +16,15 @@ integrate_stochastic, logging, tensor_utils, + filter_kwargs, ) from bayesflow.utils.serialization import serialize, deserialize, serializable from .schedules.noise_schedule import NoiseSchedule from .dispatch import find_noise_schedule +ArrayLike = int | float | Tensor + # disable module check, use potential module after moving from experimental @serializable("bayesflow.networks", disable_module_check=True) @@ -840,6 +843,26 @@ def diffusion(time, xz): seed=self.seed_generator, **integrate_kwargs, ) + elif integrate_kwargs["method"] == "langevin": + + def scores(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = annealed_langevin( + score_fn=scores, + state=state, + seed=self.seed_generator, + **filter_kwargs(integrate_kwargs, annealed_langevin), + ) else: def deltas(time, xz): @@ -859,3 +882,50 @@ def deltas(time, xz): x = state["xz"] return x + + +def annealed_langevin( + score_fn: Callable, + state: dict[str, ArrayLike], + steps: int, + seed: keras.random.SeedGenerator, + L: int = 5, + start_time: ArrayLike = None, + stop_time: ArrayLike = None, + eps: float = 0.01, +) -> dict[str, ArrayLike]: + """ + Annealed Langevin dynamics for diffusion sampling. + + for t = T-1,...,1: + for s = 1,...,L: + eta ~ N(0, I) + theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta + """ + ratio = keras.ops.convert_to_tensor( + (stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values()))) + ) + + T = steps + # main loops + for t_T in range(T - 1, 0, -1): + t = t_T / T + dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * ( + ratio ** (stop_time - t) + ) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) + # inner L Langevin steps at level t + for _ in range(L): + # score + drift = score_fn(t, **filter_kwargs(state, score_fn)) + # noise + eta = { + k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) + for k, v in state.items() + } + + # update + for k, d in drift.items(): + state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k] + return state From 0a87694f654be4039f17e02060e357fdc9e07c70 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 14:44:42 +0200 Subject: [PATCH 34/61] fix annealed_langevin --- .../diffusion_model/diffusion_model.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c3a625e06..08118a220 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -859,6 +859,7 @@ def scores(time, xz): state = annealed_langevin( score_fn=scores, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **filter_kwargs(integrate_kwargs, annealed_langevin), @@ -886,13 +887,14 @@ def deltas(time, xz): def annealed_langevin( score_fn: Callable, + noise_schedule: Callable, state: dict[str, ArrayLike], steps: int, seed: keras.random.SeedGenerator, - L: int = 5, start_time: ArrayLike = None, stop_time: ArrayLike = None, - eps: float = 0.01, + langevin_corrector_steps: int = 5, + step_size_factor: float = 0.1, ) -> dict[str, ArrayLike]: """ Annealed Langevin dynamics for diffusion sampling. @@ -902,30 +904,25 @@ def annealed_langevin( eta ~ N(0, I) theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta """ - ratio = keras.ops.convert_to_tensor( - (stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values()))) - ) + log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False) + _, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - T = steps # main loops - for t_T in range(T - 1, 0, -1): - t = t_T / T - dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * ( - ratio ** (stop_time - t) - ) - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt)) - # inner L Langevin steps at level t - for _ in range(L): - # score + for step in range(steps - 1, 0, -1): + t = step / steps + log_snr_t = noise_schedule.get_log_snr(t=t, training=False) + _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t) + + sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size)) + for _ in range(langevin_corrector_steps): drift = score_fn(t, **filter_kwargs(state, score_fn)) - # noise - eta = { + noise = { k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) for k, v in state.items() } # update for k, d in drift.items(): - state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k] + state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k] return state From 64d43735dc9136b622605bb9b17ea372629cbee9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:00:39 +0200 Subject: [PATCH 35/61] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 17 ++++++++ bayesflow/utils/integrate.py | 41 +++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 08118a220..fd1415616 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -836,9 +836,26 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + scores = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def scores(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=scores, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..be269ebaa 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -401,11 +401,17 @@ def integrate_stochastic( steps: int, seed: keras.random.SeedGenerator, method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. + When score_fn is provided, performs predictor-corrector sampling where: + - Predictor: reverse diffusion SDE solver + - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. @@ -415,11 +421,13 @@ def integrate_stochastic( steps: Number of integration steps. seed: Random seed for noise generation. method: Integration method to use, e.g., 'euler_maruyama'. + score_fn: Optional score function for predictor-corrector sampling. + Should take (time, **state) and return score dict. + corrector_steps: Number of corrector steps to take after each predictor step. **kwargs: Additional arguments to pass to the step function. Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Final state dictionary after integration. """ if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -438,17 +446,44 @@ def integrate_stochastic( step_size = (stop_time - start_time) / steps sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - # Pre-generate noise history: shape = (steps, *state_shape) + # Pre-generate noise history for predictor: shape = (steps, *state_shape) noise_history = {} for key, val in state.items(): noise_history[key] = ( keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt ) + # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + corrector_noise_history = {} + if score_fn is not None and corrector_steps > 0: + for key, val in state.items(): + corrector_noise_history[key] = keras.random.normal( + (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + ) + def body(_loop_var, _loop_state): _current_state, _current_time = _loop_state _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + + # Predictor step new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + + # Corrector steps: annealed Langevin dynamics if score_fn is provided + if score_fn is not None: + first_key = next(iter(new_state.keys())) + dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key])) + e = keras.ops.sqrt(dim) + sqrt_2e = keras.ops.sqrt(2.0 * e) + + for corrector_step in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + for k in new_state.keys(): + if k in score: + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) From 5b4236862a6c76682a1b03c54eebd278314549fe Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:13:32 +0200 Subject: [PATCH 36/61] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 81 ++++++++++++++++--- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index fd1415616..69dae59ac 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -246,6 +246,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -282,19 +331,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -452,9 +492,24 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, state=state, seed=self.seed_generator, **integrate_kwargs, @@ -836,11 +891,11 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} - scores = None + score_fn = None if "corrector_steps" in integrate_kwargs: if integrate_kwargs["corrector_steps"] > 0: - def scores(time, xz): + def score_fn(time, xz): return { "xz": self.compositional_score( xz, @@ -855,7 +910,7 @@ def scores(time, xz): state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, - score_fn=scores, + score_fn=score_fn, state=state, seed=self.seed_generator, **integrate_kwargs, From 94029414a6c58ec56aeeb2c5d75741fde015e283 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:27:44 +0200 Subject: [PATCH 37/61] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 2 ++ bayesflow/utils/integrate.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 69dae59ac..e303e961d 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -510,6 +510,7 @@ def score_fn(time, xz): drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, @@ -911,6 +912,7 @@ def score_fn(time, xz): drift_fn=deltas, diffusion_fn=diffusion, score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index be269ebaa..b3127a737 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -403,6 +403,7 @@ def integrate_stochastic( method: str = "euler_maruyama", score_fn: Callable = None, corrector_steps: int = 0, + noise_schedule=None, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -424,6 +425,7 @@ def integrate_stochastic( score_fn: Optional score function for predictor-corrector sampling. Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. **kwargs: Additional arguments to pass to the step function. Returns: @@ -455,7 +457,10 @@ def integrate_stochastic( # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) corrector_noise_history = {} - if score_fn is not None and corrector_steps > 0: + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + for key, val in state.items(): corrector_noise_history[key] = keras.random.normal( (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed @@ -469,19 +474,29 @@ def body(_loop_var, _loop_state): new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) # Corrector steps: annealed Langevin dynamics if score_fn is provided - if score_fn is not None: - first_key = next(iter(new_state.keys())) - dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key])) - e = keras.ops.sqrt(dim) - sqrt_2e = keras.ops.sqrt(2.0 * e) - + if corrector_steps > 0: for corrector_step in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + # Compute noise schedule components for corrector step size + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: + z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + + # Prevent division by zero + score_norm = keras.ops.maximum(score_norm, 1e-8) + + e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2 + sqrt_2e = keras.ops.sqrt(2.0 * e) + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] return new_state, new_time From e0b3bd5dfdfc6320cb35daa9429cbb4b816e8f69 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:32:27 +0200 Subject: [PATCH 38/61] add predictor corrector sampling --- bayesflow/utils/integrate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b3127a737..a46d3e78a 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -404,6 +404,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, + r: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -426,6 +427,7 @@ def integrate_stochastic( Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. + r: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: @@ -482,10 +484,9 @@ def body(_loop_var, _loop_state): # Compute noise schedule components for corrector step size log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector - # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2 + # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) @@ -494,7 +495,7 @@ def body(_loop_var, _loop_state): # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) - e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2 + e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2 sqrt_2e = keras.ops.sqrt(2.0 * e) new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] From 89361f75282dfa367136b2658c3fd70453c6ea93 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:42:13 +0200 Subject: [PATCH 39/61] add predictor corrector sampling --- .../diffusion_model/diffusion_model.py | 67 ------------------- bayesflow/utils/integrate.py | 8 +-- 2 files changed, 4 insertions(+), 71 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index e303e961d..81c64bfbd 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -16,15 +16,12 @@ integrate_stochastic, logging, tensor_utils, - filter_kwargs, ) from bayesflow.utils.serialization import serialize, deserialize, serializable from .schedules.noise_schedule import NoiseSchedule from .dispatch import find_noise_schedule -ArrayLike = int | float | Tensor - # disable module check, use potential module after moving from experimental @serializable("bayesflow.networks", disable_module_check=True) @@ -917,27 +914,6 @@ def score_fn(time, xz): seed=self.seed_generator, **integrate_kwargs, ) - elif integrate_kwargs["method"] == "langevin": - - def scores(time, xz): - return { - "xz": self.compositional_score( - xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = annealed_langevin( - score_fn=scores, - noise_schedule=self.noise_schedule, - state=state, - seed=self.seed_generator, - **filter_kwargs(integrate_kwargs, annealed_langevin), - ) else: def deltas(time, xz): @@ -957,46 +933,3 @@ def deltas(time, xz): x = state["xz"] return x - - -def annealed_langevin( - score_fn: Callable, - noise_schedule: Callable, - state: dict[str, ArrayLike], - steps: int, - seed: keras.random.SeedGenerator, - start_time: ArrayLike = None, - stop_time: ArrayLike = None, - langevin_corrector_steps: int = 5, - step_size_factor: float = 0.1, -) -> dict[str, ArrayLike]: - """ - Annealed Langevin dynamics for diffusion sampling. - - for t = T-1,...,1: - for s = 1,...,L: - eta ~ N(0, I) - theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta - """ - log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False) - _, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - # main loops - for step in range(steps - 1, 0, -1): - t = step / steps - log_snr_t = noise_schedule.get_log_snr(t=t, training=False) - _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t) - - sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size)) - for _ in range(langevin_corrector_steps): - drift = score_fn(t, **filter_kwargs(state, score_fn)) - noise = { - k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed) - for k, v in state.items() - } - - # update - for k, d in drift.items(): - state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k] - return state diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a46d3e78a..961015b8f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -404,7 +404,7 @@ def integrate_stochastic( score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, - r: float = 0.1, + step_size_factor: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ @@ -427,7 +427,7 @@ def integrate_stochastic( Should take (time, **state) and return score dict. corrector_steps: Number of corrector steps to take after each predictor step. noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. - r: Scaling factor for corrector step size. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: @@ -489,13 +489,13 @@ def body(_loop_var, _loop_state): # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 for k in new_state.keys(): if k in score: - z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True) + z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) # Prevent division by zero score_norm = keras.ops.maximum(score_norm, 1e-8) - e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2 + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 sqrt_2e = keras.ops.sqrt(2.0 * e) new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] From 5969bd380514acd883ed3820b32ecdf514003059 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 16:59:52 +0200 Subject: [PATCH 40/61] robust mean scores --- bayesflow/networks/diffusion_model/diffusion_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 81c64bfbd..f56655c05 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -734,13 +734,11 @@ def compositional_score( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - time) * compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - - # Prior contribution - weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + weighted_individual_scores = individual_scores - weighted_prior_score + summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) # Combined score time_tensor = ops.cast(time, dtype=ops.dtype(xz)) From e983cf7a746b13eae0122bd1d6f57f7c20095cc1 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:10:49 +0200 Subject: [PATCH 41/61] add some tests --- .../diffusion_model/diffusion_model.py | 2 +- .../test_compositional_sampling.py | 178 ++++++++++++++++++ 2 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 tests/test_networks/test_diffusion_model/test_compositional_sampling.py diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index f56655c05..25a6b4c7c 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -737,7 +737,7 @@ def compositional_score( weighted_prior_score = (1.0 - time) * compute_prior_score(xz) # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - weighted_individual_scores = individual_scores - weighted_prior_score + weighted_individual_scores = individual_scores - keras.ops.expand_dims(weighted_prior_score, axis=1) summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) # Combined score diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py new file mode 100644 index 000000000..4fa0ebf59 --- /dev/null +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -0,0 +1,178 @@ +import keras +import pytest + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn + + +def test_compositional_score_shape( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test that compositional score returns correct shapes.""" + # Build the model + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + time = 0.5 + + score = simple_diffusion_model.compositional_score( + xz=compositional_state, + time=time, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_compositional_score_no_conditions_raises_error(simple_diffusion_model, compositional_state, mock_prior_score): + """Test that compositional score raises error when conditions is None.""" + simple_diffusion_model.build(keras.ops.shape(compositional_state), None) + + with pytest.raises(ValueError, match="Conditions are required for compositional sampling"): + simple_diffusion_model.compositional_score( + xz=compositional_state, time=0.5, conditions=None, compute_prior_score=mock_prior_score, training=False + ) + + +def test_inverse_compositional_basic( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test basic compositional inverse sampling.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + # Test inverse sampling with ODE method + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler", + steps=5, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_inverse_compositional_euler_maruyama_with_corrector( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional inverse sampling with Euler-Maruyama and corrector steps.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler_maruyama", + steps=5, + corrector_steps=2, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +@pytest.mark.parametrize("noise_schedule_name", ["cosine", "edm"]) +def test_compositional_sampling_with_different_schedules( + noise_schedule_name, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional sampling with different noise schedules.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + diffusion_model = DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule=noise_schedule_name, + prediction_type="noise", + loss_type="noise", + ) + + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + diffusion_model.build(state_shape, conditions_shape) + + score = diffusion_model.compositional_score( + xz=compositional_state, + time=0.5, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) From eac9aaf562eda96326f48473097e324223abf0cd Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:30:57 +0200 Subject: [PATCH 42/61] minor fixes --- bayesflow/networks/diffusion_model/diffusion_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 25a6b4c7c..a6ff78510 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -734,13 +734,13 @@ def compositional_score( individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) # Compute prior score component - weighted_prior_score = (1.0 - time) * compute_prior_score(xz) + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - # Combine scores using compositional formula, mean over individual scores and scale with n to get sum - weighted_individual_scores = individual_scores - keras.ops.expand_dims(weighted_prior_score, axis=1) - summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1) + # Sum individual scores across compositional dimensiont + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - # Combined score + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) time_tensor = ops.cast(time, dtype=ops.dtype(xz)) compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) return compositional_score From 2a9b0e100c2f9fa24bf3ae53cfdcc6a3c0044024 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 17:41:25 +0200 Subject: [PATCH 43/61] minor fixes --- .../diffusion_model/diffusion_model.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index a6ff78510..7a75b0a9b 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -721,7 +721,6 @@ def compositional_score( # Calculate standard noise schedule components log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) # Compute individual dataset scores if mini_batch_size is not None and mini_batch_size < n_compositional: @@ -731,13 +730,13 @@ def compositional_score( conditions_batch = conditions[:, mini_batch_idx] else: conditions_batch = conditions - individual_scores = self._compute_individual_scores(xz, log_snr_t, alpha_t, sigma_t, conditions_batch, training) + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) # Compute prior score component prior_score = compute_prior_score(xz) weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - # Sum individual scores across compositional dimensiont + # Sum individual scores across compositional dimensions summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) @@ -749,8 +748,6 @@ def _compute_individual_scores( self, xz: Tensor, log_snr_t: Tensor, - alpha_t: Tensor, - sigma_t: Tensor, conditions: Tensor, training: bool, ) -> Tensor: @@ -762,9 +759,6 @@ def _compute_individual_scores( Tensor Individual scores with shape (n_datasets, n_compositional, ...) """ - # Apply subnet to each compositional condition separately - transformed_log_snr = self._transform_log_snr(log_snr_t) - # Get shapes xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) @@ -777,38 +771,21 @@ def _compute_individual_scores( xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) - # Expand noise schedule components to match compositional structure - log_snr_expanded = ops.expand_dims(transformed_log_snr, axis=1) + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) - alpha_expanded = ops.expand_dims(alpha_t, axis=1) - alpha_expanded = ops.broadcast_to(alpha_expanded, (n_datasets, n_compositional, num_samples, 1)) - - sigma_expanded = ops.expand_dims(sigma_t, axis=1) - sigma_expanded = ops.broadcast_to(sigma_expanded, (n_datasets, n_compositional, num_samples, 1)) - - # Flatten for subnet application: (n_datasets * n_compositional, num_samples, ..., dims) + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) - alpha_flat = ops.reshape(alpha_expanded, (n_datasets * n_compositional, num_samples, 1)) - sigma_flat = ops.reshape(sigma_expanded, (n_datasets * n_compositional, num_samples, 1)) conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) - # Apply subnet - subnet_out = self._apply_subnet(xz_flat, log_snr_flat, conditions=conditions_flat, training=training) - pred = self.output_projector(subnet_out, training=training) - - # Convert prediction to x - x_pred = self.convert_prediction_to_x( - pred=pred, z=xz_flat, alpha_t=alpha_flat, sigma_t=sigma_flat, log_snr_t=log_snr_flat - ) - - # Compute score: (α_t * x_pred - z) / σ_t² - score = (alpha_flat * x_pred - xz_flat) / ops.square(sigma_flat) + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) # Reshape back to compositional structure - score = ops.reshape(score, (n_datasets, n_compositional, num_samples) + dims) - return score + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores def _inverse_compositional( self, From 9a1ba32dc6e28b49b97cdb87ad0e41d8bbe518bd Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 20:17:00 +0200 Subject: [PATCH 44/61] add test for compute_prior_score_pre --- .../approximators/continuous_approximator.py | 13 +-- tests/test_approximators/conftest.py | 53 +++++++++ .../test_compositional_prior_score.py | 109 ++++++++++++++++++ .../test_diffusion_model/conftest.py | 47 ++++++++ .../test_compositional_sampling.py | 46 -------- 5 files changed, 214 insertions(+), 54 deletions(-) create mode 100644 tests/test_approximators/test_compositional_prior_score.py diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 7a0d757d9..075358c5f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -699,7 +699,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: _samples, forward=False, log_det_jac=True ) else: - log_det_jac_standardize = 0 + log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs @@ -708,15 +708,12 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: for key in adapted_samples: if isinstance(prior_score[key], np.ndarray): prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0: - prior_score[key] += log_det_jac[key] + if len(log_det_jac) > 0 and key in log_det_jac: + prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - # make a tensor - out = keras.ops.concatenate( - list(prior_score.values()), axis=-1 - ) # todo: assumes same order, might be incorrect - return out + expand_right_as(log_det_jac_standardize, out) + out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index a56802a3e..528e7969b 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -220,3 +220,56 @@ def approximator_with_summaries(request): ) case _: raise ValueError("Invalid param for approximator class.") + + +@pytest.fixture +def simple_log_simulator(): + """Create a simple simulator for testing.""" + import numpy as np + from bayesflow.simulators import Simulator + from bayesflow.utils.decorators import allow_batch_size + from bayesflow.types import Shape, Tensor + + class SimpleSimulator(Simulator): + """Simple simulator that generates mean and scale parameters.""" + + @allow_batch_size + def sample(self, batch_shape: Shape) -> dict[str, Tensor]: + # Generate parameters in original space + loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters + scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0 + + # Generate some dummy conditions + conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,)) + + return dict( + loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32") + ) + + return SimpleSimulator() + + +@pytest.fixture +def transforming_adapter(): + """Create an adapter that applies log transformation to scale parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + # Apply log transformation to scale parameters (to make them unbounded) + adapter.log(["scale"]) + + adapter.concatenate(["loc", "scale"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def diffusion_network(): + """Create a diffusion network for compositional sampling.""" + from bayesflow.networks import DiffusionModel, MLP + + return DiffusionModel(subnet=MLP(widths=[32, 32])) diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py new file mode 100644 index 000000000..cd4b81413 --- /dev/null +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -0,0 +1,109 @@ +"""Tests for compositional sampling and prior score computation with adapters.""" + +import numpy as np +import keras + +from bayesflow import ContinuousApproximator +from bayesflow.utils import expand_right_as + + +def mock_prior_score_original_space(data_dict): + """Mock prior score function that expects data in original (loc, scale) space.""" + # The function receives data in the same format the compute_prior_score_pre creates + # after running the inverse adapter + loc = data_dict["loc"] + scale = data_dict["scale"] + + # Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale + loc_score = -loc + scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale) + + return {"loc": loc_score, "scale": scale_score} + + +def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network): + """Test that prior scores work correctly with transforming adapter (log transformation).""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=transforming_adapter, + inference_network=diffusion_network, + ) + + # Generate test data and adapt it + data = simple_log_simulator.sample((2,)) + adapted_data = transforming_adapter(data) + + # Build approximator + approximator.build_from_data(adapted_data) + + # Test compositional sampling + n_datasets, n_compositional = 3, 5 + conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} + + # This should work - the compute_prior_score_pre function should handle the inverse transformation + samples = approximator.compositional_sample( + num_samples=10, + conditions=conditions, + compute_prior_score=mock_prior_score_original_space, + ) + + assert "loc" in samples + assert "scale" in samples + assert samples["loc"].shape == (n_datasets, 10, 2) + assert samples["scale"].shape == (n_datasets, 10, 2) + + +def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network): + """Test that Jacobian correction is applied correctly in compute_prior_score_pre.""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=transforming_adapter, inference_network=diffusion_network, standardize=[] + ) + + # Build with dummy data + dummy_data_dict = simple_log_simulator.sample((1,)) + adapted_dummy_data = transforming_adapter(dummy_data_dict) + approximator.build_from_data(adapted_dummy_data) + + # Get the internal compute_prior_score_pre function + def get_compute_prior_score_pre(): + def compute_prior_score_pre(_samples): + if "inference_variables" in approximator.standardize: + _samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"]( + _samples, forward=False, log_det_jac=True + ) + else: + log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") + + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) + adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True) + + prior_score = mock_prior_score_original_space(adapted_samples) + for key in adapted_samples: + if isinstance(prior_score[key], np.ndarray): + prior_score[key] = prior_score[key].astype("float32") + if len(log_det_jac) > 0 and key in log_det_jac: + prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) + out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) + + return compute_prior_score_pre + + compute_prior_score_pre = get_compute_prior_score_pre() + + # Test with a known transformation + y_samples = adapted_dummy_data["inference_variables"] + scores = compute_prior_score_pre(y_samples) + scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension + + # With Jacobian correction: score_transformed = score_original - log|J| + old_scores = mock_prior_score_original_space(dummy_data_dict) + det_jac_scale = y_samples[0, 2:].sum() + expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten() + + # Check that scores are reasonably close + np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) diff --git a/tests/test_networks/test_diffusion_model/conftest.py b/tests/test_networks/test_diffusion_model/conftest.py index b1ee915ae..581b4abde 100644 --- a/tests/test_networks/test_diffusion_model/conftest.py +++ b/tests/test_networks/test_diffusion_model/conftest.py @@ -1,4 +1,5 @@ import pytest +import keras @pytest.fixture() @@ -21,3 +22,49 @@ def edm_noise_schedule(): ) def noise_schedule(request): return request.getfixturevalue(request.param) + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py index 4fa0ebf59..2757bd28a 100644 --- a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -2,52 +2,6 @@ import pytest -@pytest.fixture -def simple_diffusion_model(): - """Create a simple diffusion model for testing compositional sampling.""" - from bayesflow.networks.diffusion_model import DiffusionModel - from bayesflow.networks import MLP - - return DiffusionModel( - subnet=MLP(widths=[32, 32]), - noise_schedule="cosine", - prediction_type="noise", - loss_type="noise", - ) - - -@pytest.fixture -def compositional_conditions(): - """Create test conditions for compositional sampling.""" - batch_size = 2 - n_compositional = 3 - n_samples = 4 - condition_dim = 5 - - return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) - - -@pytest.fixture -def compositional_state(): - """Create test state for compositional sampling.""" - batch_size = 2 - n_samples = 4 - param_dim = 3 - - return keras.random.normal((batch_size, n_samples, param_dim)) - - -@pytest.fixture -def mock_prior_score(): - """Create a mock prior score function for testing.""" - - def prior_score_fn(theta): - # Simple quadratic prior: -0.5 * ||theta||^2 - return -theta - - return prior_score_fn - - def test_compositional_score_shape( simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score ): From 93b59ba0da9e3b556ff0c7b718219688139709d9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 12 Sep 2025 20:44:56 +0200 Subject: [PATCH 45/61] fix order of prior scores --- bayesflow/approximators/continuous_approximator.py | 2 +- tests/test_approximators/conftest.py | 2 +- tests/test_approximators/test_compositional_prior_score.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 075358c5f..e4e4f09c2 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -712,7 +712,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor: prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) # Test prior score function, useful for debugging diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 528e7969b..5587901b5 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -261,7 +261,7 @@ def transforming_adapter(): # Apply log transformation to scale parameters (to make them unbounded) adapter.log(["scale"]) - adapter.concatenate(["loc", "scale"], into="inference_variables") + adapter.concatenate(["scale", "loc"], into="inference_variables") adapter.concatenate(["conditions"], into="inference_conditions") adapter.keep(["inference_variables", "inference_conditions"]) return adapter diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py index cd4b81413..96ac7d29e 100644 --- a/tests/test_approximators/test_compositional_prior_score.py +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -88,7 +88,7 @@ def compute_prior_score_pre(_samples): prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate(list(prior_score.values()), axis=-1) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) return compute_prior_score_pre @@ -102,8 +102,9 @@ def compute_prior_score_pre(_samples): # With Jacobian correction: score_transformed = score_original - log|J| old_scores = mock_prior_score_original_space(dummy_data_dict) - det_jac_scale = y_samples[0, 2:].sum() - expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten() + # order of parameters is flipped due to concatenation in adapter + det_jac_scale = y_samples[0, :2].sum() + expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten() # Check that scores are reasonably close np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) From 922040d4cd36caf9f0a0baa2964ac9bc7cefe7b2 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:29:34 +0200 Subject: [PATCH 46/61] fix prior scores standardize --- .../approximators/continuous_approximator.py | 48 ++++++++--- tests/test_approximators/conftest.py | 15 ++++ .../test_compositional_prior_score.py | 79 ++----------------- 3 files changed, 58 insertions(+), 84 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index e4e4f09c2..24b02f145 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -14,7 +14,6 @@ squeeze_inner_estimates_dict, concatenate_valid, concatenate_valid_shapes, - expand_right_as, ) from bayesflow.utils.serialization import serialize, deserialize, serializable @@ -695,25 +694,52 @@ def compositional_sample( # Prepare prior scores to handle adapter def compute_prior_score_pre(_samples: Tensor) -> Tensor: if "inference_variables" in self.standardize: - _samples, log_det_jac_standardize = self.standardize_layers["inference_variables"]( - _samples, forward=False, log_det_jac=True - ) - else: - log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") + _samples = self.standardize_layers["inference_variables"](_samples, forward=False) _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) adapted_samples, log_det_jac = self.adapter( _samples, inverse=True, strict=False, log_det_jac=True, **kwargs ) + + if len(log_det_jac) > 0: + problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0] + raise NotImplementedError( + f"Cannot use compositional sampling with adapters " + f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}" + ) + prior_score = compute_prior_score(adapted_samples) for key in adapted_samples: - if isinstance(prior_score[key], np.ndarray): - prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0 and key in log_det_jac: - prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) + prior_score[key] = prior_score[key].astype(np.float32) prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) - return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) + + if "inference_variables" in self.standardize: + # Apply jacobian correction from standardization + # For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal + # The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z + # But we need to transform the score: score_z = score_x * std where x = T^{-1}(z) + standardize_layer = self.standardize_layers["inference_variables"] + + # Compute the correct standard deviation for all components + std_components = [] + for idx in range(len(standardize_layer.moving_mean)): + std_val = standardize_layer.moving_std(idx) + std_components.append(std_val) + + # Concatenate std components to match the shape of out + if len(std_components) == 1: + std = std_components[0] + else: + std = keras.ops.concatenate(std_components, axis=-1) + + # Expand std to match batch dimension of out + std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions + std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1]) + + # Apply the jacobian: score_z = score_x * std + out = out * std_expanded + return out # Test prior score function, useful for debugging test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 5587901b5..befc0da06 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -249,6 +249,21 @@ def sample(self, batch_shape: Shape) -> dict[str, Tensor]: return SimpleSimulator() +@pytest.fixture +def identity_adapter(): + """Create an adapter that applies no transformation to the parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + adapter.concatenate(["loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + @pytest.fixture def transforming_adapter(): """Create an adapter that applies log transformation to scale parameters.""" diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py index 96ac7d29e..02be46c00 100644 --- a/tests/test_approximators/test_compositional_prior_score.py +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -1,38 +1,31 @@ """Tests for compositional sampling and prior score computation with adapters.""" import numpy as np -import keras from bayesflow import ContinuousApproximator -from bayesflow.utils import expand_right_as def mock_prior_score_original_space(data_dict): - """Mock prior score function that expects data in original (loc, scale) space.""" - # The function receives data in the same format the compute_prior_score_pre creates - # after running the inverse adapter + """Mock prior score function that expects data in original space.""" loc = data_dict["loc"] - scale = data_dict["scale"] - # Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale + # Simple prior: N(0,1) for loc loc_score = -loc - scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale) + return {"loc": loc_score} - return {"loc": loc_score, "scale": scale_score} - -def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network): +def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network): """Test that prior scores work correctly with transforming adapter (log transformation).""" # Create approximator with transforming adapter approximator = ContinuousApproximator( - adapter=transforming_adapter, + adapter=identity_adapter, inference_network=diffusion_network, ) # Generate test data and adapt it data = simple_log_simulator.sample((2,)) - adapted_data = transforming_adapter(data) + adapted_data = identity_adapter(data) # Build approximator approximator.build_from_data(adapted_data) @@ -40,8 +33,6 @@ def test_prior_score_transforming_adapter(simple_log_simulator, transforming_ada # Test compositional sampling n_datasets, n_compositional = 3, 5 conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} - - # This should work - the compute_prior_score_pre function should handle the inverse transformation samples = approximator.compositional_sample( num_samples=10, conditions=conditions, @@ -49,62 +40,4 @@ def test_prior_score_transforming_adapter(simple_log_simulator, transforming_ada ) assert "loc" in samples - assert "scale" in samples assert samples["loc"].shape == (n_datasets, 10, 2) - assert samples["scale"].shape == (n_datasets, 10, 2) - - -def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network): - """Test that Jacobian correction is applied correctly in compute_prior_score_pre.""" - - # Create approximator with transforming adapter - approximator = ContinuousApproximator( - adapter=transforming_adapter, inference_network=diffusion_network, standardize=[] - ) - - # Build with dummy data - dummy_data_dict = simple_log_simulator.sample((1,)) - adapted_dummy_data = transforming_adapter(dummy_data_dict) - approximator.build_from_data(adapted_dummy_data) - - # Get the internal compute_prior_score_pre function - def get_compute_prior_score_pre(): - def compute_prior_score_pre(_samples): - if "inference_variables" in approximator.standardize: - _samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"]( - _samples, forward=False, log_det_jac=True - ) - else: - log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") - - _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) - adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True) - - prior_score = mock_prior_score_original_space(adapted_samples) - for key in adapted_samples: - if isinstance(prior_score[key], np.ndarray): - prior_score[key] = prior_score[key].astype("float32") - if len(log_det_jac) > 0 and key in log_det_jac: - prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) - - prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) - out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) - return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) - - return compute_prior_score_pre - - compute_prior_score_pre = get_compute_prior_score_pre() - - # Test with a known transformation - y_samples = adapted_dummy_data["inference_variables"] - scores = compute_prior_score_pre(y_samples) - scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension - - # With Jacobian correction: score_transformed = score_original - log|J| - old_scores = mock_prior_score_original_space(dummy_data_dict) - # order of parameters is flipped due to concatenation in adapter - det_jac_scale = y_samples[0, :2].sum() - expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten() - - # Check that scores are reasonably close - np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) From b2991d177bb24deab200cf6419ad0c823781febb Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:51:37 +0200 Subject: [PATCH 47/61] better standard values for compositional --- .../networks/diffusion_model/diffusion_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 7a75b0a9b..b026d5ea9 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -799,21 +799,21 @@ def _inverse_compositional( """ Inverse pass for compositional diffusion sampling. """ - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - - if mini_batch_size is not None: - # if backend is jax, mini batching does not work - if keras.backend.backend() == "jax": + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: raise ValueError( "Mini batching is not supported with JAX backend. Set mini_batch_size to None " "or use another backend." ) + else: + mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - n_compositional = ops.shape(conditions)[1] scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) From d2a36a8349bbe95acee2e31ba4f065958205a751 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:57:32 +0200 Subject: [PATCH 48/61] better compositional_bridge --- bayesflow/networks/diffusion_model/diffusion_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index b026d5ea9..c9e2a2271 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -3,6 +3,7 @@ import keras from keras import ops +import numpy as np from ..inference_network import InferenceNetwork from bayesflow.types import Tensor, Shape @@ -600,10 +601,10 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} - @staticmethod - def compositional_bridge(time: Tensor) -> Tensor: + def compositional_bridge(self, time: Tensor) -> Tensor: """ - Bridge function for compositional diffusion. In the simplest case, this is just 1. + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. Parameters ---------- @@ -616,7 +617,7 @@ def compositional_bridge(time: Tensor) -> Tensor: Bridge function value with same shape as time. """ - return ops.ones_like(time) + return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time) def compositional_velocity( self, @@ -812,6 +813,8 @@ def _inverse_compositional( ) else: mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) + self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) + self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) From 0ff960f8dfd80237830a368f3333e4f9e91196c7 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 13:59:01 +0200 Subject: [PATCH 49/61] fix integrate_kwargs --- bayesflow/networks/diffusion_model/diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index c9e2a2271..36063b577 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -812,7 +812,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1)) + mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) From b2ef75522268811558899b123f63cd8ec2ae38ff Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 13 Sep 2025 14:06:07 +0200 Subject: [PATCH 50/61] fix integrate_kwargs --- bayesflow/networks/diffusion_model/diffusion_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 36063b577..99b712203 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -617,7 +617,7 @@ def compositional_bridge(self, time: Tensor) -> Tensor: Bridge function value with same shape as time. """ - return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time) + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) def compositional_velocity( self, @@ -813,8 +813,8 @@ def _inverse_compositional( ) else: mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) - self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0)) - self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0)) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) # x is sampled from a normal distribution, must be scaled with var 1/n_compositional scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) @@ -893,6 +893,7 @@ def score_fn(time, xz): **integrate_kwargs, ) else: + integrate_kwargs.pop("corrector_steps", None) def deltas(time, xz): return { From ca7f3bdaf9700d2fd6268c844c4abf6fb5963139 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 12:21:13 +0200 Subject: [PATCH 51/61] fix kwargs in sample --- bayesflow/networks/transformers/mab.py | 4 +- .../networks/transformers/set_transformer.py | 10 ++- tests/test_approximators/test_sample.py | 90 +++++++++++++++++++ 3 files changed, 99 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index d0d748067..94690f3ef 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same +from bayesflow.utils import check_lengths_same, filter_kwargs from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,7 +147,11 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks(input_set, training=training, **kwargs) - summary = self.pooling_by_attention(summary, training=training, **kwargs) + summary = self.attention_blocks( + input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) + ) + summary = self.pooling_by_attention( + summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) + ) summary = self.output_projector(summary) return summary diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index c62ffc581..e76b72a40 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,3 +1,4 @@ +import pytest import keras from tests.utils import check_combination_simulator_adapter @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter): samples = approximator.sample(num_samples=2, conditions=data) assert isinstance(samples, dict) + + +@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"]) +@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"]) +@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"]) +def test_approximator_sample_with_integration_methods( + inference_network_type, summary_network_type, method, simulator, adapter +): + """Test approximator sampling with different integration methods and summary networks. + + Tests flow matching and diffusion models with different ODE/SDE solvers: + - euler, rk45: Available for both flow matching and diffusion models + - euler_maruyama: Only for diffusion models (stochastic) + + Also tests with different summary network types. + """ + batch_size = 8 # Use smaller batch size for faster tests + check_combination_simulator_adapter(simulator, adapter) + + # Skip euler_maruyama for flow matching (deterministic model) + if inference_network_type == "flow_matching" and method == "euler_maruyama": + pytest.skip("euler_maruyama is only available for diffusion models") + + # Create inference network based on type + if inference_network_type == "flow_matching": + from bayesflow.networks import FlowMatching, MLP + + inference_network = FlowMatching( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + elif inference_network_type == "diffusion_model": + from bayesflow.networks import DiffusionModel, MLP + + inference_network = DiffusionModel( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + else: + pytest.skip(f"Unsupported inference network type: {inference_network_type}") + + # Create summary network based on type + summary_network = None + if summary_network_type != "none": + if summary_network_type == "deep_set": + from bayesflow.networks import DeepSet, MLP + + summary_network = DeepSet(subnet=MLP(widths=[16, 16])) + elif summary_network_type == "set_transformer": + from bayesflow.networks import SetTransformer + + summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16]) + elif summary_network_type == "time_series": + from bayesflow.networks import TimeSeriesNetwork + + summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm") + else: + pytest.skip(f"Unsupported summary network type: {summary_network_type}") + + # Update adapter to include summary variables if summary network is present + from bayesflow import ContinuousApproximator + + adapter = ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + summary_variables=["x"], # Use x as summary variable for testing + ) + + # Create approximator + from bayesflow import ContinuousApproximator + + approximator = ContinuousApproximator( + adapter=adapter, inference_network=inference_network, summary_network=summary_network + ) + + # Generate test data + num_batches = 2 # Use fewer batches for faster tests + data = simulator.sample((num_batches * batch_size,)) + + # Build approximator + batch = adapter(data) + batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch) + batch_shapes = keras.tree.map_structure(keras.ops.shape, batch) + approximator.build(batch_shapes) + + # Test sampling with the specified method + samples = approximator.sample(num_samples=2, conditions=data, method=method) + + # Verify results + assert isinstance(samples, dict) From 2c161c6d85f675eac9347d036cbc93b76524b323 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:32:20 +0200 Subject: [PATCH 52/61] fix kwargs in set transformer --- bayesflow/networks/transformers/isab.py | 1 + bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 10 +++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 03f15a561..1b763c2b3 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,5 +107,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) + print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index eddb8cf09..5bd7c9dff 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs, filter_kwargs +from bayesflow.utils import layer_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) + query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 94690f3ef..7e9da76ea 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same, filter_kwargs +from bayesflow.utils import check_lengths_same from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,11 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks( - input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) - ) - summary = self.pooling_by_attention( - summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) - ) + summary = self.attention_blocks(input_set, training=training) + summary = self.pooling_by_attention(summary, training=training) summary = self.output_projector(summary) return summary From 9d4c1a1c605c7e226ea72f97321e1a55c7e718ac Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:37:38 +0200 Subject: [PATCH 53/61] fix kwargs in set transformer --- bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 7e9da76ea..bd8290272 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -148,6 +148,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: Output of shape (batch_size, set_size, output_dim) """ summary = self.attention_blocks(input_set, training=training) - summary = self.pooling_by_attention(summary, training=training) + summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary From ea0659d14962e4b423a42e5bbf53dd79f1797eb9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:38:40 +0200 Subject: [PATCH 54/61] remove print --- bayesflow/networks/transformers/isab.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 1b763c2b3..03f15a561 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,6 +107,5 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) - print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs) From 9220816a1662dc42ddd2014ed4d7098b8230ac6d Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 22 Sep 2025 10:31:49 +0200 Subject: [PATCH 55/61] new class for compositional diffusion --- .../networks/diffusion_model/__init__.py | 1 + .../compositional_diffusion_model.py | 412 ++++++++++++++++++ .../diffusion_model/diffusion_model.py | 315 +------------ 3 files changed, 414 insertions(+), 314 deletions(-) create mode 100644 bayesflow/networks/diffusion_model/compositional_diffusion_model.py diff --git a/bayesflow/networks/diffusion_model/__init__.py b/bayesflow/networks/diffusion_model/__init__.py index 341c84c62..ca8aa19be 100644 --- a/bayesflow/networks/diffusion_model/__init__.py +++ b/bayesflow/networks/diffusion_model/__init__.py @@ -1,4 +1,5 @@ from .diffusion_model import DiffusionModel +from .compositional_diffusion_model import CompositionalDiffusionModel from .schedules import CosineNoiseSchedule from .schedules import EDMNoiseSchedule from .schedules import NoiseSchedule diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py new file mode 100644 index 000000000..cde1290ed --- /dev/null +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -0,0 +1,412 @@ +from typing import Literal, Callable + +import keras +import numpy as np +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils import ( + expand_right_as, + integrate, + integrate_stochastic, +) +from bayesflow.utils.serialization import serializable +from diffusion_model import DiffusionModel +from .schedules.noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class CompositionalDiffusionModel(DiffusionModel): + """Compositional Diffusion Model for Amortized Bayesian Inference. Allows to learn a single + diffusion model one single i.i.d simulations that can perform inference for multiple simulations by leveraging a + compositional score function as in [2]. + + [1] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) + [2] Compositional Score Modeling for Simulation-Based Inference: Geffner et al. (2023) + [3] Compositional amortized inference for large-scale hierarchical Bayesian models: Arruda et al. (2025) + """ + + MLP_DEFAULT_CONFIG = { + "widths": (256, 256, 256, 256, 256), + "activation": "mish", + "kernel_initializer": "he_normal", + "residual": True, + "dropout": 0.0, + "spectral_normalization": False, + } + + INTEGRATE_DEFAULT_CONFIG = { + "method": "euler_maruyama", + "corrector_steps": 1, + "steps": 100, + } + + def __init__( + self, + *, + subnet: str | type | keras.Layer = "mlp", + noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", + prediction_type: Literal["velocity", "noise", "F", "x"] = "F", + loss_type: Literal["velocity", "noise", "F"] = "noise", + subnet_kwargs: dict[str, any] = None, + schedule_kwargs: dict[str, any] = None, + integrate_kwargs: dict[str, any] = None, + **kwargs, + ): + """ + Initializes a diffusion model with configurable subnet architecture, noise schedule, + and prediction/loss types for amortized Bayesian inference. + + Note, that score-based diffusion is the most sluggish of all available samplers, + so expect slower inference times than flow matching and much slower than normalizing flows. + + Parameters + ---------- + subnet : str, type or keras.Layer, optional + Architecture for the transformation network. Can be "mlp", a custom network class, or + a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp". + noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional + Noise schedule controlling the diffusion dynamics. Can be a string identifier, + a schedule class, or a pre-initialized schedule instance. Default is "edm". + prediction_type : {'velocity', 'noise', 'F', 'x'}, optional + Output format of the model's prediction. Default is "F". + loss_type : {'velocity', 'noise', 'F'}, optional + Loss function used to train the model. Default is "noise". + subnet_kwargs : dict[str, any], optional + Additional keyword arguments passed to the subnet constructor. Default is None. + schedule_kwargs : dict[str, any], optional + Additional keyword arguments passed to the noise schedule constructor. Default is None. + integrate_kwargs : dict[str, any], optional + Configuration dictionary for integration during training or inference. Default is None. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. + + **kwargs + Additional keyword arguments passed to the base class and internal components. + """ + super().__init__( + subnet=subnet, + noise_schedule=noise_schedule, + prediction_type=prediction_type, + loss_type=loss_type, + subnet_kwargs=subnet_kwargs, + schedule_kwargs=schedule_kwargs, + integrate_kwargs=integrate_kwargs, + **kwargs, + ) + + def compositional_bridge(self, time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_compositional = ops.shape(conditions)[1] + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + # Compute individual dataset scores + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + conditions_batch = conditions[:, mini_batch_idx] + else: + conditions_batch = conditions + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) + + # Compute prior score component + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + + # Sum individual scores across compositional dimensions + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) + + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) + return compositional_score + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) + + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) + + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) + + # Reshape back to compositional structure + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion sampling. + """ + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + else: + mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + integrate_kwargs.pop("corrector_steps", None) + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 99b712203..9955c4abc 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -1,9 +1,8 @@ from collections.abc import Sequence -from typing import Literal, Callable +from typing import Literal import keras from keras import ops -import numpy as np from ..inference_network import InferenceNetwork from bayesflow.types import Tensor, Shape @@ -600,315 +599,3 @@ def compute_metrics( base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) return base_metrics | {"loss": loss} - - def compositional_bridge(self, time: Tensor) -> Tensor: - """ - Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. - Otherwise, it can be used to scale the compositional score over time. - - Parameters - ---------- - time: Tensor - Time step for the diffusion process. - - Returns - ------- - Tensor - Bridge function value with same shape as time. - - """ - return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) - - def compositional_velocity( - self, - xz: Tensor, - time: float | Tensor, - stochastic_solver: bool, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - mini_batch_size: int | None = None, - training: bool = False, - ) -> Tensor: - """ - Computes the compositional velocity for multiple datasets using the formula: - s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - - Parameters - ---------- - xz : Tensor - The current state of the latent variable, shape (n_datasets, n_compositional, ...) - time : float or Tensor - Time step for the diffusion process - stochastic_solver : bool - Whether to use stochastic (SDE) or deterministic (ODE) formulation - conditions : Tensor - Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - compute_prior_score: Callable - Function to compute the prior score ∇_θ log p(θ). - mini_batch_size : int or None - Mini batch size for computing individual scores. If None, use all conditions. - training : bool, optional - Whether in training mode - - Returns - ------- - Tensor - Compositional velocity of same shape as input xz - """ - # Calculate standard noise schedule components - log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - - compositional_score = self.compositional_score( - xz=xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - - # Compute velocity using standard drift-diffusion formulation - f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) - - if stochastic_solver: - # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW - velocity = f - g_squared * compositional_score - else: - # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt - velocity = f - 0.5 * g_squared * compositional_score - - return velocity - - def compositional_score( - self, - xz: Tensor, - time: float | Tensor, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - mini_batch_size: int | None = None, - training: bool = False, - ) -> Tensor: - """ - Computes the compositional score for multiple datasets using the formula: - s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - - Parameters - ---------- - xz : Tensor - The current state of the latent variable, shape (n_datasets, n_compositional, ...) - time : float or Tensor - Time step for the diffusion process - conditions : Tensor - Conditional inputs with compositional structure (n_datasets, n_compositional, ...) - compute_prior_score: Callable - Function to compute the prior score ∇_θ log p(θ). - mini_batch_size : int or None - Mini batch size for computing individual scores. If None, use all conditions. - training : bool, optional - Whether in training mode - - Returns - ------- - Tensor - Compositional velocity of same shape as input xz - """ - if conditions is None: - raise ValueError("Conditions are required for compositional sampling") - - # Get shapes for compositional structure - n_compositional = ops.shape(conditions)[1] - - # Calculate standard noise schedule components - log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) - log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - - # Compute individual dataset scores - if mini_batch_size is not None and mini_batch_size < n_compositional: - # sample random indices for mini-batch processing - mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) - mini_batch_idx = mini_batch_idx[:mini_batch_size] - conditions_batch = conditions[:, mini_batch_idx] - else: - conditions_batch = conditions - individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) - - # Compute prior score component - prior_score = compute_prior_score(xz) - weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score - - # Sum individual scores across compositional dimensions - summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) - - # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) - time_tensor = ops.cast(time, dtype=ops.dtype(xz)) - compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) - return compositional_score - - def _compute_individual_scores( - self, - xz: Tensor, - log_snr_t: Tensor, - conditions: Tensor, - training: bool, - ) -> Tensor: - """ - Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. - - Returns - ------- - Tensor - Individual scores with shape (n_datasets, n_compositional, ...) - """ - # Get shapes - xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) - conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) - n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] - conditions_dims = tuple(conditions_shape[3:]) - num_samples = xz_shape[1] - dims = tuple(xz_shape[2:]) - - # Expand xz to match compositional structure - xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) - xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) - - # Expand log_snr_t to match compositional structure - log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) - log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) - - # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) - xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) - log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) - conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) - - # Use standard score function - scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) - - # Reshape back to compositional structure - scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) - return scores - - def _inverse_compositional( - self, - z: Tensor, - conditions: Tensor, - compute_prior_score: Callable[[Tensor], Tensor], - density: bool = False, - training: bool = False, - **kwargs, - ) -> Tensor | tuple[Tensor, Tensor]: - """ - Inverse pass for compositional diffusion sampling. - """ - n_compositional = ops.shape(conditions)[1] - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} - integrate_kwargs = integrate_kwargs | self.integrate_kwargs - integrate_kwargs = integrate_kwargs | kwargs - if keras.backend.backend() == "jax": - mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) - if mini_batch_size is not None: - raise ValueError( - "Mini batching is not supported with JAX backend. Set mini_batch_size to None " - "or use another backend." - ) - else: - mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) - self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) - self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) - - # x is sampled from a normal distribution, must be scaled with var 1/n_compositional - scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) - z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) - - if density: - if integrate_kwargs["method"] == "euler_maruyama": - raise ValueError("Stochastic methods are not supported for density computation.") - - def deltas(time, xz): - v = self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) - return {"xz": v, "trace": trace} - - state = { - "xz": z, - "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), - } - state = integrate(deltas, state, **integrate_kwargs) - - x = state["xz"] - log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) - return x, log_density - - state = {"xz": z} - - if integrate_kwargs["method"] == "euler_maruyama": - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=True, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - def diffusion(time, xz): - return {"xz": self.diffusion_term(xz, time=time, training=training)} - - score_fn = None - if "corrector_steps" in integrate_kwargs: - if integrate_kwargs["corrector_steps"] > 0: - - def score_fn(time, xz): - return { - "xz": self.compositional_score( - xz, - time=time, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = integrate_stochastic( - drift_fn=deltas, - diffusion_fn=diffusion, - score_fn=score_fn, - noise_schedule=self.noise_schedule, - state=state, - seed=self.seed_generator, - **integrate_kwargs, - ) - else: - integrate_kwargs.pop("corrector_steps", None) - - def deltas(time, xz): - return { - "xz": self.compositional_velocity( - xz, - time=time, - stochastic_solver=False, - conditions=conditions, - compute_prior_score=compute_prior_score, - mini_batch_size=mini_batch_size, - training=training, - ) - } - - state = integrate(deltas, state, **integrate_kwargs) - - x = state["xz"] - return x From ee1c3209a429f0c03e1fa698ddeb6f9bf5c5fe5a Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 22 Sep 2025 10:34:16 +0200 Subject: [PATCH 56/61] fix import --- .../networks/diffusion_model/compositional_diffusion_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index cde1290ed..8095aba7d 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -11,7 +11,7 @@ integrate_stochastic, ) from bayesflow.utils.serialization import serializable -from diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel from .schedules.noise_schedule import NoiseSchedule @@ -299,7 +299,7 @@ def _inverse_compositional( Inverse pass for compositional diffusion sampling. """ n_compositional = ops.shape(conditions)[1] - integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1} + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} integrate_kwargs = integrate_kwargs | self.integrate_kwargs integrate_kwargs = integrate_kwargs | kwargs if keras.backend.backend() == "jax": From e6513c1992fca355cdd0ebbaeecfda9db0710731 Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 15:41:40 +0200 Subject: [PATCH 57/61] add import --- bayesflow/networks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index f71d4b536..fb9819445 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -7,7 +7,7 @@ from .consistency_models import ConsistencyModel from .coupling_flow import CouplingFlow from .deep_set import DeepSet -from .diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel, CompositionalDiffusionModel from .flow_matching import FlowMatching from .inference_network import InferenceNetwork from .point_inference_network import PointInferenceNetwork From e87f9d153dac4b9a1fdfe944b7417a768967c66b Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 18:13:29 +0200 Subject: [PATCH 58/61] fix mini_batch_size --- .../networks/diffusion_model/compositional_diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index 8095aba7d..abd7d49a9 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -310,7 +310,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)) + mini_batch_size = min(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) From 983cb8d399097dd9d9451b02b830c21a16febc8b Mon Sep 17 00:00:00 2001 From: arrjon Date: Fri, 26 Sep 2025 18:16:02 +0200 Subject: [PATCH 59/61] fix mini_batch_size --- .../networks/diffusion_model/compositional_diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py index abd7d49a9..171184314 100644 --- a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -310,7 +310,7 @@ def _inverse_compositional( "or use another backend." ) else: - mini_batch_size = min(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) + mini_batch_size = max(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) From 3b887dc8ff8e3a7c8970a8e7a5445db73bc49aa3 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 27 Oct 2025 11:20:26 +0100 Subject: [PATCH 60/61] fix scm --- .../stable_consistency_model/stable_consistency_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index 6ce27fdf4..dc092ab4e 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -307,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + o = self._apply_subnet(x, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -321,7 +321,7 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + subnet_out = self._apply_subnet(xt / self.sigma, self.time_emb(t), conditions, training=stage == "training") student_out = self.subnet_projector(subnet_out) # calculate the tangent From 64516a464ba828a6dd45db7f03f5e4cf4c7cb79a Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 29 Oct 2025 17:31:44 +0100 Subject: [PATCH 61/61] fix saving --- .../stable_consistency_model/stable_consistency_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index dc092ab4e..0f787d44f 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -105,7 +105,6 @@ def __init__( ) embedding_kwargs = embedding_kwargs or {} - self.embedding_kwargs = embedding_kwargs self.time_emb = FourierEmbedding(**embedding_kwargs) self.time_emb_dim = self.time_emb.embed_dim @@ -123,13 +122,14 @@ def get_config(self): config = { "subnet": self.subnet, "sigma": self.sigma, - "embedding_kwargs": self.embedding_kwargs, + "time_emb": self.time_emb, "concatenate_subnet_input": self._concatenate_subnet_input, } return base_config | serialize(config) - def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): + @staticmethod + def _discretize_time(num_steps: int, rho: float = 3.5): t = keras.ops.linspace(0.0, pi / 2, num_steps) times = keras.ops.exp((t - pi / 2) * rho) * pi / 2 times = keras.ops.concatenate([keras.ops.zeros((1,)), times[1:]], axis=0)