Skip to content

Commit bec2d8e

Browse files
authored
Fix: Add _skip_keys for AutoencoderKLWan (#12523)
add
1 parent a0a51eb commit bec2d8e

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu",
453453

454454
def forward(self, x, feat_cache=None, feat_idx=[0]):
455455
# First residual block
456-
x = self.resnets[0](x, feat_cache, feat_idx)
456+
x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
457457

458458
# Process through attention and residual blocks
459459
for attn, resnet in zip(self.attentions, self.resnets[1:]):
460460
if attn is not None:
461461
x = attn(x)
462462

463-
x = resnet(x, feat_cache, feat_idx)
463+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
464464

465465
return x
466466

@@ -494,9 +494,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample
494494
def forward(self, x, feat_cache=None, feat_idx=[0]):
495495
x_copy = x.clone()
496496
for resnet in self.resnets:
497-
x = resnet(x, feat_cache, feat_idx)
497+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
498498
if self.downsampler is not None:
499-
x = self.downsampler(x, feat_cache, feat_idx)
499+
x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
500500

501501
return x + self.avg_shortcut(x_copy)
502502

@@ -598,12 +598,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
598598
## downsamples
599599
for layer in self.down_blocks:
600600
if feat_cache is not None:
601-
x = layer(x, feat_cache, feat_idx)
601+
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
602602
else:
603603
x = layer(x)
604604

605605
## middle
606-
x = self.mid_block(x, feat_cache, feat_idx)
606+
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
607607

608608
## head
609609
x = self.norm_out(x)
@@ -694,13 +694,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
694694

695695
for resnet in self.resnets:
696696
if feat_cache is not None:
697-
x = resnet(x, feat_cache, feat_idx)
697+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
698698
else:
699699
x = resnet(x)
700700

701701
if self.upsampler is not None:
702702
if feat_cache is not None:
703-
x = self.upsampler(x, feat_cache, feat_idx)
703+
x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
704704
else:
705705
x = self.upsampler(x)
706706

@@ -767,13 +767,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
767767
"""
768768
for resnet in self.resnets:
769769
if feat_cache is not None:
770-
x = resnet(x, feat_cache, feat_idx)
770+
x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
771771
else:
772772
x = resnet(x)
773773

774774
if self.upsamplers is not None:
775775
if feat_cache is not None:
776-
x = self.upsamplers[0](x, feat_cache, feat_idx)
776+
x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
777777
else:
778778
x = self.upsamplers[0](x)
779779
return x
@@ -885,11 +885,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
885885
x = self.conv_in(x)
886886

887887
## middle
888-
x = self.mid_block(x, feat_cache, feat_idx)
888+
x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
889889

890890
## upsamples
891891
for up_block in self.up_blocks:
892-
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
892+
x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
893893

894894
## head
895895
x = self.norm_out(x)
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
961961
"""
962962

963963
_supports_gradient_checkpointing = False
964+
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
965+
# these are shared mutable state modified in-place
966+
_skip_keys = ["feat_cache", "feat_idx"]
964967

965968
@register_to_config
966969
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
251251
_repeated_blocks = []
252252
_parallel_config = None
253253
_cp_plan = None
254+
_skip_keys = None
254255

255256
def __init__(self):
256257
super().__init__()

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,9 @@ def load_sub_model(
866866
# remove hooks
867867
remove_hook_from_module(loaded_sub_model, recurse=True)
868868
needs_offloading_to_cpu = device_map[""] == "cpu"
869+
skip_keys = None
870+
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
871+
skip_keys = loaded_sub_model._skip_keys
869872

870873
if needs_offloading_to_cpu:
871874
dispatch_model(
@@ -874,9 +877,10 @@ def load_sub_model(
874877
device_map=device_map,
875878
force_hooks=True,
876879
main_device=0,
880+
skip_keys=skip_keys,
877881
)
878882
else:
879-
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
883+
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
880884

881885
return loaded_sub_model
882886

0 commit comments

Comments
 (0)