11import torch .nn as nn
22
3- from morph .nn .utils import make_children_list , group_layers_by_algo
3+ # TODO: nope. This is really long
4+ from morph .nn .utils import group_layers_by_algo , layer_is_conv2d , make_children_list , new_input_layer , new_output_layer , redo_layer , type_name , type_supported
5+
46
57# TODO: refactor out width_factor
6- def new_resize_layers (net : nn .Module ):
7-
8+ def resize_layers (net : nn .Module ):
9+
810 old_layers = make_children_list (net .named_children ())
911 (first_name , first_layer ), middle , last = group_layers_by_algo (old_layers )
10-
11- last_out = first_layer .out_channels # count of the last layer's out features
12-
13- new_out_next_in = int (last_out * 1.4 )
14-
15- # NOTE: is there a better way to do this part?
16- network = nn .Module () # new network
17-
18- network .add_module (first_name , nn .Conv2d (
19- first_layer .in_channels , new_out_next_in , kernel_size = first_layer .kernel_size ,
20- stride = first_layer .stride
21- ))
22-
12+
13+ first_layer_output_size = first_layer .out_channels # count of the last layer's out features
14+
15+ new_out_next_in = int (first_layer_output_size * 1.4 )
16+
17+ # NOTE: is there a better way to do this part? Maybe nn.Sequential?
18+ network = nn .Module () # new network
19+
20+ network .add_module (
21+ first_name ,
22+ new_input_layer (first_layer , type_name (first_layer ), out_dim = new_out_next_in ))
23+
2324 # TODO: format and utilize the functions in utils for making layers
24- for name , child in middle :
25- # otherwise, we want to
26- type_name = type (child ).__name__
27- if type_name in ['Conv2d' , 'Linear' ]:
28-
29- temp = 0
30- if type_name == 'Conv2d' :
31- temp = int (child .out_channels * 1.4 )
32- network .add_module (name , nn .Conv2d (
33- new_out_next_in , out_channels = temp ,
34- kernel_size = child .kernel_size , stride = child .stride
35- ))
36- else : # type_name == 'Linear'
37- temp = int (child .out_features * 1.4 )
38- network .add_module (name , nn .Linear (
39- in_features = new_out_next_in , out_features = temp
40- ))
41-
42- new_out_next_in = temp
43-
25+ for name , child in middle :
26+ # otherwise, we want to
27+ _t = type_name (child )
28+ if type_supported (_t ):
29+
30+ new_out = 0
31+ # TODO: look up performance on type name access. Maybe this could just be layer_is_conv2d(child)
32+ if layer_is_conv2d (_t ):
33+ new_out = int (child .out_channels * 1.4 )
34+ else : # type_name == 'Linear'
35+ new_out = int (child .out_features * 1.4 )
36+
37+ new_layer = redo_layer (child , new_in = new_out_next_in , new_out = new_out )
38+ new_out_next_in = new_out
39+ network .add_module (name , new_layer )
40+
4441 last_name , last_layer = last
45- network .add_module (last_name , nn .Conv2d (
46- new_out_next_in , last_layer .out_channels ,
47- kernel_size = last_layer .kernel_size , stride = last_layer .stride
48- ))
49-
42+ network .add_module (
43+ last_name ,
44+ new_output_layer (last_layer , type_name (last_layer ), in_dim = new_out_next_in ))
45+
5046 return network
0 commit comments