|
6 | 6 | from torch import nn |
7 | 7 |
|
8 | 8 | from .blocks import BasicBlock, BottleneckBlock |
9 | | -from .helpers import Cfg, ListStrMod, init_cnn, nn_seq |
| 9 | +from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq |
10 | 10 | from .layers import ConvBnAct, SEModule, SimpleSelfAttention |
11 | 11 |
|
12 | 12 | __all__ = [ |
@@ -138,10 +138,10 @@ class ModelConstructor(ModelCfg): |
138 | 138 | """Model constructor. As default - resnet18""" |
139 | 139 |
|
140 | 140 | init_cnn: Callable[[nn.Module], None] = init_cnn |
141 | | - make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore |
142 | | - make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore |
143 | | - make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore |
144 | | - make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore |
| 141 | + make_stem: Callable[[ModelCfg], ModSeq] = make_stem # type: ignore |
| 142 | + make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer # type: ignore |
| 143 | + make_body: Callable[[ModelCfg], ModSeq] = make_body # type: ignore |
| 144 | + make_head: Callable[[ModelCfg], ModSeq] = make_head # type: ignore |
145 | 145 |
|
146 | 146 | @field_validator("se") |
147 | 147 | def set_se( # pylint: disable=no-self-argument |
|
0 commit comments