Skip to content

Commit bf8aa4c

Browse files
committed
fix missing name
1 parent dde3920 commit bf8aa4c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

simplify/propagate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __remove_nan(module, input):
3737
return input
3838

3939
@torch.no_grad()
40-
def __propagate_biases_hook(module, input, output):
40+
def __propagate_biases_hook(module, input, output, name=None):
4141
"""
4242
PyTorch hook used to propagate the biases of pruned neurons to following non-pruned layers.
4343
"""
@@ -163,7 +163,7 @@ def __propagate_biases_hook(module, input, output):
163163
if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
164164
handle = module.register_forward_pre_hook(__remove_nan)
165165
handles.append(handle)
166-
handle = module.register_forward_hook(lambda m, i, o: __propagate_biases_hook(m, i, o))
166+
handle = module.register_forward_hook(lambda m, i, o, n=name: __propagate_biases_hook(m, i, o, n))
167167
handles.append(handle)
168168

169169
# Propagate biases

0 commit comments

Comments
 (0)