|
1 | 1 | from torchbench.image_classification import ImageNet |
2 | 2 | from timm import create_model, list_models |
3 | 3 | from timm.data import resolve_data_config, create_transform |
| 4 | +import os |
4 | 5 |
|
5 | 6 | NUM_GPU = 1 |
6 | 7 | BATCH_SIZE = 256 * NUM_GPU |
@@ -86,13 +87,13 @@ def _attrib(paper_model_name='', paper_arxiv_id='', batch_size=BATCH_SIZE): |
86 | 87 | gluon_xception65=_attrib( |
87 | 88 | paper_model_name='Modified Aligned Xception', paper_arxiv_id='1802.02611', batch_size=BATCH_SIZE//2), |
88 | 89 | ig_resnext101_32x8d=_attrib( |
89 | | - paper_model_name='ResNeXt-101 32×8d', paper_arxiv_id='1805.00932'), |
| 90 | + paper_model_name='ResNeXt-101 32x8d', paper_arxiv_id='1805.00932'), |
90 | 91 | ig_resnext101_32x16d=_attrib( |
91 | | - paper_model_name='ResNeXt-101 32×16d', paper_arxiv_id='1805.00932'), |
| 92 | + paper_model_name='ResNeXt-101 32x16d', paper_arxiv_id='1805.00932'), |
92 | 93 | ig_resnext101_32x32d=_attrib( |
93 | | - paper_model_name='ResNeXt-101 32×32d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//2), |
| 94 | + paper_model_name='ResNeXt-101 32x32d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//2), |
94 | 95 | ig_resnext101_32x48d=_attrib( |
95 | | - paper_model_name='ResNeXt-101 32×48d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//4), |
| 96 | + paper_model_name='ResNeXt-101 32x48d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//4), |
96 | 97 | inception_resnet_v2=_attrib( |
97 | 98 | paper_model_name='Inception ResNet V2', paper_arxiv_id='1602.07261'), |
98 | 99 | #inception_v3=dict(paper_model_name='Inception V3', paper_arxiv_id=), # same weights as torchvision |
@@ -167,6 +168,12 @@ def _attrib(paper_model_name='', paper_arxiv_id='', batch_size=BATCH_SIZE): |
167 | 168 | paper_model_name='EfficientNet-B6', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//8), |
168 | 169 | tf_efficientnet_b7=_attrib( |
169 | 170 | paper_model_name='EfficientNet-B7', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//8), |
| 171 | + tf_efficientnet_es=_attrib( |
| 172 | + paper_model_name='EfficientNet-EdgeTPU-S', paper_arxiv_id='1905.11946'), |
| 173 | + tf_efficientnet_em=_attrib( |
| 174 | + paper_model_name='EfficientNet-EdgeTPU-M', paper_arxiv_id='1905.11946'), |
| 175 | + tf_efficientnet_el=_attrib( |
| 176 | + paper_model_name='EfficientNet-EdgeTPU-L', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//2), |
170 | 177 | tf_inception_v3=_attrib( |
171 | 178 | paper_model_name='Inception V3', paper_arxiv_id='1512.00567'), |
172 | 179 | tf_mixnet_l=_attrib( |
@@ -208,7 +215,7 @@ def _attrib(paper_model_name='', paper_arxiv_id='', batch_size=BATCH_SIZE): |
208 | 215 | input_transform=input_transform, |
209 | 216 | batch_size=model_map[model_name]['batch_size'], |
210 | 217 | num_gpu=NUM_GPU, |
211 | | - #data_root=DATA_ROOT |
| 218 | + data_root=os.environ.get('IMAGENET_DIR', './imagenet') |
212 | 219 | ) |
213 | 220 |
|
214 | 221 |
|
0 commit comments