Skip to content

Commit f37f9c2

Browse files
committed
Release code for iNaturalist 2018 (#197)
1 parent cfd2462 commit f37f9c2

File tree

8 files changed

+1069
-22
lines changed

8 files changed

+1069
-22
lines changed

classification/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,27 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/
167167

168168
</details>
169169

170+
<details>
171+
<summary>iNaturalist 2018</summary>
172+
173+
- For the iNaturalist 2018, please download the dataset from the [official repository](https://github.com/visipedia/inat_comp/blob/master/2018/README.md).
174+
The file structure should look like:
175+
176+
```bash
177+
$ tree inat2018/
178+
inat2018/
179+
├── categories.json
180+
├── test2018
181+
├── test2018.json
182+
├── train2018.json
183+
├── train2018_locations.json
184+
├── val2018
185+
├── val2018.json
186+
└── val2018_locations.json
187+
```
188+
189+
</details>
190+
170191
## Released Models
171192

172193
<details open>
@@ -204,6 +225,19 @@ We use standard ImageNet dataset, you can download it from http://image-net.org/
204225

205226
</details>
206227

228+
<details open>
229+
<summary> iNaturalist 2018 Image Classification </summary>
230+
<br>
231+
<div>
232+
233+
| name | pretrain | resolution | acc@1 | #param | download |
234+
| :-----------: | :--------: | :--------: | :---: | :----: | :-----------------------------------------------------------------------------: |
235+
| InternImage-H | Joint 427M | 384x384 | 92.6 | 1.1B | [ckpt](<>) \| [cfg](configs/inaturalist2018/internimage_h_22ktoinat18_384.yaml) |
236+
237+
</div>
238+
239+
</details>
240+
207241
## Evaluation
208242

209243
To evaluate a pretrained `InternImage` on ImageNet val, run:
Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
DATA:
22
IMG_SIZE: 384
3-
DATASET: inat18
43
IMG_ON_MEMORY: False
5-
DATA_PATH: "data/inat2018/"
4+
DATASET: inat18
65
AUG:
76
MIXUP: 0.0
87
CUTMIX: 0.0
8+
REPROB: 0.0
99
MODEL:
10-
PRETRAINED: './pretrained/internimage_h_jointto22k_384.pth'
11-
TYPE: intern_image_with_meta
12-
DROP_PATH_RATE: 0.2
10+
TYPE: intern_image_meta_former
11+
DROP_PATH_RATE: 0.6
1312
LABEL_SMOOTHING: 0.3
1413
INTERN_IMAGE:
1514
CORE_OP: 'DCNv3'
@@ -26,22 +25,22 @@ MODEL:
2625
LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29]
2726
CENTER_FEATURE_SCALE: True
2827
USE_CLIP_PROJECTOR: True
28+
PRETRAINED: 'pretrained/internimage_h_jointto22k_384.pth'
2929
TRAIN:
3030
EMA:
31-
ENABLE: false
32-
DECAY: 0.9998
31+
ENABLE: true
32+
DECAY: 0.9999
3333
EPOCHS: 100
3434
WARMUP_EPOCHS: 0
35-
WEIGHT_DECAY: 1e-8
36-
BASE_LR: 3e-05 # 512
37-
WARMUP_LR: 3e-08
38-
MIN_LR: 3e-07
35+
WEIGHT_DECAY: 0.05
36+
BASE_LR: 2e-05 # 512
37+
WARMUP_LR: .0
38+
MIN_LR: .0
3939
LR_LAYER_DECAY: true
40-
LR_LAYER_DECAY_RATIO: 0.8
40+
LR_LAYER_DECAY_RATIO: 0.9
41+
USE_CHECKPOINT: true
4142
RAND_INIT_FT_HEAD: true
42-
USE_CHECKPOINT: false
4343
OPTIMIZER:
44-
USE_ZERO: True
4544
DCN_LR_MUL: 0.1
4645
AMP_OPT_LEVEL: O0
4746
EVAL_FREQ: 1

classification/dataset/build.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from timm.data import Mixup, create_transform
1313
from torchvision import transforms
1414

15-
from .cached_image_folder import CachedImageFolder, ImageCephDataset
15+
from .cached_image_folder import (CachedImageFolder, ImageCephDataset,
16+
INat18ImageCephDataset,
17+
INat18ParserCephImage)
1618
from .samplers import NodeDistributedSampler, SubsetRandomSampler
1719

1820
try:
@@ -229,6 +231,15 @@ def build_dataset(split, config):
229231
root = os.path.join(config.DATA.DATA_PATH, 'val')
230232
dataset = ImageCephDataset(root, 'val', transform=transform)
231233
nb_classes = 1000
234+
elif config.DATA.DATASET == 'inat18':
235+
if prefix == 'train' and not config.EVAL_MODE:
236+
root = config.DATA.DATA_PATH
237+
dataset = INat18ImageCephDataset(
238+
root, 'train', transform=transform, on_memory=config.DATA.IMG_ON_MEMORY)
239+
elif prefix == 'val':
240+
root = config.DATA.DATA_PATH
241+
dataset = INat18ImageCephDataset(root, 'val', transform=transform)
242+
nb_classes = 8142
232243
else:
233244
raise NotImplementedError(
234245
f'build_dataset does support {config.DATA.DATASET}')

classification/dataset/cached_image_folder.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,55 @@ def filenames(self, basename=False, absolute=False):
340340
return self.parser.filenames(basename, absolute)
341341

342342

343+
class INat18ImageCephDataset(data.Dataset):
344+
345+
def __init__(self,
346+
root,
347+
split,
348+
parser=None,
349+
transform=None,
350+
target_transform=None,
351+
on_memory=False):
352+
if split == 'train':
353+
annotation_root = osp.join(root, 'train2018.json')
354+
elif split == 'val':
355+
annotation_root = osp.join(root, 'val2018.json')
356+
elif split == 'test':
357+
annotation_root = osp.join(root, 'test2018.json')
358+
if parser is None or isinstance(parser, str):
359+
parser = INat18ParserCephImage(root=root,
360+
split=split,
361+
annotation_root=annotation_root,
362+
on_memory=on_memory)
363+
self.parser = parser
364+
self.transform = transform
365+
self.target_transform = target_transform
366+
self._consecutive_errors = 0
367+
368+
def __getitem__(self, index):
369+
img, temporal_info, spatial_info, target = self.parser[index]
370+
self._consecutive_errors = 0
371+
if self.transform is not None:
372+
img = self.transform(img)
373+
if target is None:
374+
target = -1
375+
elif self.target_transform is not None:
376+
target = self.target_transform(target)
377+
temporal_info = torch.tensor(temporal_info).to(torch.float32)
378+
spatial_info = torch.tensor(spatial_info).to(torch.float32)
379+
380+
return [img, temporal_info, spatial_info], target
381+
382+
def __len__(self):
383+
return len(self.parser)
384+
385+
def filename(self, index, basename=False, absolute=False):
386+
return self.parser.filename(index, basename, absolute)
387+
388+
def filenames(self, basename=False, absolute=False):
389+
return self.parser.filenames(basename, absolute)
390+
391+
343392
class Parser:
344393

345394
def __init__(self):
@@ -372,7 +421,7 @@ def __init__(self,
372421
self.file_client = None
373422
self.kwargs = kwargs
374423

375-
self.root = root # dataset:s3://imagenet22k
424+
self.root = root
376425
if '22k' in root:
377426
self.io_backend = 'petrel'
378427
with open(osp.join(annotation_root, '22k_class_to_idx.json'),
@@ -497,7 +546,7 @@ def __getitem__(self, index):
497546
else:
498547
target = int(target)
499548
except:
500-
print('aaaaaaaaaaaa', filepath, target)
549+
print(filepath, target)
501550
exit()
502551

503552
return img, target
@@ -512,6 +561,87 @@ def _filename(self, index, basename=False, absolute=False):
512561
return filename
513562

514563

564+
class INat18ParserCephImage(Parser):
565+
566+
def __init__(self,
567+
root,
568+
split,
569+
annotation_root,
570+
on_memory=False,
571+
**kwargs):
572+
super().__init__()
573+
574+
self.file_client = None
575+
self.kwargs = kwargs
576+
self.split = split
577+
self.root = root
578+
579+
self.io_backend = 'disk'
580+
data = mmcv.load(annotation_root)
581+
582+
self.samples = data['annotations']
583+
self.file_names = [each['file_name'] for each in data['images']]
584+
self.meta_data = mmcv.load(
585+
annotation_root.replace('2018.json', '2018_locations.json'))
586+
587+
self.class_to_idx = {}
588+
for i, each in enumerate(data['categories']):
589+
self.class_to_idx[each['id']] = i
590+
self.on_memory = on_memory
591+
self._consecutive_errors = 0
592+
# TODO: support on_memory function
593+
594+
def __getitem__(self, index):
595+
if self.file_client is None:
596+
self.file_client = FileClient(self.io_backend, **self.kwargs)
597+
anns = self.samples[index]
598+
filename = self.file_names[index]
599+
img_id = anns['image_id']
600+
target = anns['category_id']
601+
602+
# load meta information from json file
603+
meta = self.meta_data[index]
604+
date = meta['date']
605+
latitude = meta['lat']
606+
longitude = meta['lon']
607+
location_uncertainty = meta['loc_uncert']
608+
temporal_info = get_temporal_info(date, miss_hour=True)
609+
spatial_info = get_spatial_info(latitude, longitude)
610+
611+
filepath = osp.join(self.root, filename)
612+
try:
613+
if self.on_memory:
614+
img_bytes = self.holder[filepath]
615+
else:
616+
img_bytes = self.file_client.get(filepath)
617+
img = mmcv.imfrombytes(img_bytes)[:, :, ::-1]
618+
619+
except Exception as e:
620+
_logger.warning(
621+
f'Skipped sample (index {index}, file {filepath}). {str(e)}')
622+
self._consecutive_errors += 1
623+
if self._consecutive_errors < _ERROR_RETRY:
624+
return self.__getitem__((index + 1) % len(self))
625+
else:
626+
raise e
627+
self._consecutive_errors = 0
628+
629+
img = Image.fromarray(img)
630+
if self.class_to_idx is not None:
631+
target = self.class_to_idx[target]
632+
else:
633+
target = int(target)
634+
return img, temporal_info, spatial_info, target
635+
636+
def __len__(self):
637+
return len(self.samples)
638+
639+
def _filename(self, index, basename=False, absolute=False):
640+
filename, _ = self.samples[index].split(' ')
641+
filename = osp.join(self.root, filename)
642+
return filename
643+
644+
515645
def get_temporal_info(date, miss_hour=False):
516646
try:
517647
if date:

classification/main.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def parse_option():
7474
type=str,
7575
help='dataset name',
7676
default=None)
77-
parser.add_argument('--data-path', type=str, help='path to dataset',
78-
default='data/imagenet')
77+
parser.add_argument('--data-path', type=str, help='path to dataset')
7978
parser.add_argument('--zip',
8079
action='store_true',
8180
help='use zipped dataset instead of folder dataset')
@@ -146,7 +145,10 @@ def throughput(data_loader, model, logger):
146145
model.eval()
147146

148147
for idx, (images, _) in enumerate(data_loader):
149-
images = images.cuda(non_blocking=True)
148+
if type(images) == list:
149+
images = [item.cuda(non_blocking=True) for item in images]
150+
else:
151+
images = images.cuda(non_blocking=True)
150152
batch_size = images.shape[0]
151153
for i in range(50):
152154
model(images)
@@ -403,7 +405,10 @@ def train_one_epoch(config,
403405
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16
404406
for idx, (samples, targets) in enumerate(data_loader):
405407
iter_begin_time = time.time()
406-
samples = samples.cuda(non_blocking=True)
408+
if type(samples) == list:
409+
samples = [item.cuda(non_blocking=True) for item in samples]
410+
else:
411+
samples = samples.cuda(non_blocking=True)
407412
targets = targets.cuda(non_blocking=True)
408413

409414
if mixup_fn is not None:
@@ -528,7 +533,10 @@ def validate(config, data_loader, model, epoch=None):
528533

529534
end = time.time()
530535
for idx, (images, target) in enumerate(data_loader):
531-
images = images.cuda(non_blocking=True)
536+
if type(images) == list:
537+
images = [item.cuda(non_blocking=True) for item in images]
538+
else:
539+
images = images.cuda(non_blocking=True)
532540
target = target.cuda(non_blocking=True)
533541
output = model(images)
534542

classification/models/build.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# --------------------------------------------------------
66

77
from .intern_image import InternImage
8+
from .intern_image_meta_former import InternImageMetaFormer
89

910

1011
def build_model(config):
@@ -30,6 +31,27 @@ def build_model(config):
3031
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
3132
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
3233
)
34+
elif model_type == 'intern_image_meta_former':
35+
model = InternImageMetaFormer(
36+
core_op=config.MODEL.INTERN_IMAGE.CORE_OP,
37+
num_classes=config.MODEL.NUM_CLASSES,
38+
channels=config.MODEL.INTERN_IMAGE.CHANNELS,
39+
depths=config.MODEL.INTERN_IMAGE.DEPTHS,
40+
groups=config.MODEL.INTERN_IMAGE.GROUPS,
41+
layer_scale=config.MODEL.INTERN_IMAGE.LAYER_SCALE,
42+
offset_scale=config.MODEL.INTERN_IMAGE.OFFSET_SCALE,
43+
post_norm=config.MODEL.INTERN_IMAGE.POST_NORM,
44+
mlp_ratio=config.MODEL.INTERN_IMAGE.MLP_RATIO,
45+
with_cp=config.TRAIN.USE_CHECKPOINT,
46+
drop_path_rate=config.MODEL.DROP_PATH_RATE,
47+
res_post_norm=config.MODEL.INTERN_IMAGE.RES_POST_NORM, # for InternImage-H/G
48+
dw_kernel_size=config.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
49+
use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
50+
level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
51+
level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
52+
center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
53+
remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
54+
)
3355
else:
3456
raise NotImplementedError(f'Unkown model: {model_type}')
3557

0 commit comments

Comments
 (0)