Skip to content

Commit c42024e

Browse files
committed
typing
1 parent 25fa5e9 commit c42024e

File tree

4 files changed

+30
-30
lines changed

4 files changed

+30
-30
lines changed

src/model_constructor/blocks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Optional
22

33
import torch
44
from torch import nn
@@ -22,10 +22,10 @@ def __init__(
2222
bn_1st: bool = True,
2323
groups: int = 1,
2424
dw: bool = False,
25-
div_groups: Union[None, int] = None,
26-
pool: Union[Callable[[], nn.Module], None] = None,
27-
se: Union[nn.Module, None] = None,
28-
sa: Union[nn.Module, None] = None,
25+
div_groups: Optional[int] = None,
26+
pool: Optional[Callable[[], nn.Module]] = None,
27+
se: Optional[nn.Module] = None,
28+
sa: Optional[nn.Module] = None,
2929
):
3030
super().__init__()
3131
# pool defined at ModelConstructor.
@@ -107,10 +107,10 @@ def __init__(
107107
bn_1st: bool = True,
108108
groups: int = 1,
109109
dw: bool = False,
110-
div_groups: Union[None, int] = None,
111-
pool: Union[Callable[[], nn.Module], None] = None,
112-
se: Union[nn.Module, None] = None,
113-
sa: Union[nn.Module, None] = None,
110+
div_groups: Optional[int] = None,
111+
pool: Optional[Callable[[], nn.Module]] = None,
112+
se: Optional[nn.Module] = None,
113+
sa: Optional[nn.Module] = None,
114114
):
115115
super().__init__()
116116
# pool defined at ModelConstructor.

src/model_constructor/model_constructor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3333
expansion: int = 1
3434
groups: int = 1
3535
dw: bool = False
36-
div_groups: Union[int, None] = None
36+
div_groups: Optional[int] = None
3737
sa: Union[bool, type[nn.Module]] = False
3838
se: Union[bool, type[nn.Module]] = False
39-
se_module: Union[bool, None] = None
40-
se_reduction: Union[int, None] = None
39+
se_module: Optional[bool] = None
40+
se_reduction: Optional[int] = None
4141
bn_1st: bool = True
4242
zero_bn: bool = True
4343
stem_stride_on: int = 0
@@ -186,7 +186,7 @@ def from_cfg(cls, cfg: ModelCfg):
186186

187187
@classmethod
188188
def create_model(
189-
cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]
189+
cls, cfg: Optional[ModelCfg] = None, **kwargs: dict[str, Any]
190190
) -> nn.Sequential:
191191
if cfg:
192192
return cls(**cfg.model_dump())()

src/model_constructor/universal_blocks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Optional, Union
22

33
import torch
44
from torch import nn
@@ -32,10 +32,10 @@ def __init__(
3232
bn_1st: bool = True,
3333
groups: int = 1,
3434
dw: bool = False,
35-
div_groups: Union[None, int] = None,
36-
pool: Union[Callable[[], nn.Module], None] = None,
37-
se: Union[nn.Module, None] = None,
38-
sa: Union[nn.Module, None] = None,
35+
div_groups: Optional[int] = None,
36+
pool: Optional[Callable[[], nn.Module]] = None,
37+
se: Optional[nn.Module] = None,
38+
sa: Optional[nn.Module] = None,
3939
):
4040
super().__init__()
4141
# pool defined at ModelConstructor.
@@ -156,10 +156,10 @@ def __init__(
156156
bn_1st: bool = True,
157157
groups: int = 1,
158158
dw: bool = False,
159-
div_groups: Union[None, int] = None,
160-
pool: Union[Callable[[], nn.Module], None] = None,
161-
se: Union[type[nn.Module], None] = None,
162-
sa: Union[type[nn.Module], None] = None,
159+
div_groups: Optional[int] = None,
160+
pool: Optional[Callable[[], nn.Module]] = None,
161+
se: Optional[type[nn.Module]] = None,
162+
sa: Optional[type[nn.Module]] = None,
163163
):
164164
super().__init__()
165165
# pool defined at ModelConstructor.

src/model_constructor/yaresnet.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def __init__(
3838
bn_1st: bool = True,
3939
groups: int = 1,
4040
dw: bool = False,
41-
div_groups: Union[None, int] = None,
42-
pool: Union[Callable[[], nn.Module], None] = None,
43-
se: Union[nn.Module, None] = None,
44-
sa: Union[nn.Module, None] = None,
41+
div_groups: Optional[int] = None,
42+
pool: Optional[Callable[[], nn.Module]] = None,
43+
se: Optional[nn.Module] = None,
44+
sa: Optional[nn.Module] = None,
4545
):
4646
super().__init__()
4747
# pool defined at ModelConstructor.
@@ -123,10 +123,10 @@ def __init__(
123123
bn_1st: bool = True,
124124
groups: int = 1,
125125
dw: bool = False,
126-
div_groups: Union[None, int] = None,
127-
pool: Union[Callable[[], nn.Module], None] = None,
128-
se: Union[nn.Module, None] = None,
129-
sa: Union[nn.Module, None] = None,
126+
div_groups: Optional[int] = None,
127+
pool: Optional[Callable[[], nn.Module]] = None,
128+
se: Optional[nn.Module] = None,
129+
sa: Optional[nn.Module] = None,
130130
):
131131
super().__init__()
132132
# pool defined at ModelConstructor.

0 commit comments

Comments
 (0)