Skip to content

Commit cc4b865

Browse files
committed
repr cfg
1 parent 07297f3 commit cc4b865

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

src/model_constructor/model_constructor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ class CfgMC(BaseModel):
155155
class Config:
156156
arbitrary_types_allowed = True
157157

158+
def extra_repr(self) -> str:
159+
res = ""
160+
for k, v in self.dict().items():
161+
if v is not None:
162+
res += f"{k}: {v}\n"
163+
return res
164+
165+
def pprint(self) -> None:
166+
print(self.extra_repr())
167+
158168

159169
def init_cnn(module: nn.Module):
160170
"Init module - kaiming_normal for Conv2d and 0 for biases."

src/model_constructor/yaresnet.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import Union
5+
from typing import List, Type, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -12,8 +12,6 @@
1212

1313
__all__ = [
1414
'YaResBlock',
15-
'yaresnet34',
16-
'yaresnet50',
1715
]
1816

1917

@@ -126,21 +124,14 @@ def forward(self, x):
126124
return self.merge(self.convs(x) + identity)
127125

128126

129-
yaresnet34 = ModelConstructor.from_cfg(
130-
CfgMC(
131-
name='YaResnet34',
132-
block=YaResBlock,
133-
expansion=1,
134-
layers=[3, 4, 6, 3],
135-
act_fn=Mish(),
136-
)
137-
)
138-
yaresnet50 = ModelConstructor.from_cfg(
139-
CfgMC(
140-
name='YaResnet50',
141-
block=YaResBlock,
142-
act_fn=Mish(),
143-
expansion=4,
144-
layers=[3, 4, 6, 3],
145-
)
146-
)
127+
class YaResNet34(ModelConstructor):
128+
name: str = 'YaResnet34'
129+
block: Type[nn.Module] = YaResBlock
130+
expansion: int = 1
131+
layers: List[int] = [3, 4, 6, 3]
132+
act_fn: nn.Module = Mish()
133+
134+
135+
class YaResNet50(YaResNet34):
136+
name: str = 'YaResnet50'
137+
expansion: int = 4

0 commit comments

Comments
 (0)