1111from collections import OrderedDict , defaultdict
1212from copy import deepcopy
1313from functools import partial
14- from typing import Dict , List , Tuple
14+ from typing import Dict , List , Sequence , Tuple , Union
1515
1616import torch
1717import torch .nn as nn
18+ from torch .utils .checkpoint import checkpoint
1819
1920
2021__all__ = ['FeatureInfo' , 'FeatureHooks' , 'FeatureDictNet' , 'FeatureListNet' , 'FeatureHookNet' ]
@@ -88,12 +89,20 @@ class FeatureHooks:
8889 """ Feature Hook Helper
8990
9091 This module helps with the setup and extraction of hooks for extracting features from
91- internal nodes in a model by node name. This works quite well in eager Python but needs
92- redesign for torchscript.
92+ internal nodes in a model by node name.
93+
94+ FIXME This works well in eager Python but needs redesign for torchscript.
9395 """
9496
95- def __init__ (self , hooks , named_modules , out_map = None , default_hook_type = 'forward' ):
97+ def __init__ (
98+ self ,
99+ hooks : Sequence [str ],
100+ named_modules : dict ,
101+ out_map : Sequence [Union [int , str ]] = None ,
102+ default_hook_type : str = 'forward' ,
103+ ):
96104 # setup feature hooks
105+ self ._feature_outputs = defaultdict (OrderedDict )
97106 modules = {k : v for k , v in named_modules }
98107 for i , h in enumerate (hooks ):
99108 hook_name = h ['module' ]
@@ -107,7 +116,6 @@ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forwar
107116 m .register_forward_hook (hook_fn )
108117 else :
109118 assert False , "Unsupported hook type"
110- self ._feature_outputs = defaultdict (OrderedDict )
111119
112120 def _collect_output_hook (self , hook_id , * args ):
113121 x = args [- 1 ] # tensor we want is last argument, output for fwd, input for fwd_pre
@@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
167175 one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
168176 All Sequential containers that are directly assigned to the original model will have their
169177 modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
170-
171- Arguments:
172- model (nn.Module): model from which we will extract the features
173- out_indices (tuple[int]): model output indices to extract features for
174- out_map (sequence): list or tuple specifying desired return id for each out index,
175- otherwise str(index) is used
176- feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
177- vs select element [0]
178- flatten_sequential (bool): whether to flatten sequential modules assigned to model
179178 """
180179 def __init__ (
181- self , model ,
182- out_indices = (0 , 1 , 2 , 3 , 4 ), out_map = None , feature_concat = False , flatten_sequential = False ):
180+ self ,
181+ model : nn .Module ,
182+ out_indices : Tuple [int , ...] = (0 , 1 , 2 , 3 , 4 ),
183+ out_map : Sequence [Union [int , str ]] = None ,
184+ feature_concat : bool = False ,
185+ flatten_sequential : bool = False ,
186+ ):
187+ """
188+ Args:
189+ model: Model from which to extract features.
190+ out_indices: Output indices of the model features to extract.
191+ out_map: Return id mapping for each output index, otherwise str(index) is used.
192+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
193+ first element e.g. `x[0]`
194+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
195+ """
183196 super (FeatureDictNet , self ).__init__ ()
184197 self .feature_info = _get_feature_info (model , out_indices )
185198 self .concat = feature_concat
199+ self .grad_checkpointing = False
186200 self .return_layers = {}
201+
187202 return_layers = _get_return_layers (self .feature_info , out_map )
188203 modules = _module_list (model , flatten_sequential = flatten_sequential )
189204 remaining = set (return_layers .keys ())
@@ -200,10 +215,21 @@ def __init__(
200215 f'Return layers ({ remaining } ) are not present in model'
201216 self .update (layers )
202217
218+ def set_grad_checkpointing (self , enable : bool = True ):
219+ self .grad_checkpointing = enable
220+
203221 def _collect (self , x ) -> (Dict [str , torch .Tensor ]):
204222 out = OrderedDict ()
205- for name , module in self .items ():
206- x = module (x )
223+ for i , (name , module ) in enumerate (self .items ()):
224+ if self .grad_checkpointing and not torch .jit .is_scripting ():
225+ # Skipping checkpoint of first module because need a gradient at input
226+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
227+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
228+ first_or_last_module = i == 0 or i == max (len (self ) - 1 , 0 )
229+ x = module (x ) if first_or_last_module else checkpoint (module , x )
230+ else :
231+ x = module (x )
232+
207233 if name in self .return_layers :
208234 out_id = self .return_layers [name ]
209235 if isinstance (x , (tuple , list )):
@@ -221,15 +247,29 @@ def forward(self, x) -> Dict[str, torch.Tensor]:
221247class FeatureListNet (FeatureDictNet ):
222248 """ Feature extractor with list return
223249
224- See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
225- In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
250+ A specialization of FeatureDictNet that always returns features as a list (values() of dict).
226251 """
227252 def __init__ (
228- self , model ,
229- out_indices = (0 , 1 , 2 , 3 , 4 ), out_map = None , feature_concat = False , flatten_sequential = False ):
253+ self ,
254+ model : nn .Module ,
255+ out_indices : Tuple [int , ...] = (0 , 1 , 2 , 3 , 4 ),
256+ feature_concat : bool = False ,
257+ flatten_sequential : bool = False ,
258+ ):
259+ """
260+ Args:
261+ model: Model from which to extract features.
262+ out_indices: Output indices of the model features to extract.
263+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
264+ first element e.g. `x[0]`
265+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
266+ """
230267 super (FeatureListNet , self ).__init__ (
231- model , out_indices = out_indices , out_map = out_map , feature_concat = feature_concat ,
232- flatten_sequential = flatten_sequential )
268+ model ,
269+ out_indices = out_indices ,
270+ feature_concat = feature_concat ,
271+ flatten_sequential = flatten_sequential ,
272+ )
233273
234274 def forward (self , x ) -> (List [torch .Tensor ]):
235275 return list (self ._collect (x ).values ())
@@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
249289 FIXME this does not currently work with Torchscript, see FeatureHooks class
250290 """
251291 def __init__ (
252- self , model ,
253- out_indices = (0 , 1 , 2 , 3 , 4 ), out_map = None , out_as_dict = False , no_rewrite = False ,
254- feature_concat = False , flatten_sequential = False , default_hook_type = 'forward' ):
292+ self ,
293+ model : nn .Module ,
294+ out_indices : Tuple [int , ...] = (0 , 1 , 2 , 3 , 4 ),
295+ out_map : Sequence [Union [int , str ]] = None ,
296+ out_as_dict : bool = False ,
297+ no_rewrite : bool = False ,
298+ flatten_sequential : bool = False ,
299+ default_hook_type : str = 'forward' ,
300+ ):
301+ """
302+
303+ Args:
304+ model: Model from which to extract features.
305+ out_indices: Output indices of the model features to extract.
306+ out_map: Return id mapping for each output index, otherwise str(index) is used.
307+ out_as_dict: Output features as a dict.
308+ no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
309+ flatten_sequential arg must also be False if this is set True.
310+ flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
311+ default_hook_type: The default hook type to use if not specified in model.feature_info.
312+ """
255313 super (FeatureHookNet , self ).__init__ ()
256314 assert not torch .jit .is_scripting ()
257315 self .feature_info = _get_feature_info (model , out_indices )
258316 self .out_as_dict = out_as_dict
317+ self .grad_checkpointing = False
318+
259319 layers = OrderedDict ()
260320 hooks = []
261321 if no_rewrite :
@@ -266,8 +326,10 @@ def __init__(
266326 hooks .extend (self .feature_info .get_dicts ())
267327 else :
268328 modules = _module_list (model , flatten_sequential = flatten_sequential )
269- remaining = {f ['module' ]: f ['hook_type' ] if 'hook_type' in f else default_hook_type
270- for f in self .feature_info .get_dicts ()}
329+ remaining = {
330+ f ['module' ]: f ['hook_type' ] if 'hook_type' in f else default_hook_type
331+ for f in self .feature_info .get_dicts ()
332+ }
271333 for new_name , old_name , module in modules :
272334 layers [new_name ] = module
273335 for fn , fm in module .named_modules (prefix = old_name ):
@@ -280,8 +342,18 @@ def __init__(
280342 self .update (layers )
281343 self .hooks = FeatureHooks (hooks , model .named_modules (), out_map = out_map )
282344
345+ def set_grad_checkpointing (self , enable : bool = True ):
346+ self .grad_checkpointing = enable
347+
283348 def forward (self , x ):
284- for name , module in self .items ():
285- x = module (x )
349+ for i , (name , module ) in enumerate (self .items ()):
350+ if self .grad_checkpointing and not torch .jit .is_scripting ():
351+ # Skipping checkpoint of first module because need a gradient at input
352+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
353+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
354+ first_or_last_module = i == 0 or i == max (len (self ) - 1 , 0 )
355+ x = module (x ) if first_or_last_module else checkpoint (module , x )
356+ else :
357+ x = module (x )
286358 out = self .hooks .get_output (x .device )
287359 return out if self .out_as_dict else list (out .values ())
0 commit comments