@@ -166,6 +166,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
166166 def prepare_inputs (self , data : "BlockState" ) -> List ["BlockState" ]:
167167 raise NotImplementedError ("BaseGuidance::prepare_inputs must be implemented in subclasses." )
168168
169+ def prepare_inputs_from_block_state (
170+ self , data : "BlockState" , input_fields : Dict [str , Union [str , Tuple [str , str ]]]
171+ ) -> List ["BlockState" ]:
172+ raise NotImplementedError ("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses." )
173+
169174 def __call__ (self , data : List ["BlockState" ]) -> Any :
170175 if not all (hasattr (d , "noise_pred" ) for d in data ):
171176 raise ValueError ("Expected all data to have `noise_pred` attribute." )
@@ -234,6 +239,51 @@ def _prepare_batch(
234239 data_batch [cls ._identifier_key ] = identifier
235240 return BlockState (** data_batch )
236241
242+ @classmethod
243+ def _prepare_batch_from_block_state (
244+ cls ,
245+ input_fields : Dict [str , Union [str , Tuple [str , str ]]],
246+ data : "BlockState" ,
247+ tuple_index : int ,
248+ identifier : str ,
249+ ) -> "BlockState" :
250+ """
251+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
252+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
253+
254+ Args:
255+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
256+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
257+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
258+ to look up the required data provided for preparation. If a string is provided, it will be used as the
259+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
260+ length 2 is provided, the first element must be the conditional data identifier and the second element
261+ must be the unconditional data identifier or None.
262+ data (`BlockState`):
263+ The input data to be prepared.
264+ tuple_index (`int`):
265+ The index to use when accessing input fields that are tuples.
266+
267+ Returns:
268+ `BlockState`: The prepared batch of data.
269+ """
270+ from ..modular_pipelines .modular_pipeline import BlockState
271+
272+ data_batch = {}
273+ for key , value in input_fields .items ():
274+ try :
275+ if isinstance (value , str ):
276+ data_batch [key ] = getattr (data , value )
277+ elif isinstance (value , tuple ):
278+ data_batch [key ] = getattr (data , value [tuple_index ])
279+ else :
280+ # We've already checked that value is a string or a tuple of strings with length 2
281+ pass
282+ except AttributeError :
283+ logger .debug (f"`data` does not have attribute(s) { value } , skipping." )
284+ data_batch [cls ._identifier_key ] = identifier
285+ return BlockState (** data_batch )
286+
237287 @classmethod
238288 @validate_hf_hub_args
239289 def from_pretrained (
0 commit comments