1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""YOLOV3 dataset.""" 16import os 17 18import multiprocessing 19import cv2 20from PIL import Image 21from pycocotools.coco import COCO 22import mindspore.dataset as de 23import mindspore.dataset.vision.c_transforms as CV 24 25from src.distributed_sampler import DistributedSampler 26from src.transforms import reshape_fn, MultiScaleTrans 27 28 29min_keypoints_per_image = 10 30 31 32def _has_only_empty_bbox(anno): 33 return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 34 35 36def _count_visible_keypoints(anno): 37 return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 38 39 40def has_valid_annotation(anno): 41 """Check annotation file.""" 42 # if it's empty, there is no annotation 43 if not anno: 44 return False 45 # if all boxes have close to zero area, there is no annotation 46 if _has_only_empty_bbox(anno): 47 return False 48 # keypoints task have a slight different critera for considering 49 # if an annotation is valid 50 if "keypoints" not in anno[0]: 51 return True 52 # for keypoint detection tasks, only consider valid images those 53 # containing at least min_keypoints_per_image 54 if _count_visible_keypoints(anno) >= min_keypoints_per_image: 55 return True 56 return False 57 58 59class COCOYoloDataset: 60 """YOLOV3 Dataset for COCO.""" 61 def __init__(self, root, ann_file, remove_images_without_annotations=True, 62 filter_crowd_anno=True, is_training=True): 63 self.coco = COCO(ann_file) 64 self.root = root 65 self.img_ids = list(sorted(self.coco.imgs.keys())) 66 self.filter_crowd_anno = filter_crowd_anno 67 self.is_training = is_training 68 69 # filter images without any annotations 70 if remove_images_without_annotations: 71 img_ids = [] 72 for img_id in self.img_ids: 73 ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) 74 anno = self.coco.loadAnns(ann_ids) 75 if has_valid_annotation(anno): 76 img_ids.append(img_id) 77 self.img_ids = img_ids 78 79 self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} 80 81 self.cat_ids_to_continuous_ids = { 82 v: i for i, v in enumerate(self.coco.getCatIds()) 83 } 84 self.continuous_ids_cat_ids = { 85 v: k for k, v in self.cat_ids_to_continuous_ids.items() 86 } 87 88 def __getitem__(self, index): 89 """ 90 Args: 91 index (int): Index 92 93 Returns: 94 (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", 95 generated by the image's annotation. img is a PIL image. 96 """ 97 coco = self.coco 98 img_id = self.img_ids[index] 99 img_path = coco.loadImgs(img_id)[0]["file_name"] 100 img = Image.open(os.path.join(self.root, img_path)).convert("RGB") 101 if not self.is_training: 102 return img, img_id 103 104 ann_ids = coco.getAnnIds(imgIds=img_id) 105 target = coco.loadAnns(ann_ids) 106 # filter crowd annotations 107 if self.filter_crowd_anno: 108 annos = [anno for anno in target if anno["iscrowd"] == 0] 109 else: 110 annos = [anno for anno in target] 111 112 target = {} 113 boxes = [anno["bbox"] for anno in annos] 114 target["bboxes"] = boxes 115 116 classes = [anno["category_id"] for anno in annos] 117 classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] 118 target["labels"] = classes 119 120 bboxes = target['bboxes'] 121 labels = target['labels'] 122 out_target = [] 123 for bbox, label in zip(bboxes, labels): 124 tmp = [] 125 # convert to [x_min y_min x_max y_max] 126 bbox = self._convetTopDown(bbox) 127 tmp.extend(bbox) 128 tmp.append(int(label)) 129 # tmp [x_min y_min x_max y_max, label] 130 out_target.append(tmp) 131 return img, out_target, [], [], [], [], [], [] 132 133 def __len__(self): 134 return len(self.img_ids) 135 136 def _convetTopDown(self, bbox): 137 x_min = bbox[0] 138 y_min = bbox[1] 139 w = bbox[2] 140 h = bbox[3] 141 return [x_min, y_min, x_min+w, y_min+h] 142 143 144def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, 145 config=None, is_training=True, shuffle=True, num_samples=256): 146 """Create dataset for YOLOV3.""" 147 cv2.setNumThreads(0) 148 149 if is_training: 150 filter_crowd = True 151 remove_empty_anno = True 152 else: 153 filter_crowd = False 154 remove_empty_anno = False 155 156 yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, 157 remove_images_without_annotations=remove_empty_anno, is_training=is_training) 158 distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) 159 hwc_to_chw = CV.HWC2CHW() 160 161 config.dataset_size = len(yolo_dataset) 162 cores = multiprocessing.cpu_count() 163 num_parallel_workers = int(cores / device_num) 164 if is_training: 165 multi_scale_trans = MultiScaleTrans(config, device_num) 166 dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", 167 "gt_box1", "gt_box2", "gt_box3"] 168 if device_num != 8: 169 ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, 170 num_parallel_workers=min(32, num_parallel_workers), 171 sampler=distributed_sampler, num_samples=num_samples) 172 ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, 173 num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True) 174 else: 175 ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler) 176 ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, 177 num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True) 178 else: 179 ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], 180 sampler=distributed_sampler, num_samples=num_samples) 181 compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) 182 ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"], 183 output_columns=["image", "image_shape", "img_id"], 184 column_order=["image", "image_shape", "img_id"], 185 num_parallel_workers=8) 186 ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8) 187 ds = ds.batch(batch_size, drop_remainder=True) 188 ds = ds.repeat(max_epoch) 189 190 return ds, num_samples 191