Skip to content

Commit 3193931

Browse files
committed
Added SelecSLS models
2 parents 32012a4 + 1f4498f commit 3193931

39 files changed

+5247
-3056
lines changed

README.md

Lines changed: 170 additions & 68 deletions
Large diffs are not rendered by default.

clean_checkpoint.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22
import argparse
33
import os
44
import hashlib
5+
import shutil
56
from collections import OrderedDict
67

78
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
89
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
910
help='path to latest checkpoint (default: none)')
10-
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
11+
parser.add_argument('--output', default='', type=str, metavar='PATH',
1112
help='output path')
1213
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
1314
help='use ema version of weights if present')
1415

1516

17+
_TEMP_NAME = './_checkpoint.pth'
18+
19+
1620
def main():
1721
args = parser.parse_args()
1822

@@ -31,19 +35,27 @@ def main():
3135
if state_dict_key in checkpoint:
3236
state_dict = checkpoint[state_dict_key]
3337
else:
34-
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
35-
exit(1)
38+
state_dict = checkpoint
3639
else:
37-
state_dict = checkpoint
40+
assert False
3841
for k, v in state_dict.items():
3942
name = k[7:] if k.startswith('module') else k
4043
new_state_dict[name] = v
4144
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
4245

43-
torch.save(new_state_dict, args.output)
44-
with open(args.output, 'rb') as f:
46+
torch.save(new_state_dict, _TEMP_NAME)
47+
with open(_TEMP_NAME, 'rb') as f:
4548
sha_hash = hashlib.sha256(f.read()).hexdigest()
46-
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
49+
50+
if args.output:
51+
checkpoint_root, checkpoint_base = os.path.split(args.output)
52+
checkpoint_base = os.path.splitext(checkpoint_base)[0]
53+
else:
54+
checkpoint_root = ''
55+
checkpoint_base = os.path.splitext(args.checkpoint)[0]
56+
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
57+
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
58+
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
4759
else:
4860
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
4961

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
torch>=1.1.0
2-
torchvision>=0.3.0
1+
torch>=1.2.0
2+
torchvision>=0.4.0
33
pyyaml

results/results-all.csv

Lines changed: 153 additions & 95 deletions
Large diffs are not rendered by default.

results/results-inv2-matched-frequency.csv

Lines changed: 152 additions & 94 deletions
Large diffs are not rendered by default.

