• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""YOLOV3 dataset."""
16import os
17
18import multiprocessing
19import cv2
20from PIL import Image
21from pycocotools.coco import COCO
22import mindspore.dataset as de
23import mindspore.dataset.vision.c_transforms as CV
24
25from src.distributed_sampler import DistributedSampler
26from src.transforms import reshape_fn, MultiScaleTrans
27
28
29min_keypoints_per_image = 10
30
31
32def _has_only_empty_bbox(anno):
33    return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
34
35
36def _count_visible_keypoints(anno):
37    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
38
39
40def has_valid_annotation(anno):
41    """Check annotation file."""
42    # if it's empty, there is no annotation
43    if not anno:
44        return False
45    # if all boxes have close to zero area, there is no annotation
46    if _has_only_empty_bbox(anno):
47        return False
48    # keypoints task have a slight different critera for considering
49    # if an annotation is valid
50    if "keypoints" not in anno[0]:
51        return True
52    # for keypoint detection tasks, only consider valid images those
53    # containing at least min_keypoints_per_image
54    if _count_visible_keypoints(anno) >= min_keypoints_per_image:
55        return True
56    return False
57
58
59class COCOYoloDataset:
60    """YOLOV3 Dataset for COCO."""
61    def __init__(self, root, ann_file, remove_images_without_annotations=True,
62                 filter_crowd_anno=True, is_training=True):
63        self.coco = COCO(ann_file)
64        self.root = root
65        self.img_ids = list(sorted(self.coco.imgs.keys()))
66        self.filter_crowd_anno = filter_crowd_anno
67        self.is_training = is_training
68
69        # filter images without any annotations
70        if remove_images_without_annotations:
71            img_ids = []
72            for img_id in self.img_ids:
73                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
74                anno = self.coco.loadAnns(ann_ids)
75                if has_valid_annotation(anno):
76                    img_ids.append(img_id)
77            self.img_ids = img_ids
78
79        self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
80
81        self.cat_ids_to_continuous_ids = {
82            v: i for i, v in enumerate(self.coco.getCatIds())
83        }
84        self.continuous_ids_cat_ids = {
85            v: k for k, v in self.cat_ids_to_continuous_ids.items()
86        }
87
88    def __getitem__(self, index):
89        """
90        Args:
91            index (int): Index
92
93        Returns:
94            (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
95                generated by the image's annotation. img is a PIL image.
96        """
97        coco = self.coco
98        img_id = self.img_ids[index]
99        img_path = coco.loadImgs(img_id)[0]["file_name"]
100        img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
101        if not self.is_training:
102            return img, img_id
103
104        ann_ids = coco.getAnnIds(imgIds=img_id)
105        target = coco.loadAnns(ann_ids)
106        # filter crowd annotations
107        if self.filter_crowd_anno:
108            annos = [anno for anno in target if anno["iscrowd"] == 0]
109        else:
110            annos = [anno for anno in target]
111
112        target = {}
113        boxes = [anno["bbox"] for anno in annos]
114        target["bboxes"] = boxes
115
116        classes = [anno["category_id"] for anno in annos]
117        classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
118        target["labels"] = classes
119
120        bboxes = target['bboxes']
121        labels = target['labels']
122        out_target = []
123        for bbox, label in zip(bboxes, labels):
124            tmp = []
125            # convert to [x_min y_min x_max y_max]
126            bbox = self._convetTopDown(bbox)
127            tmp.extend(bbox)
128            tmp.append(int(label))
129            # tmp [x_min y_min x_max y_max, label]
130            out_target.append(tmp)
131        return img, out_target, [], [], [], [], [], []
132
133    def __len__(self):
134        return len(self.img_ids)
135
136    def _convetTopDown(self, bbox):
137        x_min = bbox[0]
138        y_min = bbox[1]
139        w = bbox[2]
140        h = bbox[3]
141        return [x_min, y_min, x_min+w, y_min+h]
142
143
144def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
145                        config=None, is_training=True, shuffle=True, num_samples=256):
146    """Create dataset for YOLOV3."""
147    cv2.setNumThreads(0)
148
149    if is_training:
150        filter_crowd = True
151        remove_empty_anno = True
152    else:
153        filter_crowd = False
154        remove_empty_anno = False
155
156    yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
157                                   remove_images_without_annotations=remove_empty_anno, is_training=is_training)
158    distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
159    hwc_to_chw = CV.HWC2CHW()
160
161    config.dataset_size = len(yolo_dataset)
162    cores = multiprocessing.cpu_count()
163    num_parallel_workers = int(cores / device_num)
164    if is_training:
165        multi_scale_trans = MultiScaleTrans(config, device_num)
166        dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
167                                "gt_box1", "gt_box2", "gt_box3"]
168        if device_num != 8:
169            ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names,
170                                     num_parallel_workers=min(32, num_parallel_workers),
171                                     sampler=distributed_sampler, num_samples=num_samples)
172            ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
173                          num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
174        else:
175            ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
176            ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
177                          num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
178    else:
179        ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
180                                 sampler=distributed_sampler, num_samples=num_samples)
181        compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
182        ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"],
183                    output_columns=["image", "image_shape", "img_id"],
184                    column_order=["image", "image_shape", "img_id"],
185                    num_parallel_workers=8)
186        ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
187        ds = ds.batch(batch_size, drop_remainder=True)
188    ds = ds.repeat(max_epoch)
189
190    return ds, num_samples
191