Skip to content

Commit a9cb08a

Browse files
authored
fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled (#12562)
* fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * address review comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 9f669e7 commit a9cb08a

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,12 @@ def post_forward(self, module, output):
203203

204204
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
205205
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
206-
raise ValueError(
207-
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
206+
logger.warning_once(
207+
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
208208
)
209-
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
209+
return x
210+
else:
211+
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
210212

211213

212214
class ContextParallelGatherHook(ModelHook):

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ class WanTransformer3DModel(
555555
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
556556
},
557557
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
558+
"": {
559+
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
560+
},
558561
}
559562

560563
@register_to_config

0 commit comments

Comments
 (0)