sotabench.py

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
7878
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
7979
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
8080
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
81-
_entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244',
81+
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
8282
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
8383
'paper as closely as possible.'),
8484
_entry('resnet18', 'ResNet-18', '1812.01187'),
@@ -108,11 +108,35 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
108108
model_desc='Ported from official Google AI Tensorflow weights'),
109109
_entry('tf_efficientnet_b4', 'EfficientNet-B4 (AutoAugment)', '1905.11946', batch_size=BATCH_SIZE//2,
110110
model_desc='Ported from official Google AI Tensorflow weights'),
111-
_entry('tf_efficientnet_b5', 'EfficientNet-B5 (AutoAugment)', '1905.11946', batch_size=BATCH_SIZE//4,
111+
_entry('tf_efficientnet_b5', 'EfficientNet-B5 (RandAugment)', '1905.11946', batch_size=BATCH_SIZE//4,
112112
model_desc='Ported from official Google AI Tensorflow weights'),
113113
_entry('tf_efficientnet_b6', 'EfficientNet-B6 (AutoAugment)', '1905.11946', batch_size=BATCH_SIZE//8,
114114
model_desc='Ported from official Google AI Tensorflow weights'),
115-
_entry('tf_efficientnet_b7', 'EfficientNet-B7 (AutoAugment)', '1905.11946', batch_size=BATCH_SIZE//8,
115+
_entry('tf_efficientnet_b7', 'EfficientNet-B7 (RandAugment)', '1905.11946', batch_size=BATCH_SIZE//8,
116+
model_desc='Ported from official Google AI Tensorflow weights'),
117+
_entry('tf_efficientnet_b0_ap', 'EfficientNet-B0 (AdvProp)', '1911.09665',
118+
model_desc='Ported from official Google AI Tensorflow weights'),
119+
_entry('tf_efficientnet_b1_ap', 'EfficientNet-B1 (AdvProp)', '1911.09665',
120+
model_desc='Ported from official Google AI Tensorflow weights'),
121+
_entry('tf_efficientnet_b2_ap', 'EfficientNet-B2 (AdvProp)', '1911.09665',
122+
model_desc='Ported from official Google AI Tensorflow weights'),
123+
_entry('tf_efficientnet_b3_ap', 'EfficientNet-B3 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2,
124+
model_desc='Ported from official Google AI Tensorflow weights'),
125+
_entry('tf_efficientnet_b4_ap', 'EfficientNet-B4 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 2,
126+
model_desc='Ported from official Google AI Tensorflow weights'),
127+
_entry('tf_efficientnet_b5_ap', 'EfficientNet-B5 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 4,
128+
model_desc='Ported from official Google AI Tensorflow weights'),
129+
_entry('tf_efficientnet_b6_ap', 'EfficientNet-B6 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
130+
model_desc='Ported from official Google AI Tensorflow weights'),
131+
_entry('tf_efficientnet_b7_ap', 'EfficientNet-B7 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
132+
model_desc='Ported from official Google AI Tensorflow weights'),
133+
_entry('tf_efficientnet_b8_ap', 'EfficientNet-B8 (AdvProp)', '1911.09665', batch_size=BATCH_SIZE // 8,
134+
model_desc='Ported from official Google AI Tensorflow weights'),
135+
_entry('tf_efficientnet_cc_b0_4e', 'EfficientNet-CondConv-B0 4 experts', '1904.04971',
136+
model_desc='Ported from official Google AI Tensorflow weights'),
137+
_entry('tf_efficientnet_cc_b0_8e', 'EfficientNet-CondConv-B0 8 experts', '1904.04971',
138+
model_desc='Ported from official Google AI Tensorflow weights'),
139+
_entry('tf_efficientnet_cc_b1_8e', 'EfficientNet-CondConv-B1 8 experts', '1904.04971',
116140
model_desc='Ported from official Google AI Tensorflow weights'),
117141
_entry('tf_efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
118142
model_desc='Ported from official Google AI Tensorflow weights'),
@@ -124,6 +148,18 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
124148
_entry('tf_mixnet_l', 'MixNet-L', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
125149
_entry('tf_mixnet_m', 'MixNet-M', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
126150
_entry('tf_mixnet_s', 'MixNet-S', '1907.09595', model_desc='Ported from official Google AI Tensorflow weights'),
151+
_entry('tf_mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
152+
model_desc='Ported from official Google AI Tensorflow weights'),
153+
_entry('tf_mobilenetv3_large_075', 'MobileNet V3-Large 0.75', '1905.02244',
154+
model_desc='Ported from official Google AI Tensorflow weights'),
155+
_entry('tf_mobilenetv3_large_minimal_100', 'MobileNet V3-Large Minimal 1.0', '1905.02244',
156+
model_desc='Ported from official Google AI Tensorflow weights'),
157+
_entry('tf_mobilenetv3_small_100', 'MobileNet V3-Small 1.0', '1905.02244',
158+
model_desc='Ported from official Google AI Tensorflow weights'),
159+
_entry('tf_mobilenetv3_small_075', 'MobileNet V3-Small 0.75', '1905.02244',
160+
model_desc='Ported from official Google AI Tensorflow weights'),
161+
_entry('tf_mobilenetv3_small_minimal_100', 'MobileNet V3-Small Minimal 1.0', '1905.02244',
162+
model_desc='Ported from official Google AI Tensorflow weights'),
127163

128164
## Cadene ported weights (to remove if Cadene adds sotabench)
129165
_entry('inception_resnet_v2', 'Inception ResNet V2', '1602.07261'),
@@ -154,18 +190,87 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
154190
# _entry('wide_resnet101_2', , ),
155191

156192
## Facebook WSL weights
157-
_entry('ig_resnext101_32x8d', 'ResNeXt-101 32x8d', '1805.00932'),
158-
_entry('ig_resnext101_32x16d', 'ResNeXt-101 32x16d', '1805.00932'),
159-
_entry('ig_resnext101_32x32d', 'ResNeXt-101 32x32d', '1805.00932', batch_size=BATCH_SIZE // 2),
160-
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d', '1805.00932', batch_size=BATCH_SIZE // 4),
193+
_entry('ig_resnext101_32x8d', 'ResNeXt-101 32x8d', '1805.00932',
194+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
195+
_entry('ig_resnext101_32x16d', 'ResNeXt-101 32x16d', '1805.00932',
196+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
197+
_entry('ig_resnext101_32x32d', 'ResNeXt-101 32x32d', '1805.00932', batch_size=BATCH_SIZE // 2,
198+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
199+
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d', '1805.00932', batch_size=BATCH_SIZE // 4,
200+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
201+
161202
_entry('ig_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1805.00932',
162-
ttp=True, args=dict(img_size=288)),
203+
ttp=True, args=dict(img_size=288),
204+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
163205
_entry('ig_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1805.00932',
164-
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2),
206+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
207+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
165208
_entry('ig_resnext101_32x32d', 'ResNeXt-101 32x32d (288x288 Mean-Max Pooling)', '1805.00932',
166-
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 4),
209+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 4,
210+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
167211
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d (288x288 Mean-Max Pooling)', '1805.00932',
168-
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8),
212+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8,
213+
model_desc='Weakly-Supervised pre-training on 1B Instagram hashtag dataset by Facebook Research'),
214+
215+
## Facebook SSL weights
216+
_entry('ssl_resnet18', 'ResNet-18', '1905.00546',
217+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
218+
_entry('ssl_resnet50', 'ResNet-50', '1905.00546',
219+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
220+
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
221+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
222+
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
223+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
224+
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
225+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
226+
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546',
227+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
228+
229+
_entry('ssl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
230+
ttp=True, args=dict(img_size=288),
231+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
232+
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
233+
ttp=True, args=dict(img_size=288),
234+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
235+
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
236+
ttp=True, args=dict(img_size=288),
237+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
238+
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
239+
ttp=True, args=dict(img_size=288),
240+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
241+
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
242+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
243+
model_desc='Semi-Supervised pre-training on YFCC100M dataset by Facebook Research'),
244+
245+
## Facebook SWSL weights
246+
_entry('swsl_resnet18', 'ResNet-18', '1905.00546',
247+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
248+
_entry('swsl_resnet50', 'ResNet-50', '1905.00546',
249+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
250+
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
251+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
252+
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
253+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
254+
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
255+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
256+
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546',
257+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
258+
259+
_entry('swsl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
260+
ttp=True, args=dict(img_size=288),
261+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
262+
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
263+
ttp=True, args=dict(img_size=288),
264+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
265+
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
266+
ttp=True, args=dict(img_size=288),
267+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
268+
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
269+
ttp=True, args=dict(img_size=288),
270+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
271+
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
272+
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
273+
model_desc='Semi-Weakly-Supervised pre-training on 1 billion unlabelled dataset by Facebook Research'),
169274

170275
## DLA official impl weights (to remove if sotabench added to source)
171276
_entry('dla34', 'DLA-34', '1707.06484'),
@@ -189,6 +294,17 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
189294
_entry('res2next50', 'Res2NeXt-50', '1904.01169'),
190295
_entry('dla60_res2net', 'Res2Net-DLA-60', '1904.01169'),
191296
_entry('dla60_res2next', 'Res2NeXt-DLA-60', '1904.01169'),
297+
298+
## HRNet official impl weights
299+
_entry('hrnet_w18_small', 'HRNet-W18-C-Small-V1', '1908.07919'),
300+
_entry('hrnet_w18_small_v2', 'HRNet-W18-C-Small-V2', '1908.07919'),
301+
_entry('hrnet_w18', 'HRNet-W18-C', '1908.07919'),
302+
_entry('hrnet_w30', 'HRNet-W30-C', '1908.07919'),
303+
_entry('hrnet_w32', 'HRNet-W32-C', '1908.07919'),
304+
_entry('hrnet_w40', 'HRNet-W40-C', '1908.07919'),
305+
_entry('hrnet_w44', 'HRNet-W44-C', '1908.07919'),
306+
_entry('hrnet_w48', 'HRNet-W48-C', '1908.07919'),
307+
_entry('hrnet_w64', 'HRNet-W64-C', '1908.07919'),
192308
]
193309

194310
for m in model_list:

timm/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
from .transforms import *
55
from .loader import create_loader, create_transform
66
from .mixup import mixup_target, FastCollateMixup
7+
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
8+
rand_augment_transform, auto_augment_transform

0 commit comments

Comments
 (0)