1-
1+ import torch .nn as nn
2+
3+ from morph .nn .utils import make_children_list , group_layers_by_algo
4+
5+ # TODO: refactor out width_factor
6+ def new_resize_layers (net : nn .Module ):
7+
8+ old_layers = make_children_list (net .named_children ())
9+ (first_name , first_layer ), middle , last = group_layers_by_algo (old_layers )
10+
11+ last_out = first_layer .out_channels # count of the last layer's out features
12+
13+ new_out_next_in = int (last_out * 1.4 )
14+
15+ # NOTE: is there a better way to do this part?
16+ network = nn .Module () # new network
17+
18+ network .add_module (first_name , nn .Conv2d (
19+ first_layer .in_channels , new_out_next_in , kernel_size = first_layer .kernel_size ,
20+ stride = first_layer .stride
21+ ))
22+
23+ # TODO: format and utilize the functions in utils for making layers
24+ for name , child in middle :
25+ # otherwise, we want to
26+ type_name = type (child ).__name__
27+ if type_name in ['Conv2d' , 'Linear' ]:
28+
29+ temp = 0
30+ if type_name == 'Conv2d' :
31+ temp = int (child .out_channels * 1.4 )
32+ network .add_module (name , nn .Conv2d (
33+ new_out_next_in , out_channels = temp ,
34+ kernel_size = child .kernel_size , stride = child .stride
35+ ))
36+ else : # type_name == 'Linear'
37+ temp = int (child .out_features * 1.4 )
38+ network .add_module (name , nn .Linear (
39+ in_features = new_out_next_in , out_features = temp
40+ ))
41+
42+ new_out_next_in = temp
43+
44+ last_name , last_layer = last
45+ network .add_module (last_name , nn .Conv2d (
46+ new_out_next_in , last_layer .out_channels ,
47+ kernel_size = last_layer .kernel_size , stride = last_layer .stride
48+ ))
49+
50+ return network
0 commit comments