|
3 | 3 | from typing import Any, Callable, Optional, TypeVar, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | | -from pydantic import BaseModel, field_validator |
| 6 | +from pydantic import field_validator |
7 | 7 | from torch import nn |
8 | 8 |
|
9 | | -from .helpers import nn_seq |
| 9 | +from .helpers import nn_seq, Cfg, init_cnn |
10 | 10 | from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act |
11 | 11 |
|
12 | 12 | __all__ = [ |
|
23 | 23 | ListStrMod = list[tuple[str, nn.Module]] |
24 | 24 |
|
25 | 25 |
|
26 | | -def init_cnn(module: nn.Module) -> None: |
27 | | - "Init module - kaiming_normal for Conv2d and 0 for biases." |
28 | | - if getattr(module, "bias", None) is not None: |
29 | | - nn.init.constant_(module.bias, 0) # type: ignore |
30 | | - if isinstance(module, (nn.Conv2d, nn.Linear)): |
31 | | - nn.init.kaiming_normal_(module.weight) |
32 | | - for layer in module.children(): |
33 | | - init_cnn(layer) |
34 | | - |
35 | | - |
36 | 26 | class BasicBlock(nn.Module): |
37 | 27 | """Basic Resnet block. |
38 | 28 | Configurable - can use pool to reduce at identity path, change act etc.""" |
@@ -212,6 +202,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore |
212 | 202 | return self.act_fn(self.convs(x) + identity) |
213 | 203 |
|
214 | 204 |
|
| 205 | +class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"): |
| 206 | + """Model constructor Config. As default - xresnet18""" |
| 207 | + |
| 208 | + name: Optional[str] = None |
| 209 | + in_chans: int = 3 |
| 210 | + num_classes: int = 1000 |
| 211 | + block: type[nn.Module] = BasicBlock |
| 212 | + conv_layer: type[nn.Module] = ConvBnAct |
| 213 | + block_sizes: list[int] = [64, 128, 256, 512] |
| 214 | + layers: list[int] = [2, 2, 2, 2] |
| 215 | + norm: type[nn.Module] = nn.BatchNorm2d |
| 216 | + act_fn: type[nn.Module] = nn.ReLU |
| 217 | + pool: Optional[Callable[[Any], nn.Module]] = None |
| 218 | + expansion: int = 1 |
| 219 | + groups: int = 1 |
| 220 | + dw: bool = False |
| 221 | + div_groups: Union[int, None] = None |
| 222 | + sa: Union[bool, type[nn.Module]] = False |
| 223 | + se: Union[bool, type[nn.Module]] = False |
| 224 | + se_module: Union[bool, None] = None |
| 225 | + se_reduction: Union[int, None] = None |
| 226 | + bn_1st: bool = True |
| 227 | + zero_bn: bool = True |
| 228 | + stem_stride_on: int = 0 |
| 229 | + stem_sizes: list[int] = [64] |
| 230 | + stem_pool: Union[Callable[[], nn.Module], None] = partial( |
| 231 | + nn.MaxPool2d, kernel_size=3, stride=2, padding=1 |
| 232 | + ) |
| 233 | + stem_bn_end: bool = False |
| 234 | + |
| 235 | + def __repr__(self) -> str: |
| 236 | + se_repr = self.se.__name__ if self.se else "False" # type: ignore |
| 237 | + model_name = self.name or self.__class__.__name__ |
| 238 | + return ( |
| 239 | + f"{model_name}\n" |
| 240 | + f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n" |
| 241 | + f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n" |
| 242 | + f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n" |
| 243 | + f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n" |
| 244 | + f" body sizes {self.block_sizes}\n" |
| 245 | + f" layers: {self.layers}" |
| 246 | + ) |
| 247 | + |
| 248 | + |
215 | 249 | def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore |
216 | 250 | """Create Resnet stem.""" |
217 | 251 | stem: ListStrMod = [ |
@@ -285,87 +319,15 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore |
285 | 319 | return nn_seq(head) |
286 | 320 |
|
287 | 321 |
|
288 | | -class ModelCfg(BaseModel, arbitrary_types_allowed=True, extra="forbid"): |
289 | | - """Model constructor Config. As default - xresnet18""" |
| 322 | +class ModelConstructor(ModelCfg): |
| 323 | + """Model constructor. As default - resnet18""" |
290 | 324 |
|
291 | | - name: Optional[str] = None |
292 | | - in_chans: int = 3 |
293 | | - num_classes: int = 1000 |
294 | | - block: type[nn.Module] = BasicBlock |
295 | | - conv_layer: type[nn.Module] = ConvBnAct |
296 | | - block_sizes: list[int] = [64, 128, 256, 512] |
297 | | - layers: list[int] = [2, 2, 2, 2] |
298 | | - norm: type[nn.Module] = nn.BatchNorm2d |
299 | | - act_fn: type[nn.Module] = nn.ReLU |
300 | | - pool: Optional[Callable[[Any], nn.Module]] = None |
301 | | - expansion: int = 1 |
302 | | - groups: int = 1 |
303 | | - dw: bool = False |
304 | | - div_groups: Union[int, None] = None |
305 | | - sa: Union[bool, type[nn.Module]] = False |
306 | | - se: Union[bool, type[nn.Module]] = False |
307 | | - se_module: Union[bool, None] = None |
308 | | - se_reduction: Union[int, None] = None |
309 | | - bn_1st: bool = True |
310 | | - zero_bn: bool = True |
311 | | - stem_stride_on: int = 0 |
312 | | - stem_sizes: list[int] = [64] |
313 | | - stem_pool: Union[Callable[[], nn.Module], None] = partial( |
314 | | - nn.MaxPool2d, kernel_size=3, stride=2, padding=1 |
315 | | - ) |
316 | | - stem_bn_end: bool = False |
317 | 325 | init_cnn: Callable[[nn.Module], None] = init_cnn |
318 | 326 | make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore |
319 | 327 | make_layer: Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore |
320 | 328 | make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore |
321 | 329 | make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore |
322 | 330 |
|
323 | | - def _get_str_value(self, field: str) -> str: |
324 | | - value = getattr(self, field) |
325 | | - if isinstance(value, type): |
326 | | - value = value.__name__ |
327 | | - elif isinstance(value, partial): |
328 | | - value = f"{value.func.__name__} {value.keywords}" |
329 | | - elif callable(value): |
330 | | - value = value.__name__ |
331 | | - return value |
332 | | - |
333 | | - def __repr__(self) -> str: |
334 | | - return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})" |
335 | | - |
336 | | - def __repr_args__(self) -> list[tuple[str, str]]: |
337 | | - return [ |
338 | | - (field, str_value) |
339 | | - for field in self.model_fields |
340 | | - if (str_value := self._get_str_value(field)) |
341 | | - ] |
342 | | - |
343 | | - def __repr_changed_args__(self) -> list[str]: |
344 | | - """Return list repr for changed fields""" |
345 | | - return [ |
346 | | - f"{field}: {self._get_str_value(field)}" |
347 | | - for field in self.model_fields_set |
348 | | - if field != "name" |
349 | | - ] |
350 | | - |
351 | | - def print_cfg(self) -> None: |
352 | | - """Print full config""" |
353 | | - print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})") |
354 | | - |
355 | | - def print_changed(self) -> None: |
356 | | - """Print changed fields.""" |
357 | | - changed_fields = self.__repr_changed_args__() |
358 | | - if changed_fields: |
359 | | - print("Changed fields:") |
360 | | - for i in changed_fields: |
361 | | - print(" ", i) |
362 | | - else: |
363 | | - print("Nothing changed") |
364 | | - |
365 | | - |
366 | | -class ModelConstructor(ModelCfg): |
367 | | - """Model constructor. As default - resnet18""" |
368 | | - |
369 | 331 | @field_validator("se") |
370 | 332 | def set_se( # pylint: disable=no-self-argument |
371 | 333 | cls, value: Union[bool, type[nn.Module]] |
@@ -430,19 +392,6 @@ def __call__(self) -> nn.Sequential: |
430 | 392 | model.extra_repr = lambda: ", ".join(extra_repr) |
431 | 393 | return model |
432 | 394 |
|
433 | | - def __repr__(self) -> str: |
434 | | - se_repr = self.se.__name__ if self.se else "False" # type: ignore |
435 | | - model_name = self.name or self.__class__.__name__ |
436 | | - return ( |
437 | | - f"{model_name}\n" |
438 | | - f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n" |
439 | | - f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n" |
440 | | - f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n" |
441 | | - f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n" |
442 | | - f" body sizes {self.block_sizes}\n" |
443 | | - f" layers: {self.layers}" |
444 | | - ) |
445 | | - |
446 | 395 |
|
447 | 396 | class ResNet34(ModelConstructor): |
448 | 397 | layers: list[int] = [3, 4, 6, 3] |
|
0 commit comments