Skip to content

Commit b8e855b

Browse files
author
Mark-ZhouWX
committed
add sa-1b dataset
1 parent b3a2204 commit b8e855b

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: False
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
amp_level: O2
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
drop_overflow_update: False
19+
20+
optimizer:
21+
type: segment_anything.optim.optimizer.AdamW
22+
weight_decay: 1e-4
23+
group_param:
24+
25+
lr_scheduler:
26+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
27+
learning_rate: 8e-6
28+
warmup_steps: 250
29+
decay_steps: [ 60000, 86666 ]
30+
decay_factor: 10
31+
32+
33+
network:
34+
model:
35+
type: vit_b
36+
checkpoint: ./models/sam_vit_b-35e4849c.ckpt
37+
freeze:
38+
image_encoder: True
39+
prompt_encoder: True
40+
41+
loss:
42+
type: segment_anything.modeling.loss.SAMLoss
43+
44+
45+
train_loader:
46+
dataset:
47+
type: segment_anything.dataset.dataset.SA1BDataset
48+
data_dir: ./datasets/sa-1b/
49+
transform_pipeline:
50+
- type: segment_anything.dataset.transform.ImageResizeAndPad
51+
target_size: 1024
52+
- type: segment_anything.dataset.transform.ImageNorm
53+
hwc2chw: True
54+
- type: segment_anything.dataset.transform.LabelPad
55+
gt_size: 20
56+
output_column: ['image', 'masks', 'boxes', 'valid_boxes']
57+
58+
model_column: ['image', 'boxes'] # columns for model cell input
59+
loss_column: ['masks', 'valid_boxes'] # columns for loss function input
60+
61+
shuffle: True
62+
batch_size: 1
63+
epoch_size: 8
64+
drop_remainder: True
65+
num_workers: 2
66+
max_rowsize: 64 # 24M space for dataloader
67+
68+
69+
callback:
70+
- type: segment_anything.utils.callbacks.TrainStatusLog
71+
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
72+
interval: 100
73+
- type: segment_anything.utils.callbacks.SaveCkpt
74+
work_root: *work_root
75+
interval: 1 # in epoch
76+

official/cv/segment-anything/segment_anything/dataset/dataset.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from typing import List
34

@@ -6,6 +7,7 @@
67
from mindspore.dataset import GeneratorDataset, BatchDataset
78

89
from pycocotools.coco import COCO
10+
from pycocotools import mask as maskUtils
911

1012
from segment_anything.dataset.transform import create_transform_pipeline
1113
from segment_anything.utils import logger
@@ -142,3 +144,98 @@ def __getitem__(self, idx):
142144
self.output_column = list(data_dict.key())
143145

144146
return tuple(data_dict[k] for k in self.output_column)
147+
148+
149+
@DATASET_REGISTRY.registry_module()
150+
class SA1BDataset:
151+
152+
def __init__(self,
153+
data_dir,
154+
transform_pipeline,
155+
output_column: List[str] = None,
156+
**kwargs,
157+
):
158+
self.data_dir = data_dir
159+
self.output_column = output_column
160+
self.transform_pipeline = create_transform_pipeline(transform_pipeline)
161+
assert os.path.exists(data_dir), f'SA-1B dataset root not exists at {data_dir}'
162+
parts = sorted(os.listdir(data_dir)) # there are about 11K jpgs in each part
163+
164+
image_paths = []
165+
anno_paths = []
166+
for p in parts:
167+
part_dir = os.path.join(data_dir, p)
168+
all_files = [os.path.join(part_dir, f) for f in sorted(os.listdir(part_dir))]
169+
image_paths += list(filter(lambda f: f.endswith('.jpg'), all_files))
170+
anno_paths += list(filter(lambda f: f.endswith('.json'), all_files))
171+
assert len(image_paths) == len(anno_paths)
172+
173+
self.image_paths = image_paths
174+
self.anno_paths = anno_paths
175+
176+
logger.info(f'got {len(parts)} parts of SA-1B dateset, total size: {len(self.image_paths)}')
177+
178+
def __len__(self):
179+
return len(self.image_paths)
180+
181+
def __getitem__(self, idx):
182+
"""
183+
Below is an example describing the format of sa-1b dataset
184+
annotations:
185+
- area: 1418,
186+
bbox: [1134.0, 119.0, 30.0, 58.0]
187+
crop_box: [622.0, 0.0, 567.0, 707.0] # the sub part of image where the mask is generated, see automatic_mask_generator.py for details
188+
point_coords: [1153.5625, 132.5625]
189+
predicted_iou: 0.8891242146492
190+
segmentation:
191+
counts: "`]YW23SP2`0D;F:F8H3M1O100O1O1O10000O1100O001O001O1O010O1O001O1^OYQN\\Oin16iQNEXn14X1FbiQe0"
192+
size: [2060, 1500]
193+
stability_score: 0.960608184337616
194+
- area:
195+
box:
196+
xxx:
197+
]
198+
image:
199+
file_name: "sa_1.jpg"
200+
height: 2060
201+
image_id: 1
202+
width: 1500
203+
Returns:
204+
a tuple of transformed input items
205+
206+
"""
207+
anno_path = self.anno_paths[idx]
208+
image_path = self.image_paths[idx]
209+
210+
assert os.path.exists(image_path), f'image file not found at {image_path}'
211+
assert os.path.exists(anno_path), f'anno file not found at {anno_path}'
212+
213+
image = cv2.imread(image_path)
214+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
215+
216+
with open(anno_path, 'r') as f:
217+
json_data = json.load(f)
218+
anno_list = json_data['annotations']
219+
220+
boxes = []
221+
masks = []
222+
for anno in anno_list:
223+
x, y, w, h = anno['bbox']
224+
mask = maskUtils.decode(anno['segmentation']) # uint8
225+
226+
# filter small mask
227+
image_h, image_w = anno['segmentation']['size']
228+
if w / image_w < 0.1 and h / image_h < 0.1:
229+
continue
230+
231+
boxes.append([x, y, x + w, y + h])
232+
masks.append(mask)
233+
234+
# letter box
235+
data_dict = dict(image=image, masks=masks, boxes=np.array(boxes, np.float32))
236+
data_dict = self.transform_pipeline(data_dict)
237+
238+
if self.output_column is None:
239+
self.output_column = list(data_dict.key())
240+
241+
return tuple(data_dict[k] for k in self.output_column)

official/cv/segment-anything/segment_anything/dataset/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __call__(self, result_dict):
7777
if False: # show image and mask for debug
7878
import matplotlib.pyplot as plt
7979
plt.imshow(result_dict['image']) # raw image
80-
from use_sam_with_promts import show_box, show_mask
80+
from segment_anything.utils.visualize import show_box, show_mask
8181
show_box(result_dict['boxes'][0], plt.gca())
8282
show_mask(result_dict['masks'][0], plt.gca())
8383
plt.show()

0 commit comments

Comments
 (0)