Skip to content

Commit b908ef3

Browse files
committed
remove typevar, fix __init__
1 parent 4efee64 commit b908ef3

File tree

2 files changed

+12
-19
lines changed

2 files changed

+12
-19
lines changed

src/model_constructor/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from model_constructor.convmixer import ConvMixer # noqa F401
2-
from model_constructor.model_constructor import (
3-
ModelConstructor,
4-
ModelCfg,
5-
) # noqa F401
6-
7-
from model_constructor.version import __version__ # noqa F401
1+
from .convmixer import ConvMixer # noqa F401
2+
from .model_constructor import ModelConstructor, ModelCfg # noqa F401
3+
from .version import __version__ # noqa F401

src/model_constructor/model_constructor.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, Optional, TypeVar, Union
3+
from typing import Any, Callable, Optional, Union
44

55
from pydantic import field_validator
66
from torch import nn
@@ -17,9 +17,6 @@
1717
]
1818

1919

20-
TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
21-
22-
2320
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
2421
"""Model constructor Config. As default - xresnet18"""
2522

@@ -64,7 +61,7 @@ def __repr__(self) -> str:
6461
)
6562

6663

67-
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
64+
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
6865
"""Create Resnet stem."""
6966
stem: ListStrMod = [
7067
(
@@ -88,7 +85,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
8885
return nn_seq(stem)
8986

9087

91-
def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
88+
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
9289
"""Create layer (stage)"""
9390
# if no pool on stem - stride = 2 for first layer block in body
9491
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
@@ -119,15 +116,15 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
119116
)
120117

121118

122-
def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
119+
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
123120
"""Create model body."""
124121
return nn_seq(
125122
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
126123
for layer_num in range(len(cfg.layers))
127124
)
128125

129126

130-
def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
127+
def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
131128
"""Create head."""
132129
head = [
133130
("pool", nn.AdaptiveAvgPool2d(1)),
@@ -141,10 +138,10 @@ class ModelConstructor(ModelCfg):
141138
"""Model constructor. As default - resnet18"""
142139

143140
init_cnn: Callable[[nn.Module], None] = init_cnn
144-
make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
145-
make_layer: Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
146-
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
147-
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
141+
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
142+
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
143+
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
144+
make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
148145

149146
@field_validator("se")
150147
def set_se( # pylint: disable=no-self-argument

0 commit comments

Comments
 (0)