Skip to content

Commit 2cfff05

Browse files
committed
Add grad_checkpointing support to features_only, test in EfficientDet.
1 parent 45af496 commit 2cfff05

File tree

3 files changed

+123
-33
lines changed

3 files changed

+123
-33
lines changed

timm/models/_features.py

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from collections import OrderedDict, defaultdict
1212
from copy import deepcopy
1313
from functools import partial
14-
from typing import Dict, List, Tuple
14+
from typing import Dict, List, Sequence, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
1819

1920

2021
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
@@ -88,12 +89,20 @@ class FeatureHooks:
8889
""" Feature Hook Helper
8990
9091
This module helps with the setup and extraction of hooks for extracting features from
91-
internal nodes in a model by node name. This works quite well in eager Python but needs
92-
redesign for torchscript.
92+
internal nodes in a model by node name.
93+
94+
FIXME This works well in eager Python but needs redesign for torchscript.
9395
"""
9496

95-
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
97+
def __init__(
98+
self,
99+
hooks: Sequence[str],
100+
named_modules: dict,
101+
out_map: Sequence[Union[int, str]] = None,
102+
default_hook_type: str = 'forward',
103+
):
96104
# setup feature hooks
105+
self._feature_outputs = defaultdict(OrderedDict)
97106
modules = {k: v for k, v in named_modules}
98107
for i, h in enumerate(hooks):
99108
hook_name = h['module']
@@ -107,7 +116,6 @@ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forwar
107116
m.register_forward_hook(hook_fn)
108117
else:
109118
assert False, "Unsupported hook type"
110-
self._feature_outputs = defaultdict(OrderedDict)
111119

