Skip to content

Commit e15c388

Browse files
committed
Defaul lambda r=7. Define '26t' stage 4/5 256x256 variants for all of bot/halo/lambda nets for experiment. Add resnet50t for exp. Fix a few comments.
1 parent d15ad3e commit e15c388

File tree

4 files changed

+45
-9
lines changed

4 files changed

+45
-9
lines changed

timm/models/byoanet.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ def _cfg(url='', **kwargs):
4545

4646
default_cfgs = {
4747
# GPU-Efficient (ResNet) weights
48+
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256)),
4849
'botnet50t_224': _cfg(url='', fixed_input_size=True),
4950
'botnet50t_c4c5_224': _cfg(url='', fixed_input_size=True),
5051

5152
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
5253
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
53-
'halonet26t': _cfg(url=''),
54+
'halonet26t': _cfg(url='', input_size=(3, 256, 256)),
5455
'halonet50t': _cfg(url=''),
5556

56-
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128)),
57+
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256)),
5758
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
5859
}
5960

@@ -92,6 +93,21 @@ def interleave_attn(
9293

9394
model_cfgs = dict(
9495

96+
botnet26t=ByoaCfg(
97+
blocks=(
98+
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
99+
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
100+
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
101+
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
102+
),
103+
stem_chs=64,
104+
stem_type='tiered',
105+
stem_pool='maxpool',
106+
num_features=0,
107+
self_attn_layer='bottleneck',
108+
self_attn_fixed_size=True,
109+
self_attn_kwargs=dict()
110+
),
95111
botnet50t=ByoaCfg(
96112
blocks=(
97113
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
@@ -161,15 +177,15 @@ def interleave_attn(
161177
blocks=(
162178
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
163179
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
164-
ByoaBlocksCfg(type='bottle', d=2, c=1024, s=2, gs=0, br=0.25),
180+
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
165181
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
166182
),
167183
stem_chs=64,
168184
stem_type='tiered',
169185
stem_pool='maxpool',
170186
num_features=0,
171187
self_attn_layer='halo',
172-
self_attn_kwargs=dict(block_size=7, halo_size=2)
188+
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
173189
),
174190
halonet50t=ByoaCfg(
175191
blocks=(
@@ -370,6 +386,14 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
370386
**kwargs)
371387

372388

389+
@register_model
390+
def botnet26t_256(pretrained=False, **kwargs):
391+
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
392+
"""
393+
kwargs.setdefault('img_size', 256)
394+
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
395+
396+
373397
@register_model
374398
def botnet50t_224(pretrained=False, **kwargs):
375399
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage.

timm/models/layers/halo_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115
self.win_size = block_size + halo_size * 2 # neighbourhood window size
116116
self.scale = self.dim_head ** -0.5
117117

118-
# FIXME not clear if this stride behaviour is what the paper intended, not really clear
118+
# FIXME not clear if this stride behaviour is what the paper intended
119119
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
120120
# data in unfolded block form. I haven't wrapped my head around how that'd look.
121121
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
@@ -139,10 +139,10 @@ def forward(self, x):
139139

140140
kv = self.kv(x)
141141
# FIXME I 'think' this unfold does what I want it to, but I should investigate
142-
k = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
143-
k = k.reshape(
142+
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
143+
kv = kv.reshape(
144144
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
145-
k, v = torch.split(k, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
145+
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
146146

147147
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
148148
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2

timm/models/layers/lambda_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class LambdaLayer(nn.Module):
3434
"""
3535
def __init__(
3636
self,
37-
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=5, qkv_bias=False):
37+
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
3838
super().__init__()
3939
self.dim_out = dim_out or dim
4040
self.dim_k = dim_head # query depth 'k'

timm/models/resnet.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def _cfg(url='', **kwargs):
5454
'resnet50d': _cfg(
5555
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
5656
interpolation='bicubic', first_conv='conv1.0'),
57+
'resnet50t': _cfg(
58+
url='',
59+
interpolation='bicubic', first_conv='conv1.0'),
5760
'resnet101': _cfg(url='', interpolation='bicubic'),
5861
'resnet101d': _cfg(
5962
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
@@ -706,6 +709,15 @@ def resnet50d(pretrained=False, **kwargs):
706709
return _create_resnet('resnet50d', pretrained, **model_args)
707710

708711

712+
@register_model
713+
def resnet50t(pretrained=False, **kwargs):
714+
"""Constructs a ResNet-50-T model.
715+
"""
716+
model_args = dict(
717+
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
718+
return _create_resnet('resnet50t', pretrained, **model_args)
719+
720+
709721
@register_model
710722
def resnet101(pretrained=False, **kwargs):
711723
"""Constructs a ResNet-101 model.

0 commit comments

Comments
 (0)