11import torch .nn as nn
22
3+
34def layer_has_bias (layer : nn .Module ) -> bool :
45 return not layer .bias is None
56
7+
68def make_children_list (children_or_named_children ):
79 """Receives `nn.Module.children()` or `nn.Module.named_children()`.
810 Returns: that generator collected as a list
911 """
1012 return [c for c in children_or_named_children ]
1113
14+
1215#################### NEW LAYERS ####################
1316
14- def new_layer (base_layer : nn .Module , type_name : str , in_dim : int , out_dim : int ) -> nn .Module :
17+
18+ def new_layer (base_layer : nn .Module , type_name : str , in_dim : int ,
19+ out_dim : int ) -> nn .Module :
1520
1621 has_bias = layer_has_bias (base_layer )
1722
1823 if layer_is_linear (type_name ):
1924 return nn .Linear (in_dim , out_dim , bias = has_bias )
20-
25+
2126 if layer_is_conv2d (type_name ):
22- return nn .Conv2d (in_dim , out_dim , kernel_size = base_layer .kernel_size ,
23- stride = base_layer .stride , bias = has_bias )
27+ return nn .Conv2d (
28+ in_dim ,
29+ out_dim ,
30+ kernel_size = base_layer .kernel_size ,
31+ stride = base_layer .stride ,
32+ bias = has_bias )
2433
2534 raise ValueError ('User got around type check ;)' )
2635
36+
2737def new_input_layer (base_layer : nn .Module , type_name : str , out_dim : int ) -> nn .Module :
2838 has_bias = layer_has_bias (base_layer )
29-
39+
3040 if layer_is_linear (type_name ):
3141 return nn .Linear (base_layer .in_features , out_features = out_dim , bias = has_bias )
32-
42+
3343 if layer_is_conv2d (type_name ):
34- return nn .Conv2d (base_layer .in_channels , out_channels = out_dim ,
35- kernel_size = base_layer .kernel_size , stride = base_layer .stride , bias = has_bias )
44+ return nn .Conv2d (
45+ base_layer .in_channels ,
46+ out_channels = out_dim ,
47+ kernel_size = base_layer .kernel_size ,
48+ stride = base_layer .stride ,
49+ bias = has_bias )
50+
3651
3752def new_output_layer (base_layer : nn .Module , type_name : str , in_dim : int ) -> nn .Module :
3853 has_bias = layer_has_bias (base_layer )
39-
54+
4055 if layer_is_linear (type_name ):
4156 return nn .Linear (in_dim , base_layer .out_features , bias = has_bias )
42-
57+
4358 if layer_is_conv2d (type_name ):
44- return nn .Conv2d (in_dim , base_layer .out_channels ,
45- kernel_size = base_layer .kernel_size , stride = base_layer .stride , bias = has_bias )
59+ return nn .Conv2d (
60+ in_dim ,
61+ base_layer .out_channels ,
62+ kernel_size = base_layer .kernel_size ,
63+ stride = base_layer .stride ,
64+ bias = has_bias )
4665
4766
4867def redo_layer (layer : nn .Module , new_in = None , new_out = None ) -> nn .Module :
4968 if new_in is None and new_out is None :
50- return layer
51-
69+ return layehr
70+
5271 _type = type_name (layer )
5372 if not type_supported (_type ):
5473 raise ValueError ('Unsupported layer type:' , _type )
55-
56- if new_in is not None and new_out is not None :
74+
75+ received_new_input = new_in is not None
76+ received_new_output = new_out is not None
77+
78+ if received_new_input and received_new_output :
5779 return new_layer (layer , _type , new_in , new_out )
58-
59- if new_in is not None :
80+
81+ if received_new_input :
6082 # we need a new input dim, but retain the same output dim
6183 return new_output_layer (layer , _type , new_in )
62-
63- if new_out is not None :
84+
85+ if received_new_output :
6486 # we need a new output dim, but retain the same input dim
6587 return new_input_layer (layer , _type , new_out )
6688
89+
6790#################### TYPE HELPERS ####################
6891
92+
6993def layer_is_conv2d (name : str ):
7094 return name == 'Conv2d'
7195
96+
7297def layer_is_linear (name : str ):
7398 return name == 'Linear'
7499
100+
75101def type_name (o ):
76102 '''Returns the simplified type name of the given object.
77103 Eases type checking, rather than any(isinstance(some_obj, _type) for _type in [my, types, to, check])
78104 '''
79105 return type (o ).__name__
80106
107+
81108def type_supported (type_name : str ) -> bool :
82- return type_name in ['Conv2d' , 'Linear' ]
109+ return type_name in ['Conv2d' , 'Linear' ]
0 commit comments