Skip to content

Commit d349cf0

Browse files
committed
in_chans inConvMixer
1 parent 3ad183d commit d349cf0

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/model_constructor/convmixer.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable
6+
from typing import Callable, Optional
77
import torch.nn as nn
88

99

@@ -43,9 +43,18 @@ def ConvMixerOriginal(dim, depth,
4343
class ConvLayer(nn.Sequential):
4444
"""Basic conv layers block"""
4545

46-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
47-
act_fn=nn.GELU(), padding=0, groups=1,
48-
bn_1st=False, pre_act=False):
46+
def __init__(
47+
self,
48+
in_channels,
49+
out_channels,
50+
kernel_size,
51+
stride=1,
52+
act_fn=nn.GELU(),
53+
padding=0,
54+
groups=1,
55+
bn_1st=False,
56+
pre_act=False
57+
):
4958

5059
conv_layer = [('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
5160
padding=padding, groups=groups))]
@@ -65,12 +74,20 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
6574

6675
class ConvMixer(nn.Sequential):
6776

68-
def __init__(self, dim: int, depth: int,
69-
kernel_size: int = 9, patch_size: int = 7, n_classes: int = 1000,
70-
act_fn: nn.Module = nn.GELU(),
71-
stem: nn.Module = None,
72-
bn_1st: bool = False, pre_act: bool = False,
73-
init_func: Callable = None):
77+
def __init__(
78+
self,
79+
dim: int,
80+
depth: int,
81+
kernel_size: int = 9,
82+
patch_size: int = 7,
83+
n_classes: int = 1000,
84+
act_fn: nn.Module = nn.GELU(),
85+
stem: Optional[nn.Module] = None,
86+
in_chans: int = 3,
87+
bn_1st: bool = False,
88+
pre_act: bool = False,
89+
init_func: Optional[Callable] = None
90+
):
7491
"""ConvMixer constructor.
7592
Adopted from https://github.com/tmp-iclr/convmixer
7693
@@ -91,7 +108,7 @@ def __init__(self, dim: int, depth: int,
91108
if pre_act:
92109
bn_1st = False
93110
if stem is None:
94-
stem = ConvLayer(3, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)
111+
stem = ConvLayer(in_chans, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)
95112

96113
super().__init__(
97114
stem,
@@ -100,7 +117,7 @@ def __init__(self, dim: int, depth: int,
100117
ConvLayer(dim, dim, kernel_size, act_fn=act_fn,
101118
groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)),
102119
ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act))
103-
for i in range(depth)],
120+
for _ in range(depth)],
104121
nn.AdaptiveAvgPool2d((1, 1)),
105122
nn.Flatten(),
106123
nn.Linear(dim, n_classes))

0 commit comments

Comments
 (0)