Skip to content

Commit 16702e5

Browse files
authored
Merge pull request #9 from EIDOSLAB/v1.1.4
V1.1.4
2 parents d4de117 + d8d2ab3 commit 16702e5

File tree

10 files changed

+81
-43
lines changed

10 files changed

+81
-43
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
setuptools==45.2.0
2-
torch==1.11.0
3-
torchvision==0.12.0
2+
torch==1.12.0
3+
torchvision==0.13.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
HERE = pathlib.Path(__file__).parent
66
README = (HERE / "README.md").read_text()
77

8-
__version__ = "1.1.3"
8+
__version__ = "1.1.4"
99

1010
setup(
1111
name='torch-simplify',

simplify/layers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
import torch.nn as nn
66

77

8+
class LinearExpand(nn.Linear):
9+
@staticmethod
10+
def from_linear(module: nn.Linear, idxs: torch.Tensor, bias):
11+
module.__class__ = LinearExpand
12+
13+
module.register_parameter('bf', torch.nn.Parameter(bias.clone()))
14+
module.bf[idxs] = 0
15+
16+
module.register_buffer('idxs', idxs.to(module.weight.device))
17+
module.register_buffer('zeros', torch.zeros(bias.shape, dtype=bias.dtype, device=module.weight.device))
18+
19+
setattr(module, 'idxs_cache', module.idxs)
20+
setattr(module, 'zero_cache', module.zeros)
21+
22+
return module
23+
24+
def forward(self, x):
25+
x = super().forward(x)
26+
27+
expanded = torch.scatter(self.zeros, 0, self.idxs, x)
28+
return expanded + self.bf
29+
30+
831
class ConvB(nn.Conv2d):
932
@staticmethod
1033
def from_conv(module: nn.Conv2d, bias):
@@ -98,4 +121,5 @@ def forward(self, x):
98121
return expanded + self.bf[:, None, None].expand_as(expanded)
99122

100123
def __repr__(self):
101-
return f'BatchNormExpand({self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={self.affine}, track_running_stats={self.track_running_stats})'
124+
return f'BatchNormExpand({self.num_features}, eps={self.eps}, momentum={self.momentum}, ' \
125+
f'affine={self.affine}, track_running_stats={self.track_running_stats})'

simplify/propagate.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def propagate_bias(model: nn.Module, x: torch.Tensor, pinned_out: List) -> nn.Mo
1818
Args:
1919
model (nn.Module):
2020
x (torch.Tensor): `model`'s input of shape [1, C, N, M], same as the model usual input.
21-
pinned_out (List): List of `nn.Modules` which output needs to remain of the original shape (e.g. layers related to a residual connection with a sum operation).
21+
pinned_out (List): List of `nn.Modules` which output needs to remain of the original shape
22+
(e.g. layers related to a residual connection with a sum operation).
2223
2324
Returns:
2425
nn.Module: Model with propagated bias.
@@ -36,7 +37,7 @@ def __remove_nan(module, input):
3637
return input
3738

3839
@torch.no_grad()
39-
def __propagate_biases_hook(module, input, output):
40+
def __propagate_biases_hook(module, input, output, name=None):
4041
"""
4142
PyTorch hook used to propagate the biases of pruned neurons to following non-pruned layers.
4243
"""
@@ -47,7 +48,14 @@ def __propagate_biases_hook(module, input, output):
4748

4849
bias_feature_maps = output[0].clone()
4950

50-
if isinstance(module, nn.Conv2d):
51+
if isinstance(module, nn.Linear):
52+
# TODO: handle missing bias
53+
# For a linear layer, we can just update the scalar bias values
54+
# if getattr(module, 'bias', None) is not None:
55+
# module.bias.data = bias_feature_maps
56+
module.register_parameter('bias', nn.Parameter(bias_feature_maps))
57+
58+
elif isinstance(module, nn.Conv2d):
5159
# For a conv layer, we remove the scalar biases
5260
# and use bias matrices (ConvB)
5361
if bias_feature_maps.abs().sum() != 0.:
@@ -107,13 +115,6 @@ def __propagate_biases_hook(module, input, output):
107115
# if getattr(module, 'bias', None) is not None and module.bias.abs().sum() == 0:
108116
# module.register_parameter('bias', None)
109117

110-
elif isinstance(module, nn.Linear):
111-
# TODO: handle missing bias
112-
# For a linear layer, we can just update the scalar bias values
113-
# if getattr(module, 'bias', None) is not None:
114-
# module.bias.data = bias_feature_maps
115-
module.register_parameter('bias', nn.Parameter(bias_feature_maps))
116-
117118
else:
118119
error('Unsupported module type:', module)
119120

@@ -136,8 +137,7 @@ def __propagate_biases_hook(module, input, output):
136137
module.bias.data.mul_(~pruned_channels)
137138

138139
elif isinstance(module, nn.Conv2d):
139-
output[~pruned_channels[None, :, None,
140-
None].expand_as(output)] *= float('nan')
140+
output[~pruned_channels[None, :, None, None].expand_as(output)] *= float('nan')
141141
if isinstance(module, (ConvB, ConvExpand)):
142142
if getattr(module, 'bf', None) is not None:
143143
module.bf.data.mul_(~pruned_channels[:, None, None])
@@ -146,8 +146,7 @@ def __propagate_biases_hook(module, input, output):
146146
module.bias.data.mul_(~pruned_channels)
147147

148148
if isinstance(module, nn.BatchNorm2d):
149-
output[~pruned_channels[None, :, None,
150-
None].expand_as(output)] *= float('nan')
149+
output[~pruned_channels[None, :, None, None].expand_as(output)] *= float('nan')
151150
if isinstance(module, (BatchNormB, BatchNormExpand)):
152151
module.bf.data.mul_(~pruned_channels)
153152
else:
@@ -164,7 +163,7 @@ def __propagate_biases_hook(module, input, output):
164163
if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
165164
handle = module.register_forward_pre_hook(__remove_nan)
166165
handles.append(handle)
167-
handle = module.register_forward_hook(lambda m, i, o: __propagate_biases_hook(m, i, o))
166+
handle = module.register_forward_hook(lambda m, i, o, n=name: __propagate_biases_hook(m, i, o, n))
168167
handles.append(handle)
169168

170169
# Propagate biases

simplify/remove.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.nn as nn
88

9-
from .layers import BatchNormB, ConvExpand, BatchNormExpand
9+
from .layers import BatchNormB, ConvExpand, BatchNormExpand, LinearExpand
1010

1111

1212
@torch.no_grad()
@@ -37,16 +37,16 @@ def __remove_zeroed_channels_hook(module, input, output, name):
3737
nonzero_idx = ~(input.view(input.shape[0], -1).sum(dim=1) == 0)
3838
# print('input:', input.shape)
3939

40-
if isinstance(module, nn.Conv2d):
40+
if isinstance(module, nn.Linear):
41+
module.weight = nn.Parameter(module.weight[:, nonzero_idx])
42+
module.in_features = module.weight.shape[1]
43+
44+
elif isinstance(module, nn.Conv2d):
4145
if module.groups == 1:
4246
module.weight = nn.Parameter(module.weight[:, nonzero_idx])
4347
module.in_channels = module.weight.shape[1]
4448
# TODO: handle when groups > 1 (if possible)
4549

46-
elif isinstance(module, nn.Linear):
47-
module.weight = nn.Parameter(module.weight[:, nonzero_idx])
48-
module.in_features = module.weight.shape[1]
49-
5050
elif isinstance(module, nn.BatchNorm2d):
5151
module.weight.data.mul_(nonzero_idx)
5252
module.running_mean.data.mul_(nonzero_idx)
@@ -104,24 +104,27 @@ def __remove_zeroed_channels_hook(module, input, output, name):
104104
module.running_mean = module.running_mean[nonzero_idx]
105105
module.running_var = module.running_var[nonzero_idx]
106106

107-
# 3. If it is a pinned layer, convert it into ConvExpand or BatchNormExpand
107+
# 3. If it is a pinned layer, convert it into LinearExpand, ConvExpand or BatchNormExpand
108108
if name in pinned_out:
109109
idxs = torch.where(nonzero_idx)[0]
110110

111+
if isinstance(module, nn.Linear):
112+
module = LinearExpand.from_linear(module, idxs, module.bias)
113+
111114
# Keep bias (bf) full size
112-
if isinstance(module, nn.Conv2d):
115+
elif isinstance(module, nn.Conv2d):
113116
module_bf = getattr(module, 'bf', None)
114117
if module_bf is None:
115118
module_bf = torch.zeros_like(output[0])
116119

117120
module = ConvExpand.from_conv(module, idxs, module_bf)
118121

119-
if isinstance(module, BatchNormB):
120-
module = BatchNormExpand.from_bn(module, idxs, module.bf, output.shape)
121-
122122
elif isinstance(module, nn.BatchNorm2d):
123-
module = BatchNormExpand.from_bn(module, idxs, module.bias, output.shape)
124-
module.register_parameter("bias", None)
123+
bias = module.bf if isinstance(module, BatchNormB) else module.bias
124+
module = BatchNormExpand.from_bn(module, idxs, bias, output.shape)
125+
126+
if not isinstance(module, BatchNormB):
127+
module.register_parameter("bias", None)
125128
else:
126129
if getattr(module, 'bf', None) is not None:
127130
module.bf = nn.Parameter(module.bf[nonzero_idx])

simplify/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def get_previous_layer(connections: Dict, module: fx.Node) -> fx.Node:
9494

9595
def get_pinned(model: torch.nn.Module) -> List[str]:
9696
"""
97-
Try to find all the modules for which the output shape needs to stay fixed, (e.g. modules involved in residual connections with a sum).
97+
Try to find all the modules for which the output shape needs to stay fixed,
98+
(e.g. modules involved in residual connections with a sum).
9899
99100
Args:
100101
model (torch.nn.Module): The model on which to perform the research.

test/modules/fuse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ def test_arch(arch, x):
2626
x = im / 255.
2727

2828
for architecture in models:
29+
print(f"Testing with {architecture.__name__}")
2930
with self.subTest(arch=architecture):
3031
self.assertTrue(test_arch(architecture, x))

test/modules/propagate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class Test(unittest.TestCase):
1313
def test(self):
1414
@torch.no_grad()
1515
def test_arch(arch, x, fuse_bn):
16+
print(f"Fuse: {fuse_bn}")
17+
1618
model = get_model(architecture, arch)
1719

1820
if fuse_bn:
@@ -31,6 +33,8 @@ def test_arch(arch, x, fuse_bn):
3133
x = im / 255.
3234

3335
for architecture in models:
36+
print(f"Testing with {architecture.__name__}")
37+
3438
with self.subTest(arch=architecture, fuse_bn=True):
3539
self.assertTrue(test_arch(architecture, x, fuse_bn=True))
3640

test/modules/remove.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class Test(unittest.TestCase):
1313
def test(self):
1414
@torch.no_grad()
1515
def test_arch(arch, x, fuse_bn):
16+
print(f"Fuse: {fuse_bn}")
17+
1618
model = get_model(architecture, arch)
1719

1820
if fuse_bn:
@@ -33,6 +35,8 @@ def test_arch(arch, x, fuse_bn):
3335
x = im / 255.
3436

3537
for architecture in models:
38+
print(f"Testing with {architecture.__name__}")
39+
3640
with self.subTest(arch=architecture, fuse_bn=True):
3741
self.assertTrue(test_arch(architecture, x, fuse_bn=True))
3842

test/utils.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313

1414
models = [
1515
alexnet,
16-
vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn,
17-
resnet18, resnet34, resnet50, resnet101, resnet152,
18-
squeezenet1_0, squeezenet1_1,
19-
densenet121, densenet161, densenet169, densenet201,
16+
vgg11, vgg11_bn,
17+
resnet18, resnet50,
18+
squeezenet1_0,
19+
densenet121,
2020
inception_v3,
2121
googlenet,
22-
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0,
23-
mobilenet_v2, mobilenet_v3_small, mobilenet_v3_large,
24-
resnext50_32x4d, resnext101_32x8d,
25-
wide_resnet50_2, wide_resnet101_2,
26-
mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3,
22+
shufflenet_v2_x0_5,
23+
mobilenet_v2, mobilenet_v3_small,
24+
resnext50_32x4d,
25+
wide_resnet50_2,
26+
mnasnet0_5, mnasnet1_0,
2727
densenet121
2828
]
2929

@@ -41,6 +41,8 @@ def get_model(architecture, arch):
4141
if isinstance(model, SqueezeNet) and 'classifier.1' in name:
4242
continue
4343

44-
if isinstance(module, nn.Conv2d):
44+
if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
4545
prune.random_structured(module, 'weight', amount=0.8, dim=0)
4646
prune.remove(module, 'weight')
47+
48+
return model

0 commit comments

Comments
 (0)