Skip to content

Commit db1f944

Browse files
committed
Reorganize utilities
The code that was in shrink was utility code that I found uses for in at least two other places + Satisfies my rule of three occurences + Led to successful porting of notebook code
1 parent 744bb5a commit db1f944

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

morph/nn/_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def type_name(o):
2+
'''Returns the simplified type name of the given object.
3+
Eases type checking, rather than any(isinstance(some_obj, _type) for _type in [my, types, to, check])
4+
'''
5+
return type(o).__name__
6+
7+
8+
def type_supported(type_name: str) -> bool:
9+
return type_name in ['Conv2d', 'Linear']

morph/nn/utils.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
11
import torch.nn as nn
22

3+
from morph.nn._types import type_name, type_supported
4+
5+
from typing import List, Tuple, TypeVar
6+
7+
ML = TypeVar('MODULES', List[nn.Module])
8+
# Type constrained to be the results of nn.Module.children() or ...named_children()
9+
CL = TypeVar('MODULE_CHILDREN_LIST', ML, List[Tuple[str, nn.Module]])
10+
11+
12+
def group_layers_by_algo(children_list: CL) -> ML:
13+
"""Group the layers into how they will be acted upon by my implementation of the algorithm:
14+
1. First child in the list (the "input" layer)
15+
2. Slice of all the child, those that are not first nor last
16+
3. Last child in the list (the "output" layer)
17+
"""
18+
19+
list_len = len(children_list)
20+
21+
# validate input in case I slip up
22+
if list_len < 1:
23+
raise ValueError('Invalid argument:', children_list)
24+
25+
if list_len <= 2:
26+
return children_list # interface?
27+
28+
first = children_list[0]
29+
middle = children_list[1:-1]
30+
last = children_list[-1]
31+
32+
return first, middle, last
33+
334

435
def layer_has_bias(layer: nn.Module) -> bool:
536
return not layer.bias is None
@@ -66,7 +97,7 @@ def new_output_layer(base_layer: nn.Module, type_name: str, in_dim: int) -> nn.M
6697

6798
def redo_layer(layer: nn.Module, new_in=None, new_out=None) -> nn.Module:
6899
if new_in is None and new_out is None:
69-
return layehr
100+
return layer
70101

71102
_type = type_name(layer)
72103
if not type_supported(_type):
@@ -96,14 +127,3 @@ def layer_is_conv2d(name: str):
96127

97128
def layer_is_linear(name: str):
98129
return name == 'Linear'
99-
100-
101-
def type_name(o):
102-
'''Returns the simplified type name of the given object.
103-
Eases type checking, rather than any(isinstance(some_obj, _type) for _type in [my, types, to, check])
104-
'''
105-
return type(o).__name__
106-
107-
108-
def type_supported(type_name: str) -> bool:
109-
return type_name in ['Conv2d', 'Linear']

0 commit comments

Comments
 (0)