|
1 | 1 | import torch.nn as nn |
2 | 2 |
|
| 3 | +from morph.nn._types import type_name, type_supported |
| 4 | + |
| 5 | +from typing import List, Tuple, TypeVar |
| 6 | + |
| 7 | +ML = TypeVar('MODULES', List[nn.Module]) |
| 8 | +# Type constrained to be the results of nn.Module.children() or ...named_children() |
| 9 | +CL = TypeVar('MODULE_CHILDREN_LIST', ML, List[Tuple[str, nn.Module]]) |
| 10 | + |
| 11 | + |
| 12 | +def group_layers_by_algo(children_list: CL) -> ML: |
| 13 | + """Group the layers into how they will be acted upon by my implementation of the algorithm: |
| 14 | + 1. First child in the list (the "input" layer) |
| 15 | + 2. Slice of all the child, those that are not first nor last |
| 16 | + 3. Last child in the list (the "output" layer) |
| 17 | + """ |
| 18 | + |
| 19 | + list_len = len(children_list) |
| 20 | + |
| 21 | + # validate input in case I slip up |
| 22 | + if list_len < 1: |
| 23 | + raise ValueError('Invalid argument:', children_list) |
| 24 | + |
| 25 | + if list_len <= 2: |
| 26 | + return children_list # interface? |
| 27 | + |
| 28 | + first = children_list[0] |
| 29 | + middle = children_list[1:-1] |
| 30 | + last = children_list[-1] |
| 31 | + |
| 32 | + return first, middle, last |
| 33 | + |
3 | 34 |
|
4 | 35 | def layer_has_bias(layer: nn.Module) -> bool: |
5 | 36 | return not layer.bias is None |
@@ -66,7 +97,7 @@ def new_output_layer(base_layer: nn.Module, type_name: str, in_dim: int) -> nn.M |
66 | 97 |
|
67 | 98 | def redo_layer(layer: nn.Module, new_in=None, new_out=None) -> nn.Module: |
68 | 99 | if new_in is None and new_out is None: |
69 | | - return layehr |
| 100 | + return layer |
70 | 101 |
|
71 | 102 | _type = type_name(layer) |
72 | 103 | if not type_supported(_type): |
@@ -96,14 +127,3 @@ def layer_is_conv2d(name: str): |
96 | 127 |
|
97 | 128 | def layer_is_linear(name: str): |
98 | 129 | return name == 'Linear' |
99 | | - |
100 | | - |
101 | | -def type_name(o): |
102 | | - '''Returns the simplified type name of the given object. |
103 | | - Eases type checking, rather than any(isinstance(some_obj, _type) for _type in [my, types, to, check]) |
104 | | - ''' |
105 | | - return type(o).__name__ |
106 | | - |
107 | | - |
108 | | -def type_supported(type_name: str) -> bool: |
109 | | - return type_name in ['Conv2d', 'Linear'] |
|
0 commit comments