@@ -18,7 +18,8 @@ def propagate_bias(model: nn.Module, x: torch.Tensor, pinned_out: List) -> nn.Mo
1818 Args:
1919 model (nn.Module):
2020 x (torch.Tensor): `model`'s input of shape [1, C, N, M], same as the model usual input.
21- pinned_out (List): List of `nn.Modules` which output needs to remain of the original shape (e.g. layers related to a residual connection with a sum operation).
21+ pinned_out (List): List of `nn.Modules` which output needs to remain of the original shape
22+ (e.g. layers related to a residual connection with a sum operation).
2223
2324 Returns:
2425 nn.Module: Model with propagated bias.
@@ -36,7 +37,7 @@ def __remove_nan(module, input):
3637 return input
3738
3839 @torch .no_grad ()
39- def __propagate_biases_hook (module , input , output ):
40+ def __propagate_biases_hook (module , input , output , name = None ):
4041 """
4142 PyTorch hook used to propagate the biases of pruned neurons to following non-pruned layers.
4243 """
@@ -47,7 +48,14 @@ def __propagate_biases_hook(module, input, output):
4748
4849 bias_feature_maps = output [0 ].clone ()
4950
50- if isinstance (module , nn .Conv2d ):
51+ if isinstance (module , nn .Linear ):
52+ # TODO: handle missing bias
53+ # For a linear layer, we can just update the scalar bias values
54+ # if getattr(module, 'bias', None) is not None:
55+ # module.bias.data = bias_feature_maps
56+ module .register_parameter ('bias' , nn .Parameter (bias_feature_maps ))
57+
58+ elif isinstance (module , nn .Conv2d ):
5159 # For a conv layer, we remove the scalar biases
5260 # and use bias matrices (ConvB)
5361 if bias_feature_maps .abs ().sum () != 0. :
@@ -107,13 +115,6 @@ def __propagate_biases_hook(module, input, output):
107115 # if getattr(module, 'bias', None) is not None and module.bias.abs().sum() == 0:
108116 # module.register_parameter('bias', None)
109117
110- elif isinstance (module , nn .Linear ):
111- # TODO: handle missing bias
112- # For a linear layer, we can just update the scalar bias values
113- # if getattr(module, 'bias', None) is not None:
114- # module.bias.data = bias_feature_maps
115- module .register_parameter ('bias' , nn .Parameter (bias_feature_maps ))
116-
117118 else :
118119 error ('Unsupported module type:' , module )
119120
@@ -136,8 +137,7 @@ def __propagate_biases_hook(module, input, output):
136137 module .bias .data .mul_ (~ pruned_channels )
137138
138139 elif isinstance (module , nn .Conv2d ):
139- output [~ pruned_channels [None , :, None ,
140- None ].expand_as (output )] *= float ('nan' )
140+ output [~ pruned_channels [None , :, None , None ].expand_as (output )] *= float ('nan' )
141141 if isinstance (module , (ConvB , ConvExpand )):
142142 if getattr (module , 'bf' , None ) is not None :
143143 module .bf .data .mul_ (~ pruned_channels [:, None , None ])
@@ -146,8 +146,7 @@ def __propagate_biases_hook(module, input, output):
146146 module .bias .data .mul_ (~ pruned_channels )
147147
148148 if isinstance (module , nn .BatchNorm2d ):
149- output [~ pruned_channels [None , :, None ,
150- None ].expand_as (output )] *= float ('nan' )
149+ output [~ pruned_channels [None , :, None , None ].expand_as (output )] *= float ('nan' )
151150 if isinstance (module , (BatchNormB , BatchNormExpand )):
152151 module .bf .data .mul_ (~ pruned_channels )
153152 else :
@@ -164,7 +163,7 @@ def __propagate_biases_hook(module, input, output):
164163 if isinstance (module , (nn .Conv2d , nn .Linear , nn .BatchNorm2d )):
165164 handle = module .register_forward_pre_hook (__remove_nan )
166165 handles .append (handle )
167- handle = module .register_forward_hook (lambda m , i , o : __propagate_biases_hook (m , i , o ))
166+ handle = module .register_forward_hook (lambda m , i , o , n = name : __propagate_biases_hook (m , i , o , n ))
168167 handles .append (handle )
169168
170169 # Propagate biases
0 commit comments