Skip to content

Commit e7195aa

Browse files
committed
typing mc, yaresnet
1 parent 92de2d9 commit e7195aa

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

src/model_constructor/model_constructor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
3+
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
@@ -39,7 +39,7 @@ def __init__(
3939
mid_channels: int,
4040
stride: int = 1,
4141
conv_layer=ConvBnAct,
42-
act_fn: Type[nn.Module] = nn.ReLU,
42+
act_fn: type[nn.Module] = nn.ReLU,
4343
zero_bn: bool = True,
4444
bn_1st: bool = True,
4545
groups: int = 1,
@@ -153,7 +153,7 @@ def forward(self, x):
153153

154154
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
155155
len_stem = len(cfg.stem_sizes)
156-
stem: List[tuple[str, nn.Module]] = [
156+
stem: list[tuple[str, nn.Module]] = [
157157
(
158158
f"conv_{i}",
159159
cfg.conv_layer(
@@ -238,27 +238,27 @@ class ModelCfg(BaseModel):
238238
name: Optional[str] = None
239239
in_chans: int = 3
240240
num_classes: int = 1000
241-
block: Type[nn.Module] = ResBlock
242-
conv_layer: Type[nn.Module] = ConvBnAct
243-
block_sizes: List[int] = [64, 128, 256, 512]
244-
layers: List[int] = [2, 2, 2, 2]
245-
norm: Type[nn.Module] = nn.BatchNorm2d
246-
act_fn: Type[nn.Module] = nn.ReLU
241+
block: type[nn.Module] = ResBlock
242+
conv_layer: type[nn.Module] = ConvBnAct
243+
block_sizes: list[int] = [64, 128, 256, 512]
244+
layers: list[int] = [2, 2, 2, 2]
245+
norm: type[nn.Module] = nn.BatchNorm2d
246+
act_fn: type[nn.Module] = nn.ReLU
247247
pool: Callable[[Any], nn.Module] = partial(
248248
nn.AvgPool2d, kernel_size=2, ceil_mode=True
249249
)
250250
expansion: int = 1
251251
groups: int = 1
252252
dw: bool = False
253253
div_groups: Union[int, None] = None
254-
sa: Union[bool, int, Type[nn.Module]] = False
255-
se: Union[bool, int, Type[nn.Module]] = False
254+
sa: Union[bool, int, type[nn.Module]] = False
255+
se: Union[bool, int, type[nn.Module]] = False
256256
se_module: Union[bool, None] = None
257257
se_reduction: Union[int, None] = None
258258
bn_1st: bool = True
259259
zero_bn: bool = True
260260
stem_stride_on: int = 0
261-
stem_sizes: List[int] = [32, 32, 64]
261+
stem_sizes: list[int] = [32, 32, 64]
262262
stem_pool: Union[Callable[[], nn.Module], None] = partial(
263263
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
264264
)
@@ -286,7 +286,7 @@ def _get_str_value(self, field: str) -> str:
286286
def __repr__(self) -> str:
287287
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
288288

289-
def __repr_args__(self):
289+
def __repr_args__(self) -> list[tuple[str, str]]:
290290
return [
291291
(field, str_value)
292292
for field in self.__fields__
@@ -325,7 +325,7 @@ def body(self):
325325
def from_cfg(cls, cfg: ModelCfg):
326326
return cls(**cfg.dict())
327327

328-
def __call__(self):
328+
def __call__(self) -> nn.Sequential:
329329
model_name = self.name or self.__class__.__name__
330330
named_sequential = type(model_name, (nn.Sequential,), {})
331331
model = named_sequential(
@@ -338,13 +338,14 @@ def __call__(self):
338338
return model
339339

340340
def _get_extra_repr(self) -> str:
341+
"""Repr for changed fields"""
341342
return " ".join(
342343
f"{field}: {self._get_str_value(field)},"
343344
for field in self.__fields_set__
344345
if field != "name"
345-
)[:-1]
346+
)[:-1] # strip last comma.
346347

347-
def __repr__(self):
348+
def __repr__(self) -> str:
348349
se_repr = self.se.__name__ if self.se else "False" # type: ignore
349350
model_name = self.name or self.__class__.__name__
350351
return (

src/model_constructor/yaresnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import Callable, List, Type, Union
5+
from typing import Callable, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -27,15 +27,15 @@ def __init__(
2727
mid_channels: int,
2828
stride: int = 1,
2929
conv_layer=ConvBnAct,
30-
act_fn: Type[nn.Module] = nn.ReLU,
30+
act_fn: type[nn.Module] = nn.ReLU,
3131
zero_bn: bool = True,
3232
bn_1st: bool = True,
3333
groups: int = 1,
3434
dw: bool = False,
3535
div_groups: Union[None, int] = None,
3636
pool: Union[Callable[[], nn.Module], None] = None,
37-
se: Union[Type[nn.Module], None] = None,
38-
sa: Union[Type[nn.Module], None] = None,
37+
se: Union[type[nn.Module], None] = None,
38+
sa: Union[type[nn.Module], None] = None,
3939
):
4040
super().__init__()
4141
# pool defined at ModelConstructor.
@@ -115,9 +115,9 @@ def __init__(
115115
), # noqa E501
116116
]
117117
if se:
118-
layers.append(("se", se(out_channels)))
118+
layers.append(("se", se(out_channels))) # type: ignore
119119
if sa:
120-
layers.append(("sa", sa(out_channels)))
120+
layers.append(("sa", sa(out_channels))) # type: ignore
121121
self.convs = nn.Sequential(OrderedDict(layers))
122122
if in_channels != out_channels:
123123
self.id_conv = conv_layer(
@@ -143,7 +143,7 @@ class YaResNet34(ModelConstructor):
143143
expansion: int = 1
144144
layers: list[int] = [3, 4, 6, 3]
145145
stem_sizes: list[int] = [3, 32, 64, 64]
146-
act_fn: Type[nn.Module] = Mish
146+
act_fn: type[nn.Module] = Mish
147147

148148

149149
class YaResNet50(YaResNet34):

0 commit comments

Comments
 (0)