Skip to content

Commit b455dc9

Browse files
authored
[modular] wan! (#12611)
* update, remove intermediaate_inputs * support image2video * revert dynamic steps to simplify * refactor vae encoder block * support flf2video! * add support for wan2.2 14B * style * Apply suggestions from code review * input dynamic step -> additiional input step * up * fix init * update dtype
1 parent 04f9d2b commit b455dc9

23 files changed

+1996
-363
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@
407407
"QwenImageModularPipeline",
408408
"StableDiffusionXLAutoBlocks",
409409
"StableDiffusionXLModularPipeline",
410+
"Wan22AutoBlocks",
410411
"WanAutoBlocks",
411412
"WanModularPipeline",
412413
]
@@ -1090,6 +1091,7 @@
10901091
QwenImageModularPipeline,
10911092
StableDiffusionXLAutoBlocks,
10921093
StableDiffusionXLModularPipeline,
1094+
Wan22AutoBlocks,
10931095
WanAutoBlocks,
10941096
WanModularPipeline,
10951097
)

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -88,6 +88,19 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
8888
data_batches.append(data_batch)
8989
return data_batches
9090

91+
def prepare_inputs_from_block_state(
92+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
93+
) -> List["BlockState"]:
94+
if self._step == 0:
95+
if self.adaptive_projected_guidance_momentum is not None:
96+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
97+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
98+
data_batches = []
99+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
100+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
101+
data_batches.append(data_batch)
102+
return data_batches
103+
91104
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
92105
pred = None
93106

src/diffusers/guiders/adaptive_projected_guidance_mix.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -99,6 +99,19 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
9999
data_batches.append(data_batch)
100100
return data_batches
101101

102+
def prepare_inputs_from_block_state(
103+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
104+
) -> List["BlockState"]:
105+
if self._step == 0:
106+
if self.adaptive_projected_guidance_momentum is not None:
107+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
108+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
109+
data_batches = []
110+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
111+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
112+
data_batches.append(data_batch)
113+
return data_batches
114+
102115
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
103116
pred = None
104117

src/diffusers/guiders/auto_guidance.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
141141
data_batches.append(data_batch)
142142
return data_batches
143143

144+
def prepare_inputs_from_block_state(
145+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
146+
) -> List["BlockState"]:
147+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
148+
data_batches = []
149+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
150+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
151+
data_batches.append(data_batch)
152+
return data_batches
153+
144154
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
145155
pred = None
146156

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -99,6 +99,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
9999
data_batches.append(data_batch)
100100
return data_batches
101101

102+
def prepare_inputs_from_block_state(
103+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
104+
) -> List["BlockState"]:
105+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
106+
data_batches = []
107+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
108+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
109+
data_batches.append(data_batch)
110+
return data_batches
111+
102112
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
103113
pred = None
104114

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -85,6 +85,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
8585
data_batches.append(data_batch)
8686
return data_batches
8787

88+
def prepare_inputs_from_block_state(
89+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
90+
) -> List["BlockState"]:
91+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
92+
data_batches = []
93+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
94+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
95+
data_batches.append(data_batch)
96+
return data_batches
97+
8898
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
8999
pred = None
90100

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,16 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
226226
data_batches.append(data_batch)
227227
return data_batches
228228

229+
def prepare_inputs_from_block_state(
230+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
231+
) -> List["BlockState"]:
232+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
233+
data_batches = []
234+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
235+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
236+
data_batches.append(data_batch)
237+
return data_batches
238+
229239
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
230240
pred = None
231241

src/diffusers/guiders/guider_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
166166
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
167167
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
168168

169+
def prepare_inputs_from_block_state(
170+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
171+
) -> List["BlockState"]:
172+
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
173+
169174
def __call__(self, data: List["BlockState"]) -> Any:
170175
if not all(hasattr(d, "noise_pred") for d in data):
171176
raise ValueError("Expected all data to have `noise_pred` attribute.")
@@ -234,6 +239,51 @@ def _prepare_batch(
234239
data_batch[cls._identifier_key] = identifier
235240
return BlockState(**data_batch)
236241

242+
@classmethod
243+
def _prepare_batch_from_block_state(
244+
cls,
245+
input_fields: Dict[str, Union[str, Tuple[str, str]]],
246+
data: "BlockState",
247+
tuple_index: int,
248+
identifier: str,
249+
) -> "BlockState":
250+
"""
251+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
252+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
253+
254+
Args:
255+
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
256+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
257+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
258+
to look up the required data provided for preparation. If a string is provided, it will be used as the
259+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
260+
length 2 is provided, the first element must be the conditional data identifier and the second element
261+
must be the unconditional data identifier or None.
262+
data (`BlockState`):
263+
The input data to be prepared.
264+
tuple_index (`int`):
265+
The index to use when accessing input fields that are tuples.
266+
267+
Returns:
268+
`BlockState`: The prepared batch of data.
269+
"""
270+
from ..modular_pipelines.modular_pipeline import BlockState
271+
272+
data_batch = {}
273+
for key, value in input_fields.items():
274+
try:
275+
if isinstance(value, str):
276+
data_batch[key] = getattr(data, value)
277+
elif isinstance(value, tuple):
278+
data_batch[key] = getattr(data, value[tuple_index])
279+
else:
280+
# We've already checked that value is a string or a tuple of strings with length 2
281+
pass
282+
except AttributeError:
283+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
284+
data_batch[cls._identifier_key] = identifier
285+
return BlockState(**data_batch)
286+
237287
@classmethod
238288
@validate_hf_hub_args
239289
def from_pretrained(

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,26 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
187187
data_batches.append(data_batch)
188188
return data_batches
189189

190+
def prepare_inputs_from_block_state(
191+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
192+
) -> List["BlockState"]:
193+
if self.num_conditions == 1:
194+
tuple_indices = [0]
195+
input_predictions = ["pred_cond"]
196+
elif self.num_conditions == 2:
197+
tuple_indices = [0, 1]
198+
input_predictions = (
199+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
200+
)
201+
else:
202+
tuple_indices = [0, 1, 0]
203+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
204+
data_batches = []
205+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
206+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
207+
data_batches.append(data_batch)
208+
return data_batches
209+
190210
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
191211
def forward(
192212
self,

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,26 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
183183
data_batches.append(data_batch)
184184
return data_batches
185185

186+
def prepare_inputs_from_block_state(
187+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
188+
) -> List["BlockState"]:
189+
if self.num_conditions == 1:
190+
tuple_indices = [0]
191+
input_predictions = ["pred_cond"]
192+
elif self.num_conditions == 2:
193+
tuple_indices = [0, 1]
194+
input_predictions = (
195+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
196+
)
197+
else:
198+
tuple_indices = [0, 1, 0]
199+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
200+
data_batches = []
201+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
202+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
203+
data_batches.append(data_batch)
204+
return data_batches
205+
186206
def forward(
187207
self,
188208
pred_cond: torch.Tensor,

0 commit comments

Comments
 (0)