@@ -453,14 +453,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu",
453453
454454 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
455455 # First residual block
456- x = self .resnets [0 ](x , feat_cache , feat_idx )
456+ x = self .resnets [0 ](x , feat_cache = feat_cache , feat_idx = feat_idx )
457457
458458 # Process through attention and residual blocks
459459 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
460460 if attn is not None :
461461 x = attn (x )
462462
463- x = resnet (x , feat_cache , feat_idx )
463+ x = resnet (x , feat_cache = feat_cache , feat_idx = feat_idx )
464464
465465 return x
466466
@@ -494,9 +494,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample
494494 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
495495 x_copy = x .clone ()
496496 for resnet in self .resnets :
497- x = resnet (x , feat_cache , feat_idx )
497+ x = resnet (x , feat_cache = feat_cache , feat_idx = feat_idx )
498498 if self .downsampler is not None :
499- x = self .downsampler (x , feat_cache , feat_idx )
499+ x = self .downsampler (x , feat_cache = feat_cache , feat_idx = feat_idx )
500500
501501 return x + self .avg_shortcut (x_copy )
502502
@@ -598,12 +598,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
598598 ## downsamples
599599 for layer in self .down_blocks :
600600 if feat_cache is not None :
601- x = layer (x , feat_cache , feat_idx )
601+ x = layer (x , feat_cache = feat_cache , feat_idx = feat_idx )
602602 else :
603603 x = layer (x )
604604
605605 ## middle
606- x = self .mid_block (x , feat_cache , feat_idx )
606+ x = self .mid_block (x , feat_cache = feat_cache , feat_idx = feat_idx )
607607
608608 ## head
609609 x = self .norm_out (x )
@@ -694,13 +694,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
694694
695695 for resnet in self .resnets :
696696 if feat_cache is not None :
697- x = resnet (x , feat_cache , feat_idx )
697+ x = resnet (x , feat_cache = feat_cache , feat_idx = feat_idx )
698698 else :
699699 x = resnet (x )
700700
701701 if self .upsampler is not None :
702702 if feat_cache is not None :
703- x = self .upsampler (x , feat_cache , feat_idx )
703+ x = self .upsampler (x , feat_cache = feat_cache , feat_idx = feat_idx )
704704 else :
705705 x = self .upsampler (x )
706706
@@ -767,13 +767,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
767767 """
768768 for resnet in self .resnets :
769769 if feat_cache is not None :
770- x = resnet (x , feat_cache , feat_idx )
770+ x = resnet (x , feat_cache = feat_cache , feat_idx = feat_idx )
771771 else :
772772 x = resnet (x )
773773
774774 if self .upsamplers is not None :
775775 if feat_cache is not None :
776- x = self .upsamplers [0 ](x , feat_cache , feat_idx )
776+ x = self .upsamplers [0 ](x , feat_cache = feat_cache , feat_idx = feat_idx )
777777 else :
778778 x = self .upsamplers [0 ](x )
779779 return x
@@ -885,11 +885,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
885885 x = self .conv_in (x )
886886
887887 ## middle
888- x = self .mid_block (x , feat_cache , feat_idx )
888+ x = self .mid_block (x , feat_cache = feat_cache , feat_idx = feat_idx )
889889
890890 ## upsamples
891891 for up_block in self .up_blocks :
892- x = up_block (x , feat_cache , feat_idx , first_chunk = first_chunk )
892+ x = up_block (x , feat_cache = feat_cache , feat_idx = feat_idx , first_chunk = first_chunk )
893893
894894 ## head
895895 x = self .norm_out (x )
@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
961961 """
962962
963963 _supports_gradient_checkpointing = False
964+ # keys toignore when AlignDeviceHook moves inputs/outputs between devices
965+ # these are shared mutable state modified in-place
966+ _skip_keys = ["feat_cache" , "feat_idx" ]
964967
965968 @register_to_config
966969 def __init__ (
0 commit comments