Skip to content

Commit c3a2086

Browse files
committed
typing
1 parent d349cf0 commit c3a2086

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/model_constructor/convmixer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable, Optional
6+
from typing import Callable, Optional, Union
77
import torch.nn as nn
88

99

@@ -33,7 +33,7 @@ def ConvMixerOriginal(dim, depth,
3333
nn.Conv2d(dim, dim, kernel_size=1),
3434
act_fn,
3535
nn.BatchNorm2d(dim)
36-
) for i in range(depth)],
36+
) for _i in range(depth)],
3737
nn.AdaptiveAvgPool2d((1, 1)),
3838
nn.Flatten(),
3939
nn.Linear(dim, n_classes)
@@ -45,16 +45,16 @@ class ConvLayer(nn.Sequential):
4545

4646
def __init__(
4747
self,
48-
in_channels,
49-
out_channels,
50-
kernel_size,
51-
stride=1,
52-
act_fn=nn.GELU(),
53-
padding=0,
54-
groups=1,
55-
bn_1st=False,
56-
pre_act=False
57-
):
48+
in_channels: int,
49+
out_channels: int,
50+
kernel_size: Union[int, tuple[int, int]],
51+
stride: int = 1,
52+
act_fn: nn.Module = nn.GELU(),
53+
padding: Union[int, str] = 0,
54+
groups: int = 1,
55+
bn_1st: bool = False,
56+
pre_act: bool = False,
57+
):
5858

5959
conv_layer = [('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
6060
padding=padding, groups=groups))]
@@ -86,8 +86,8 @@ def __init__(
8686
in_chans: int = 3,
8787
bn_1st: bool = False,
8888
pre_act: bool = False,
89-
init_func: Optional[Callable] = None
90-
):
89+
init_func: Optional[Callable[[nn.Module], None]] = None
90+
):
9191
"""ConvMixer constructor.
9292
Adopted from https://github.com/tmp-iclr/convmixer
9393

0 commit comments

Comments
 (0)