Skip to content

Commit a1a03f9

Browse files
committed
WIP: Implement widen (by a different name)
1 parent db1f944 commit a1a03f9

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

morph/nn/widen.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,50 @@
1-
1+
import torch.nn as nn
2+
3+
from morph.nn.utils import make_children_list, group_layers_by_algo
4+
5+
# TODO: refactor out width_factor
6+
def new_resize_layers(net: nn.Module):
7+
8+
old_layers = make_children_list(net.named_children())
9+
(first_name, first_layer), middle, last = group_layers_by_algo(old_layers)
10+
11+
last_out = first_layer.out_channels # count of the last layer's out features
12+
13+
new_out_next_in = int(last_out * 1.4)
14+
15+
# NOTE: is there a better way to do this part?
16+
network = nn.Module() # new network
17+
18+
network.add_module(first_name, nn.Conv2d(
19+
first_layer.in_channels, new_out_next_in, kernel_size=first_layer.kernel_size,
20+
stride=first_layer.stride
21+
))
22+
23+
# TODO: format and utilize the functions in utils for making layers
24+
for name, child in middle:
25+
# otherwise, we want to
26+
type_name = type(child).__name__
27+
if type_name in ['Conv2d', 'Linear']:
28+
29+
temp = 0
30+
if type_name == 'Conv2d':
31+
temp = int(child.out_channels * 1.4)
32+
network.add_module(name, nn.Conv2d(
33+
new_out_next_in, out_channels=temp,
34+
kernel_size=child.kernel_size, stride=child.stride
35+
))
36+
else: # type_name == 'Linear'
37+
temp = int(child.out_features * 1.4)
38+
network.add_module(name, nn.Linear(
39+
in_features=new_out_next_in, out_features=temp
40+
))
41+
42+
new_out_next_in = temp
43+
44+
last_name, last_layer = last
45+
network.add_module(last_name, nn.Conv2d(
46+
new_out_next_in, last_layer.out_channels,
47+
kernel_size=last_layer.kernel_size, stride=last_layer.stride
48+
))
49+
50+
return network

0 commit comments

Comments
 (0)