|
| 1 | +import json |
1 | 2 | import os |
2 | 3 | from typing import List |
3 | 4 |
|
|
6 | 7 | from mindspore.dataset import GeneratorDataset, BatchDataset |
7 | 8 |
|
8 | 9 | from pycocotools.coco import COCO |
| 10 | +from pycocotools import mask as maskUtils |
9 | 11 |
|
10 | 12 | from segment_anything.dataset.transform import create_transform_pipeline |
11 | 13 | from segment_anything.utils import logger |
@@ -142,3 +144,98 @@ def __getitem__(self, idx): |
142 | 144 | self.output_column = list(data_dict.key()) |
143 | 145 |
|
144 | 146 | return tuple(data_dict[k] for k in self.output_column) |
| 147 | + |
| 148 | + |
| 149 | +@DATASET_REGISTRY.registry_module() |
| 150 | +class SA1BDataset: |
| 151 | + |
| 152 | + def __init__(self, |
| 153 | + data_dir, |
| 154 | + transform_pipeline, |
| 155 | + output_column: List[str] = None, |
| 156 | + **kwargs, |
| 157 | + ): |
| 158 | + self.data_dir = data_dir |
| 159 | + self.output_column = output_column |
| 160 | + self.transform_pipeline = create_transform_pipeline(transform_pipeline) |
| 161 | + assert os.path.exists(data_dir), f'SA-1B dataset root not exists at {data_dir}' |
| 162 | + parts = sorted(os.listdir(data_dir)) # there are about 11K jpgs in each part |
| 163 | + |
| 164 | + image_paths = [] |
| 165 | + anno_paths = [] |
| 166 | + for p in parts: |
| 167 | + part_dir = os.path.join(data_dir, p) |
| 168 | + all_files = [os.path.join(part_dir, f) for f in sorted(os.listdir(part_dir))] |
| 169 | + image_paths += list(filter(lambda f: f.endswith('.jpg'), all_files)) |
| 170 | + anno_paths += list(filter(lambda f: f.endswith('.json'), all_files)) |
| 171 | + assert len(image_paths) == len(anno_paths) |
| 172 | + |
| 173 | + self.image_paths = image_paths |
| 174 | + self.anno_paths = anno_paths |
| 175 | + |
| 176 | + logger.info(f'got {len(parts)} parts of SA-1B dateset, total size: {len(self.image_paths)}') |
| 177 | + |
| 178 | + def __len__(self): |
| 179 | + return len(self.image_paths) |
| 180 | + |
| 181 | + def __getitem__(self, idx): |
| 182 | + """ |
| 183 | + Below is an example describing the format of sa-1b dataset |
| 184 | + annotations: |
| 185 | + - area: 1418, |
| 186 | + bbox: [1134.0, 119.0, 30.0, 58.0] |
| 187 | + crop_box: [622.0, 0.0, 567.0, 707.0] # the sub part of image where the mask is generated, see automatic_mask_generator.py for details |
| 188 | + point_coords: [1153.5625, 132.5625] |
| 189 | + predicted_iou: 0.8891242146492 |
| 190 | + segmentation: |
| 191 | + counts: "`]YW23SP2`0D;F:F8H3M1O100O1O1O10000O1100O001O001O1O010O1O001O1^OYQN\\Oin16iQNEXn14X1FbiQe0" |
| 192 | + size: [2060, 1500] |
| 193 | + stability_score: 0.960608184337616 |
| 194 | + - area: |
| 195 | + box: |
| 196 | + xxx: |
| 197 | + ] |
| 198 | + image: |
| 199 | + file_name: "sa_1.jpg" |
| 200 | + height: 2060 |
| 201 | + image_id: 1 |
| 202 | + width: 1500 |
| 203 | + Returns: |
| 204 | + a tuple of transformed input items |
| 205 | +
|
| 206 | + """ |
| 207 | + anno_path = self.anno_paths[idx] |
| 208 | + image_path = self.image_paths[idx] |
| 209 | + |
| 210 | + assert os.path.exists(image_path), f'image file not found at {image_path}' |
| 211 | + assert os.path.exists(anno_path), f'anno file not found at {anno_path}' |
| 212 | + |
| 213 | + image = cv2.imread(image_path) |
| 214 | + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 215 | + |
| 216 | + with open(anno_path, 'r') as f: |
| 217 | + json_data = json.load(f) |
| 218 | + anno_list = json_data['annotations'] |
| 219 | + |
| 220 | + boxes = [] |
| 221 | + masks = [] |
| 222 | + for anno in anno_list: |
| 223 | + x, y, w, h = anno['bbox'] |
| 224 | + mask = maskUtils.decode(anno['segmentation']) # uint8 |
| 225 | + |
| 226 | + # filter small mask |
| 227 | + image_h, image_w = anno['segmentation']['size'] |
| 228 | + if w / image_w < 0.1 and h / image_h < 0.1: |
| 229 | + continue |
| 230 | + |
| 231 | + boxes.append([x, y, x + w, y + h]) |
| 232 | + masks.append(mask) |
| 233 | + |
| 234 | + # letter box |
| 235 | + data_dict = dict(image=image, masks=masks, boxes=np.array(boxes, np.float32)) |
| 236 | + data_dict = self.transform_pipeline(data_dict) |
| 237 | + |
| 238 | + if self.output_column is None: |
| 239 | + self.output_column = list(data_dict.key()) |
| 240 | + |
| 241 | + return tuple(data_dict[k] for k in self.output_column) |
0 commit comments