Skip to content

Commit ae4b4cf

Browse files
committed
instantiate from module string
1 parent de5ec7d commit ae4b4cf

File tree

2 files changed

+66
-20
lines changed

2 files changed

+66
-20
lines changed

src/model_constructor/helpers.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import importlib
12
from collections import OrderedDict
23
from functools import partial
3-
from typing import Iterable, Optional, Union
4-
from pydantic import BaseModel
4+
from typing import Any, Iterable, Optional, Union
55

6+
from pydantic import BaseModel
67
from torch import nn
78

8-
99
ListStrMod = list[tuple[str, nn.Module]]
1010
ModSeq = Union[nn.Module, nn.Sequential]
1111

@@ -25,6 +25,49 @@ def init_cnn(module: nn.Module) -> None:
2525
init_cnn(layer)
2626

2727

28+
def is_module(val: Any) -> bool:
29+
"""Check if val is a nn.Module or partial of nn.Module."""
30+
31+
to_check = val
32+
if isinstance(val, partial):
33+
to_check = val.func
34+
try:
35+
return issubclass(to_check, nn.Module)
36+
except TypeError:
37+
return False
38+
39+
40+
def instantiate_module(
41+
name: str,
42+
default_path: Optional[str] = None,
43+
) -> nn.Module:
44+
"""Instantiate model from name."""
45+
if default_path is None:
46+
path_list = name.rsplit(".", 1)
47+
if len(path_list) == 1:
48+
default_path = "torch.nn"
49+
name = path_list[0]
50+
else:
51+
if path_list[0] == "nn":
52+
default_path = "torch.nn"
53+
name = path_list[1]
54+
else:
55+
default_path = path_list[0]
56+
name = path_list[1]
57+
try:
58+
mod = importlib.import_module(default_path)
59+
except ImportError:
60+
raise ImportError(f"Module {default_path} not found")
61+
if hasattr(mod, name):
62+
module = getattr(mod, name)
63+
if is_module(module):
64+
return module
65+
else:
66+
raise ImportError(f"Module {name} is not a nn.Module")
67+
else:
68+
raise ImportError(f"Module {name} not found at {default_path}")
69+
70+
2871
class Cfg(BaseModel):
2972
"""Base class for config."""
3073

src/model_constructor/model_constructor.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from torch import nn
88

99
from .blocks import BasicBlock, BottleneckBlock
10-
from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq
10+
from .helpers import (Cfg, ListStrMod, ModSeq, init_cnn, instantiate_module,
11+
is_module, nn_seq)
1112
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1213

1314
__all__ = [
@@ -24,16 +25,7 @@
2425
}
2526

2627

27-
def is_module(val: Any) -> bool:
28-
"""Check if val is a nn.Module or partial of nn.Module."""
29-
30-
to_check = val
31-
if isinstance(val, partial):
32-
to_check = val.func
33-
try:
34-
return issubclass(to_check, nn.Module)
35-
except TypeError:
36-
return False
28+
nnModule = Union[type[nn.Module], Callable[[], nn.Module], str]
3729

3830

3931
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
@@ -42,13 +34,13 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4234
name: Optional[str] = None
4335
in_chans: int = 3
4436
num_classes: int = 1000
45-
block: type[nn.Module] = BasicBlock
46-
conv_layer: type[nn.Module] = ConvBnAct
37+
block: nnModule = BasicBlock
38+
conv_layer: nnModule = ConvBnAct
4739
block_sizes: list[int] = [64, 128, 256, 512]
4840
layers: list[int] = [2, 2, 2, 2]
49-
norm: type[nn.Module] = nn.BatchNorm2d
50-
act_fn: type[nn.Module] = nn.ReLU
51-
pool: Optional[Callable[[Any], nn.Module]] = None
41+
norm: nnModule = nn.BatchNorm2d
42+
act_fn: nnModule = nn.ReLU
43+
pool: Optional[nnModule] = None
5244
expansion: int = 1
5345
groups: int = 1
5446
dw: bool = False
@@ -61,11 +53,22 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
6153
zero_bn: bool = True
6254
stem_stride_on: int = 0
6355
stem_sizes: list[int] = [64]
64-
stem_pool: Optional[Callable[[], nn.Module]] = partial(
56+
stem_pool: Optional[nnModule] = partial(
6557
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
6658
)
6759
stem_bn_end: bool = False
6860

61+
@field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool")
62+
def set_modules( # pylint: disable=no-self-argument
63+
cls, value: Union[type[nn.Module], str], info: FieldValidationInfo,
64+
) -> Union[type[nn.Module], Callable[[], nn.Module]]:
65+
"""Check values, if string, convert to nn.Module."""
66+
if is_module(value):
67+
return value
68+
if isinstance(value, str):
69+
return instantiate_module(value)
70+
raise ValueError(f"{info.field_name} must be str or nn.Module")
71+
6972
@field_validator("se", "sa")
7073
def set_se( # pylint: disable=no-self-argument
7174
cls, value: Union[bool, type[nn.Module]], info: FieldValidationInfo,

0 commit comments

Comments
 (0)