|
1 | 1 | from morph.layers.sparse import percent_waste |
2 | | -from morph._utils import check, round |
3 | | -from morph.nn.utils import in_dim, out_dim |
| 2 | +from morph.utils import check, round |
| 3 | +from .resizing import Resizing |
| 4 | +from .utils import in_dim, out_dim, group_layers_by_algo |
| 5 | +from .widen import resize_layers |
| 6 | +from ._types import type_name |
| 7 | + |
| 8 | +from typing import List |
4 | 9 |
|
5 | 10 | import torch.nn as nn |
6 | 11 |
|
7 | 12 |
|
8 | | -def calc_reduced_size(layer: nn.Module) -> (int, int): |
9 | | - """Calculates the reduced size of the layer, post training (initial or morphed re-training) |
10 | | - so the layers can be resized. |
| 13 | +class Shrinkage: |
| 14 | + """ |
| 15 | + An intermediary for the "Shrink" step of the three step Morphing algorithm. |
| 16 | + Rather than have all of the state be free in the small scope of a mega-function, |
| 17 | + these abstractions ease the way of implementing the shrinking and prune of the |
| 18 | + network. |
| 19 | + * Given that we have access to the total count of nodes, and how wasteful a layer was |
| 20 | + we can deduce any necessary changes once given a new input dimension |
| 21 | + * We expect input dimensions to change to accomodate the trimmed down earlier layers, |
| 22 | + but we want an expansion further along to allow the opening of bottlenecks in the architecture |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, input_dimension: int, initial_parameters: int, |
| 26 | + waste_percentage: float): |
| 27 | + self.input_dimension = input_dimension # TODO: is this relevant in any non-Linear case? |
| 28 | + self.initial_parameters = initial_parameters |
| 29 | + self.waste_percentage = waste_percentage |
| 30 | + self.reduced_parameters = Shrinkage.reduce_parameters(initial_parameters, |
| 31 | + waste_percentage) |
| 32 | + |
| 33 | + @staticmethod |
| 34 | + def reduce_parameters(initial_parameters: int, waste: float) -> int: |
| 35 | + """Calculates the new, smaller, number of paratemers that this instance encapsulates""" |
| 36 | + percent_keep = (1. - waste) |
| 37 | + unrounded_params_to_keep = percent_keep * initial_parameters |
| 38 | + # round digital up to the nearest integer |
| 39 | + return round(unrounded_params_to_keep) |
| 40 | + |
| 41 | + |
| 42 | +def shrink_to_resize(shrinkage: Shrinkage, new_input_dimension: int) -> Resizing: |
| 43 | + """Given the `new_input_dimension`, calculate a reshaping/resizing for the parameters |
| 44 | + of the supplied `shrinkage`. |
| 45 | + We round up the new output dimension, generously allowing for opening bottlenecks. |
| 46 | + Iteratively, any waste introduced is pruned hereafter. (Needs proof/unit test) |
11 | 47 | """ |
12 | | - # TODO: remove this guard when properly we protect access to this function |
13 | | - check( |
14 | | - type(layer) == nn.Conv2d or type(layer) == nn.Linear, |
15 | | - 'Invalid layer type: ' + type(layer)) |
| 48 | + new_output_dimension = round(shrinkage.reduced_parameters / new_input_dimension) |
| 49 | + return Resizing(new_input_dimension, new_output_dimension) |
| 50 | + |
| 51 | + |
| 52 | +#################### prove of a good implementation #################### |
| 53 | + |
| 54 | + |
| 55 | +def uniform_prune(net: nn.Module) -> nn.Module: |
| 56 | + """Shrink the network down 70%. Input and output dimensions are not altered""" |
| 57 | + return resize_layers(net, width_factor=0.7) |
| 58 | + |
| 59 | + |
| 60 | +#################### the algorithm to end all algorithms #################### |
| 61 | + |
| 62 | + |
| 63 | +def shrink_layer(layer: nn.Module) -> Shrinkage: |
| 64 | + waste = percent_waste(layer) |
| 65 | + parameter_count = layer.weight.numel() # the count is already tracked for us |
| 66 | + return Shrinkage(in_dim(layer), parameter_count, waste) |
| 67 | + |
| 68 | + |
| 69 | +def fit_layer_sizes(layer_sizes: List[Shrinkage]) -> List[Resizing]: |
| 70 | + # TODO: where's the invocation site for shrink_to_resize |
| 71 | + pass |
| 72 | + |
| 73 | + |
| 74 | +def transform(original_layer: nn.Module, new_shape: Resizing) -> nn.Module: |
| 75 | + # TODO: this might just be utils.redo_layer, without the primitive obsession |
| 76 | + pass |
| 77 | + |
| 78 | + |
| 79 | +def shrink_prune_fit(net: nn.Module) -> nn.Module: |
| 80 | + first, middle_layers, last = group_layers_by_algo(net) |
| 81 | + shrunk = { |
| 82 | + "first": shrink_layer(first), |
| 83 | + "middle": [shrink_layer(m) for m in middle_layers], |
| 84 | + "last": shrink_layer(last) |
| 85 | + } |
| 86 | + |
| 87 | + # FIXME: why doesn't the linter like `fitted_layers` |
| 88 | + fitted_layers = fit_layer_sizes([shrunk["first"], *shrunk["middle"], shrunk["last"]]) |
| 89 | + |
| 90 | + # iteration very similar to `resize_layers` but matches Shrinkage with the corresponding layer |
| 91 | + new_first, new_middle_layers, new_last = group_layers_by_algo(fitted_layers) |
| 92 | + |
| 93 | + new_net = nn.Module() |
| 94 | + |
| 95 | + new_net.add_module(type_name(first), transform(first, new_first)) |
| 96 | + |
| 97 | + for old, new in zip(middle_layers, new_middle_layers): |
| 98 | + new_net.add_module(type_name(old), transform(old, new)) |
| 99 | + pass # append to new_net with the Shrinkage's properties |
16 | 100 |
|
17 | | - percent_keep = 1 - percent_waste(layer) |
18 | | - shrunk_in, shrunk_out = percent_keep * in_dim(layer), percent_keep * out_dim(layer) |
| 101 | + new_net.add_module(type_name(last), transform(last, new_last)) |
19 | 102 |
|
20 | | - return round(shrunk_in), round(shrunk_out) |
| 103 | + return new_net |
0 commit comments