112120
def _collect_output_hook(self, hook_id, *args):
113121
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
167175
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
168176
All Sequential containers that are directly assigned to the original model will have their
169177
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
170-
171-
Arguments:
172-
model (nn.Module): model from which we will extract the features
173-
out_indices (tuple[int]): model output indices to extract features for
174-
out_map (sequence): list or tuple specifying desired return id for each out index,
175-
otherwise str(index) is used
176-
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
177-
vs select element [0]
178-
flatten_sequential (bool): whether to flatten sequential modules assigned to model
179178
"""
180179
def __init__(
181-
self, model,
182-
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
180+
self,
181+
model: nn.Module,
182+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
183+
out_map: Sequence[Union[int, str]] = None,
184+
feature_concat: bool = False,
185+
flatten_sequential: bool = False,
186+
):
187+
"""
188+
Args:
189+
model: Model from which to extract features.
190+
out_indices: Output indices of the model features to extract.
191+
out_map: Return id mapping for each output index, otherwise str(index) is used.
192+
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
193+
first element e.g. `x[0]`
194+
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
195+
"""
183196
super(FeatureDictNet, self).__init__()
184197
self.feature_info = _get_feature_info(model, out_indices)
185198
self.concat = feature_concat
199+
self.grad_checkpointing = False
186200
self.return_layers = {}
201+
187202
return_layers = _get_return_layers(self.feature_info, out_map)
188203
modules = _module_list(model, flatten_sequential=flatten_sequential)
189204
remaining = set(return_layers.keys())
@@ -200,10 +215,21 @@ def __init__(
200215
f'Return layers ({remaining}) are not present in model'
201216
self.update(layers)
202217

218+
def set_grad_checkpointing(self, enable: bool = True):
219+
self.grad_checkpointing = enable
220+
203221
def _collect(self, x) -> (Dict[str, torch.Tensor]):
204222
out = OrderedDict()
205-
for name, module in self.items():
206-
x = module(x)
223+
for i, (name, module) in enumerate(self.items()):
224+
if self.grad_checkpointing and not torch.jit.is_scripting():
225+
# Skipping checkpoint of first module because need a gradient at input
226+
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
227+
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
228+
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
229+
x = module(x) if first_or_last_module else checkpoint(module, x)
230+
else:
231+
x = module(x)
232+
207233
if name in self.return_layers:
208234
out_id = self.return_layers[name]
209235
if isinstance(x, (tuple, list)):
@@ -221,15 +247,29 @@ def forward(self, x) -> Dict[str, torch.Tensor]:
221247
class FeatureListNet(FeatureDictNet):
222248
""" Feature extractor with list return
223249
224-
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
225-
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
250+
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
226251
"""
227252
def __init__(
228-
self, model,
229-
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
253+
self,
254+
model: nn.Module,
255+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
256+
feature_concat: bool = False,
257+
flatten_sequential: bool = False,
258+
):
259+
"""
260+
Args:
261+
model: Model from which to extract features.
262+
out_indices: Output indices of the model features to extract.
263+
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
264+
first element e.g. `x[0]`
265+
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
266+
"""
230267
super(FeatureListNet, self).__init__(
231-
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
232-
flatten_sequential=flatten_sequential)
268+
model,
269+
out_indices=out_indices,
270+
feature_concat=feature_concat,
271+
flatten_sequential=flatten_sequential,
272+
)
233273

234274
def forward(self, x) -> (List[torch.Tensor]):
235275
return list(self._collect(x).values())
@@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
249289
FIXME this does not currently work with Torchscript, see FeatureHooks class
250290
"""
251291
def __init__(
252-
self, model,
253-
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
254-
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
292+
self,
293+
model: nn.Module,
294+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
295+
out_map: Sequence[Union[int, str]] = None,
296+
out_as_dict: bool = False,
297+
no_rewrite: bool = False,
298+
flatten_sequential: bool = False,
299+
default_hook_type: str = 'forward',
300+
):
301+
"""
302+
303+
Args:
304+
model: Model from which to extract features.
305+
out_indices: Output indices of the model features to extract.
306+
out_map: Return id mapping for each output index, otherwise str(index) is used.
307+
out_as_dict: Output features as a dict.
308+
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
309+
flatten_sequential arg must also be False if this is set True.
310+
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
311+
default_hook_type: The default hook type to use if not specified in model.feature_info.
312+
"""
255313
super(FeatureHookNet, self).__init__()
256314
assert not torch.jit.is_scripting()
257315
self.feature_info = _get_feature_info(model, out_indices)
258316
self.out_as_dict = out_as_dict
317+
self.grad_checkpointing = False
318+
259319
layers = OrderedDict()
260320
hooks = []
261321
if no_rewrite:
@@ -266,8 +326,10 @@ def __init__(
266326
hooks.extend(self.feature_info.get_dicts())
267327
else:
268328
modules = _module_list(model, flatten_sequential=flatten_sequential)
269-
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
270-
for f in self.feature_info.get_dicts()}
329+
remaining = {
330+
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
331+
for f in self.feature_info.get_dicts()
332+
}
271333
for new_name, old_name, module in modules:
272334
layers[new_name] = module
273335
for fn, fm in module.named_modules(prefix=old_name):
@@ -280,8 +342,18 @@ def __init__(
280342
self.update(layers)
281343
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
282344

345+
def set_grad_checkpointing(self, enable: bool = True):
346+
self.grad_checkpointing = enable
347+
283348
def forward(self, x):
284-
for name, module in self.items():
285-
x = module(x)
349+
for i, (name, module) in enumerate(self.items()):
350+
if self.grad_checkpointing and not torch.jit.is_scripting():
351+
# Skipping checkpoint of first module because need a gradient at input
352+
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
353+
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
354+
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
355+
x = module(x) if first_or_last_module else checkpoint(module, x)
356+
else:
357+
x = module(x)
286358
out = self.hooks.get_output(x.device)
287359
return out if self.out_as_dict else list(out.values())

timm/models/efficientnet.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import torch
4242
import torch.nn as nn
4343
import torch.nn.functional as F
44+
from torch.utils.checkpoint import checkpoint
4445

4546
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
4647
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
@@ -211,6 +212,7 @@ def __init__(
211212
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
212213
se_layer = se_layer or SqueezeExcite
213214
self.drop_rate = drop_rate
215+
self.grad_checkpointing = False
214216

215217
# Stem
216218
if not fix_stem:
@@ -241,6 +243,10 @@ def __init__(
241243
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
242244
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
243245

246+
@torch.jit.ignore
247+
def set_grad_checkpointing(self, enable=True):
248+
self.grad_checkpointing = enable
249+
244250
def forward(self, x) -> List[torch.Tensor]:
245251
x = self.conv_stem(x)
246252
x = self.bn1(x)
@@ -249,7 +255,10 @@ def forward(self, x) -> List[torch.Tensor]:
249255
if 0 in self._stage_out_idx:
250256
features.append(x) # add stem out
251257
for i, b in enumerate(self.blocks):
252-
x = b(x)
258+
if self.grad_checkpointing and not torch.jit.is_scripting():
259+
x = checkpoint(b, x)
260+
else:
261+
x = b(x)
253262
if i + 1 in self._stage_out_idx:
254263
features.append(x)
255264
return features

timm/models/mobilenetv3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15+
from torch.utils.checkpoint import checkpoint
1516

1617
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1718
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
@@ -188,6 +189,7 @@ def __init__(
188189
norm_layer = norm_layer or nn.BatchNorm2d
189190
se_layer = se_layer or SqueezeExcite
190191
self.drop_rate = drop_rate
192+
self.grad_checkpointing = False
191193

192194
# Stem
193195
if not fix_stem:
@@ -220,6 +222,10 @@ def __init__(
220222
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
221223
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
222224

225+
@torch.jit.ignore
226+
def set_grad_checkpointing(self, enable=True):
227+
self.grad_checkpointing = enable
228+
223229
def forward(self, x) -> List[torch.Tensor]:
224230
x = self.conv_stem(x)
225231
x = self.bn1(x)
@@ -229,7 +235,10 @@ def forward(self, x) -> List[torch.Tensor]:
229235
if 0 in self._stage_out_idx:
230236
features.append(x) # add stem out
231237
for i, b in enumerate(self.blocks):
232-
x = b(x)
238+
if self.grad_checkpointing and not torch.jit.is_scripting():
239+
x = checkpoint(b, x)
240+
else:
241+
x = b(x)
233242
if i + 1 in self._stage_out_idx:
234243
features.append(x)
235244
return features

0 commit comments

Comments
 (0)