Skip to content

Commit b6692ed

Browse files
committed
Fix several grad checkpointing issues
1 parent 869bac2 commit b6692ed

18 files changed

+28
-27
lines changed

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def forward_intermediates(
261261
blocks = self.blocks[:max_index]
262262
for feat_idx, blk in enumerate(blocks, start=1):
263263
if self.grad_checkpointing and not torch.jit.is_scripting():
264-
x = checkpoint_seq(blk, x, flatten=True)
264+
x = checkpoint_seq(blk, x)
265265
else:
266266
x = blk(x)
267267
if feat_idx in take_indices:

timm/models/fasternet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
def forward(self, x: torch.Tensor) -> torch.Tensor:
143143
x = self.downsample(x)
144144
if self.grad_checkpointing and not torch.jit.is_scripting():
145-
x = checkpoint_seq(self.blocks, x, flatten=True)
145+
x = checkpoint_seq(self.blocks, x)
146146
else:
147147
x = self.blocks(x)
148148
return x

timm/models/focalnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(self, x):
274274
x = self.downsample(x)
275275
for blk in self.blocks:
276276
if self.grad_checkpointing and not torch.jit.is_scripting():
277-
x = checkpoint.checkpoint(blk, x)
277+
x = checkpoint(blk, x)
278278
else:
279279
x = blk(x)
280280
return x

timm/models/gcvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def forward(self, x):
361361
global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
362362
for blk in self.blocks:
363363
if self.grad_checkpointing and not torch.jit.is_scripting():
364-
x = checkpoint.checkpoint(blk, x)
364+
x = checkpoint(blk, x, global_query)
365365
else:
366366
x = blk(x, global_query)
367367
x = self.norm(x)

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def forward_intermediates(
728728

729729
for feat_idx, stage in enumerate(stages, start=1):
730730
if self.grad_checkpointing and not torch.jit.is_scripting():
731-
x = checkpoint_seq(stage, x, flatten=True)
731+
x = checkpoint_seq(stage, x)
732732
else:
733733
x = stage(x)
734734
if feat_idx in take_indices:

timm/models/hgnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def __init__(
345345
def forward(self, x):
346346
x = self.downsample(x)
347347
if self.grad_checkpointing and not torch.jit.is_scripting():
348-
x = checkpoint_seq(self.blocks, x, flatten=False)
348+
x = checkpoint_seq(self.blocks, x)
349349
else:
350350
x = self.blocks(x)
351351
return x

timm/models/hieradet_sam2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(
288288
norm_layer = get_norm_layer(norm_layer)
289289
act_layer = get_act_layer(act_layer)
290290
assert len(stages) == len(window_spec)
291+
self.grad_checkpointing = False
291292
self.num_classes = num_classes
292293
self.window_spec = window_spec
293294
self.output_fmt = 'NHWC'

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def forward_intermediates(
229229
blocks = self.blocks[:max_index]
230230
for feat_idx, blk in enumerate(blocks, start=1):
231231
if self.grad_checkpointing and not torch.jit.is_scripting():
232-
x = checkpoint_seq(blk, x, flatten=True)
232+
x = checkpoint_seq(blk, x)
233233
else:
234234
x = blk(x)
235235
if feat_idx in take_indices:

timm/models/mvitv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def __init__(
681681
def forward(self, x, feat_size: List[int]):
682682
for blk in self.blocks:
683683
if self.grad_checkpointing and not torch.jit.is_scripting():
684-
x, feat_size = checkpoint.checkpoint(blk, x, feat_size)
684+
x, feat_size = checkpoint(blk, x, feat_size)
685685
else:
686686
x, feat_size = blk(x, feat_size)
687687
return x, feat_size

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def forward(self, x):
267267
x = x.reshape(B, -1, C)
268268
for blk in self.blocks:
269269
if self.grad_checkpointing and not torch.jit.is_scripting():
270-
x = checkpoint.checkpoint(blk, x, feat_size)
270+
x = checkpoint(blk, x, feat_size)
271271
else:
272272
x = blk(x, feat_size)
273273
x = self.norm(x)

0 commit comments

Comments
 (0)