33import torch
44import torch .nn as nn
55
6- from ..nn .utils import layer_has_bias
6+ from ..nn .utils import layer_has_bias , redo_layer
7+ from .._utils import check , round
78
89
910# NOTE: should factor be {smaller, default at all}?
10- # TODO: Research - is there a better type for layer than nn.Module?
1111def widen (layer : nn .Module , factor = 1.4 , in_place = False ) -> nn .Module :
1212 """
1313 Args:
@@ -23,23 +23,18 @@ def widen(layer: nn.Module, factor=1.4, in_place=False) -> nn.Module:
2323 Returns:
2424 A new layer of the base type (e.g. nn.Linear) or `None` if in_place=True
2525 """
26- if factor < 1.0 :
27- raise ValueError ('Cannot shrink with the widen() function' )
28- if factor == 1.0 :
29- raise ValueError ("You shouldn't waste compute time if you're not changing anything" )
26+ check (factor > 1.0 , "Your call to widen() should be increasing the size of your layers" )
3027 # we know that layer.weight.size()[0] is the __output__ dimension in the linear case
3128 output_dim = 0
3229 if isinstance (layer , nn .Linear ):
3330 output_dim = layer .weight .size ()[0 ] # FIXME: switch to layer.out_features?
3431 input_dim = layer .weight .size ()[1 ] # FIXME: switch to layer.in_features?
35- # TODO: other classes, for robustness?
36- # TODO: Use dictionary look-ups instead, because they're faster?
3732 else :
3833 raise ValueError ('unsupported layer type:' , type (layer ))
3934
4035 logging .debug (f"current dimensions: { (output_dim , input_dim )} " )
4136
42- new_size = round (factor * output_dim + .5 ) # round up, not down, if we can
37+ new_size = round (factor * output_dim ) # round up, not down, if we can
4338
4439 # We're increasing layer width from output_dim to new_size, so let's save that for later
4540 size_diff = new_size - output_dim
@@ -56,20 +51,26 @@ def widen(layer: nn.Module, factor=1.4, in_place=False) -> nn.Module:
5651
5752 # TODO: cleanup duplication? Missing properties that will effect usability?
5853 if in_place :
59- layer .out_features = new_size
60- layer .weight = p_weights
61- layer .bias = p_bias
62- logging .warning (
63- 'Using experimental "in-place" version. May have unexpected affects on activation.'
64- )
54+ write_layer_properties (layer , new_size , p_weights , p_bias )
6555 return layer
6656 else :
67- print (f"New shape = { expanded_weights .shape } " )
68- l = nn .Linear (* expanded_weights .shape [::- 1 ], bias = utils .layer_has_bias (layer ))
69- l .weight = p_weights
70- l .bias = p_bias
57+ logging .debug (f"New shape = { expanded_weights .shape } " )
58+ new_input , new_output = expanded_weights [1 ], expanded_weights [0 ]
59+ l = redo_layer (layer , new_in = new_input , new_out = new_output )
60+ write_layer_properties (layer , new_size = None , new_weights = p_weights , new_bias = p_bias )
61+
7162 return l
7263
64+ def write_layer_properties (layer , new_size , new_weights , new_bias ):
65+ """Assigns properties to this `layer`, making the changes on a model in-line
66+ """
67+ if new_size : layer .out_features = new_size
68+ if new_weights : layer .weight = new_weights
69+ if new_bias : layer .bias = new_bias
70+ logging .warning (
71+ 'Using experimental "in-place" version. May have unexpected affects on activation.'
72+ )
73+
7374
7475def _expand_bias_or_weight (t : nn .Parameter , increase : int ) -> torch .Tensor :
7576 """Returns a tensor of shape `t`, with padding values drawn from a Guassian distribution
0 commit comments