File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -1257,7 +1257,7 @@ def forward_intermediates(
12571257 if attn_mask is not None :
12581258 x = blk (x , attn_mask = attn_mask )
12591259 elif self .grad_checkpointing and not torch .jit .is_scripting ():
1260- x = checkpoint (blk . x )
1260+ x = checkpoint (blk , x )
12611261 else :
12621262 x = blk (x )
12631263 if i in take_indices :
Original file line number Diff line number Diff line change 1717from timm .layers import ClassifierHead
1818from ._builder import build_model_with_cfg
1919from ._features import feature_take_indices
20- from ._manipulate import checkpoint_seq
20+ from ._manipulate import checkpoint , checkpoint_seq
2121from ._registry import generate_default_cfgs , register_model
2222
2323__all__ = ['NextViT' ]
@@ -595,7 +595,7 @@ def forward_intermediates(
595595
596596 for feat_idx , stage in enumerate (stages ):
597597 if self .grad_checkpointing and not torch .jit .is_scripting ():
598- x = checkpoint_seq (stage , x )
598+ x = checkpoint (stage , x )
599599 else :
600600 x = stage (x )
601601 if feat_idx in take_indices :
You can’t perform that action at this time.
0 commit comments