Skip to content

Commit 2b3a909

Browse files
committed
type ModSeq
1 parent c42024e commit 2b3a909

File tree

6 files changed

+19
-19
lines changed

6 files changed

+19
-19
lines changed

noxfile_conda.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
@nox.session(python=["3.9", "3.10", "3.11"], venv_backend="mamba")
55
def conda_tests(session: nox.Session) -> None:
66
args = session.posargs or ["--cov"]
7-
# session.install("pytest", "pytest-cov")
87
session.conda_install("pytest", "pytest-cov")
98
session.conda_install("pytorch")
109
session.conda_install("pydantic")

src/model_constructor/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Iterable, Optional
3+
from typing import Iterable, Optional, Union
44
from pydantic import BaseModel
55

66
from torch import nn
77

88

99
ListStrMod = list[tuple[str, nn.Module]]
10+
ModSeq = Union[nn.Module, nn.Sequential]
1011

1112

1213
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:

src/model_constructor/model_constructor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77

88
from .blocks import BasicBlock, BottleneckBlock
9-
from .helpers import Cfg, ListStrMod, init_cnn, nn_seq
9+
from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq
1010
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1111

1212
__all__ = [
@@ -138,10 +138,10 @@ class ModelConstructor(ModelCfg):
138138
"""Model constructor. As default - resnet18"""
139139

140140
init_cnn: Callable[[nn.Module], None] = init_cnn
141-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
142-
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
143-
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
144-
make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
141+
make_stem: Callable[[ModelCfg], ModSeq] = make_stem # type: ignore
142+
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer # type: ignore
143+
make_body: Callable[[ModelCfg], ModSeq] = make_body # type: ignore
144+
make_head: Callable[[ModelCfg], ModSeq] = make_head # type: ignore
145145

146146
@field_validator("se")
147147
def set_se( # pylint: disable=no-self-argument

src/model_constructor/universal_blocks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Callable, Optional
22

33
import torch
44
from torch import nn
55

6-
from .helpers import nn_seq
6+
from .helpers import ModSeq, nn_seq
77
from .layers import ConvBnAct, get_act
88
from .model_constructor import ListStrMod, ModelCfg, ModelConstructor
99

@@ -337,10 +337,10 @@ def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
337337
class XResNet(ModelConstructor):
338338
"""Base Xresnet constructor."""
339339

340-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem
341-
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer
342-
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body
343-
make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head
340+
make_stem: Callable[[ModelCfg], ModSeq] = make_stem
341+
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer
342+
make_body: Callable[[ModelCfg], ModSeq] = make_body
343+
make_head: Callable[[ModelCfg], ModSeq] = make_head
344344
block: type[nn.Module] = XResBlock
345345

346346

src/model_constructor/xresnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from functools import partial
2-
from typing import Any, Callable, Optional, Union
2+
from typing import Any, Callable, Optional
33

44
from torch import nn
55

66
from .blocks import BottleneckBlock
7-
from .helpers import ListStrMod, nn_seq
7+
from .helpers import ListStrMod, nn_seq, ModSeq
88
from .model_constructor import ModelCfg, ModelConstructor
99

1010
__all__ = [
@@ -39,7 +39,7 @@ def xresnet_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
3939

4040

4141
class XResNet(ModelConstructor):
42-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = xresnet_stem
42+
make_stem: Callable[[ModelCfg], ModSeq] = xresnet_stem
4343
stem_sizes: list[int] = [32, 32, 64]
4444
pool: Optional[Callable[[Any], nn.Module]] = partial(
4545
nn.AvgPool2d, kernel_size=2, ceil_mode=True

src/model_constructor/yaresnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# Yet another ResNet.
33

44
from functools import partial
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, Callable, Optional
66

77
import torch
88
from torch import nn
99

10-
from model_constructor.helpers import nn_seq
10+
from model_constructor.helpers import ModSeq, nn_seq
1111

1212
from .layers import ConvBnAct, get_act
1313
from .model_constructor import ListStrMod, ModelConstructor, ModelCfg
@@ -203,7 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
203203

204204

205205
class YaResNet(ModelConstructor):
206-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = xresnet_stem
206+
make_stem: Callable[[ModelCfg], ModSeq] = xresnet_stem
207207
stem_sizes: list[int] = [32, 64, 64]
208208
block: type[nn.Module] = YaBasicBlock
209209
act_fn: type[nn.Module] = nn.Mish

0 commit comments

Comments
 (0)