@@ -85,6 +85,12 @@ def __init__(
8585 Additional keyword arguments passed to the noise schedule constructor. Default is None.
8686 integrate_kwargs : dict[str, any], optional
8787 Configuration dictionary for integration during training or inference. Default is None.
88+ concatenate_subnet_input: bool, optional
89+ Flag for advanced users to control whether all inputs to the subnet should be concatenated
90+ into a single vector or passed as separate arguments. If set to False, the subnet
91+ must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
92+ and optional 'conditions'. Default is True.
93+
8894 **kwargs
8995 Additional keyword arguments passed to the base class and internal components.
9096 """
@@ -116,6 +122,7 @@ def __init__(
116122 if subnet == "mlp" :
117123 subnet_kwargs = DiffusionModel .MLP_DEFAULT_CONFIG | subnet_kwargs
118124 self .subnet = find_network (subnet , ** subnet_kwargs )
125+ self ._concatenate_subnet_input = kwargs .get ("concatenate_subnet_input" , True )
119126
120127 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" , name = "output_projector" )
121128
@@ -128,15 +135,23 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
128135 self .output_projector .units = xz_shape [- 1 ]
129136 input_shape = list (xz_shape )
130137
131- # construct time vector
132- input_shape [- 1 ] += 1
133- if conditions_shape is not None :
134- input_shape [- 1 ] += conditions_shape [- 1 ]
138+ if self ._concatenate_subnet_input :
139+ # construct time vector
140+ input_shape [- 1 ] += 1
141+ if conditions_shape is not None :
142+ input_shape [- 1 ] += conditions_shape [- 1 ]
143+ input_shape = tuple (input_shape )
135144
136- input_shape = tuple (input_shape )
145+ self .subnet .build (input_shape )
146+ out_shape = self .subnet .compute_output_shape (input_shape )
147+ else :
148+ # Multiple separate inputs
149+ time_shape = tuple (xz_shape [:- 1 ]) + (1 ,) # same batch/sequence dims, 1 feature
150+ self .subnet .build (x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape )
151+ out_shape = self .subnet .compute_output_shape (
152+ x_shape = xz_shape , t_shape = time_shape , conditions_shape = conditions_shape
153+ )
137154
138- self .subnet .build (input_shape )
139- out_shape = self .subnet .compute_output_shape (input_shape )
140155 self .output_projector .build (out_shape )
141156
142157 def get_config (self ):
@@ -149,6 +164,8 @@ def get_config(self):
149164 "prediction_type" : self ._prediction_type ,
150165 "loss_type" : self ._loss_type ,
151166 "integrate_kwargs" : self .integrate_kwargs ,
167+ "concatenate_subnet_input" : self ._concatenate_subnet_input ,
168+ # we do not need to store subnet_kwargs
152169 }
153170 return base_config | serialize (config )
154171
@@ -197,6 +214,35 @@ def convert_prediction_to_x(
197214 return (z + sigma_t ** 2 * pred ) / alpha_t
198215 raise ValueError (f"Unknown prediction type { self ._prediction_type } ." )
199216
217+ def _apply_subnet (
218+ self , xz : Tensor , log_snr : Tensor , conditions : Tensor = None , training : bool = False
219+ ) -> Tensor | tuple [Tensor , Tensor , Tensor ]:
220+ """
221+ Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
222+ the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
223+
224+ Parameters
225+ ----------
226+ xz : Tensor
227+ The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
228+ log_snr : Tensor
229+ The log signal-to-noise ratio tensor, typically of shape (..., 1).
230+ conditions : Tensor, optional
231+ The optional conditioning tensor (e.g. parameters).
232+ training : bool, optional
233+ The training mode flag, which can be used to control behavior during training.
234+
235+ Returns
236+ -------
237+ Tensor
238+ The output tensor from the subnet.
239+ """
240+ if self ._concatenate_subnet_input :
241+ xtc = tensor_utils .concatenate_valid ([xz , log_snr , conditions ], axis = - 1 )
242+ return self .subnet (xtc , training = training )
243+ else :
244+ return self .subnet (x = xz , t = log_snr , conditions = conditions , training = training )
245+
200246 def velocity (
201247 self ,
202248 xz : Tensor ,
@@ -221,7 +267,7 @@ def velocity(
221267 If True, computes the velocity for the stochastic formulation (SDE).
222268 If False, uses the deterministic formulation (ODE).
223269 conditions : Tensor, optional
224- Optional conditional inputs to the network, such as conditioning variables
270+ Conditional inputs to the network, such as conditioning variables
225271 or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226272 training : bool, optional
227273 Whether the model is in training mode. Affects behavior of dropout, batch norm,
@@ -238,12 +284,10 @@ def velocity(
238284 log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
239285 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
240286
241- if conditions is None :
242- xtc = tensor_utils .concatenate_valid ([xz , self ._transform_log_snr (log_snr_t )], axis = - 1 )
243- else :
244- xtc = tensor_utils .concatenate_valid ([xz , self ._transform_log_snr (log_snr_t ), conditions ], axis = - 1 )
245-
246- pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
287+ subnet_out = self ._apply_subnet (
288+ xz , self ._transform_log_snr (log_snr_t ), conditions = conditions , training = training
289+ )
290+ pred = self .output_projector (subnet_out , training = training )
247291
248292 x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
249293
@@ -461,11 +505,10 @@ def compute_metrics(
461505 diffused_x = alpha_t * x + sigma_t * eps_t
462506
463507 # calculate output of the network
464- if conditions is None :
465- xtc = tensor_utils .concatenate_valid ([diffused_x , self ._transform_log_snr (log_snr_t )], axis = - 1 )
466- else :
467- xtc = tensor_utils .concatenate_valid ([diffused_x , self ._transform_log_snr (log_snr_t ), conditions ], axis = - 1 )
468- pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
508+ subnet_out = self ._apply_subnet (
509+ diffused_x , self ._transform_log_snr (log_snr_t ), conditions = conditions , training = training
510+ )
511+ pred = self .output_projector (subnet_out , training = training )
469512
470513 x_pred = self .convert_prediction_to_x (
471514 pred = pred , z = diffused_x , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t
0 commit comments