• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""YOLOv3 dataset"""
17from __future__ import division
18
19import os
20import numpy as np
21from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
22from PIL import Image
23import mindspore.dataset as de
24from mindspore.mindrecord import FileWriter
25import mindspore.dataset.vision.c_transforms as C
26from src.config import ConfigYOLOV3ResNet18
27
28iter_cnt = 0
29_NUM_BOXES = 50
30np.random.seed(1)
31de.config.set_seed(1)
32
33def preprocess_fn(image, box, is_training):
34    """Preprocess function for dataset."""
35    config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
36    anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
37    do_hsv = False
38    max_boxes = 20
39    num_classes = ConfigYOLOV3ResNet18.num_classes
40
41    def _rand(a=0., b=1.):
42        return np.random.rand() * (b - a) + a
43
44    def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
45        """Get true boxes."""
46        num_layers = anchors.shape[0] // 3
47        anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
48        true_boxes = np.array(true_boxes, dtype='float32')
49        # input_shape = np.array([in_shape, in_shape], dtype='int32')
50        input_shape = np.array(in_shape, dtype='int32')
51        boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
52        boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
53        true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
54        true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
55
56        grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
57        y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
58                            5 + num_classes), dtype='float32') for l in range(num_layers)]
59
60        anchors = np.expand_dims(anchors, 0)
61        anchors_max = anchors / 2.
62        anchors_min = -anchors_max
63
64        valid_mask = boxes_wh[..., 0] >= 1
65
66        wh = boxes_wh[valid_mask]
67
68
69        if len(wh) >= 1:
70            wh = np.expand_dims(wh, -2)
71            boxes_max = wh / 2.
72            boxes_min = -boxes_max
73
74            intersect_min = np.maximum(boxes_min, anchors_min)
75            intersect_max = np.minimum(boxes_max, anchors_max)
76            intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
77            intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
78            box_area = wh[..., 0] * wh[..., 1]
79            anchor_area = anchors[..., 0] * anchors[..., 1]
80            iou = intersect_area / (box_area + anchor_area - intersect_area)
81
82            best_anchor = np.argmax(iou, axis=-1)
83            for t, n in enumerate(best_anchor):
84                for l in range(num_layers):
85                    if n in anchor_mask[l]:
86                        i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
87                        j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
88                        k = anchor_mask[l].index(n)
89
90                        c = true_boxes[t, 4].astype('int32')
91                        y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
92                        y_true[l][j, i, k, 4] = 1.
93                        y_true[l][j, i, k, 5 + c] = 1.
94
95        pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
96        pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
97        pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
98
99        mask0 = np.reshape(y_true[0][..., 4:5], [-1])
100        gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
101        gt_box0 = gt_box0[mask0 == 1]
102        pad_gt_box0[:gt_box0.shape[0]] = gt_box0
103
104        mask1 = np.reshape(y_true[1][..., 4:5], [-1])
105        gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
106        gt_box1 = gt_box1[mask1 == 1]
107        pad_gt_box1[:gt_box1.shape[0]] = gt_box1
108
109        mask2 = np.reshape(y_true[2][..., 4:5], [-1])
110        gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
111        gt_box2 = gt_box2[mask2 == 1]
112        pad_gt_box2[:gt_box2.shape[0]] = gt_box2
113
114        return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
115
116    def _infer_data(img_data, input_shape, box):
117        w, h = img_data.size
118        input_h, input_w = input_shape
119        scale = min(float(input_w) / float(w), float(input_h) / float(h))
120        nw = int(w * scale)
121        nh = int(h * scale)
122        img_data = img_data.resize((nw, nh), Image.BICUBIC)
123
124        new_image = np.zeros((input_h, input_w, 3), np.float32)
125        new_image.fill(128)
126        img_data = np.array(img_data)
127        if len(img_data.shape) == 2:
128            img_data = np.expand_dims(img_data, axis=-1)
129            img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
130
131        dh = int((input_h - nh) / 2)
132        dw = int((input_w - nw) / 2)
133        new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
134        new_image /= 255.
135        new_image = np.transpose(new_image, (2, 0, 1))
136        new_image = np.expand_dims(new_image, 0)
137        return new_image, np.array([h, w], np.float32), box
138
139    def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
140        """Data augmentation function."""
141        if not isinstance(image, Image.Image):
142            image = Image.fromarray(image)
143
144        iw, ih = image.size
145        ori_image_shape = np.array([ih, iw], np.int32)
146        h, w = image_size
147
148        if not is_training:
149            return _infer_data(image, image_size, box)
150
151        flip = _rand() < .5
152        # correct boxes
153        box_data = np.zeros((max_boxes, 5))
154        while True:
155            # Prevent the situation that all boxes are eliminated
156            new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
157                     _rand(1 - jitter, 1 + jitter)
158            scale = _rand(0.25, 2)
159
160            if new_ar < 1:
161                nh = int(scale * h)
162                nw = int(nh * new_ar)
163            else:
164                nw = int(scale * w)
165                nh = int(nw / new_ar)
166
167            dx = int(_rand(0, w - nw))
168            dy = int(_rand(0, h - nh))
169
170            if len(box) >= 1:
171                t_box = box.copy()
172                np.random.shuffle(t_box)
173                t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
174                t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
175                if flip:
176                    t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
177                t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
178                t_box[:, 2][t_box[:, 2] > w] = w
179                t_box[:, 3][t_box[:, 3] > h] = h
180                box_w = t_box[:, 2] - t_box[:, 0]
181                box_h = t_box[:, 3] - t_box[:, 1]
182                t_box = t_box[np.logical_and(box_w > 1, box_h > 1)]  # discard invalid box
183
184            if len(t_box) >= 1:
185                box = t_box
186                break
187
188        box_data[:len(box)] = box
189        # resize image
190        image = image.resize((nw, nh), Image.BICUBIC)
191        # place image
192        new_image = Image.new('RGB', (w, h), (128, 128, 128))
193        new_image.paste(image, (dx, dy))
194        image = new_image
195
196        # flip image or not
197        if flip:
198            image = image.transpose(Image.FLIP_LEFT_RIGHT)
199
200        # convert image to gray or not
201        gray = _rand() < .25
202        if gray:
203            image = image.convert('L').convert('RGB')
204
205        # when the channels of image is 1
206        image = np.array(image)
207        if len(image.shape) == 2:
208            image = np.expand_dims(image, axis=-1)
209            image = np.concatenate([image, image, image], axis=-1)
210
211        # distort image
212        hue = _rand(-hue, hue)
213        sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
214        val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
215        image_data = image / 255.
216        if do_hsv:
217            x = rgb_to_hsv(image_data)
218            x[..., 0] += hue
219            x[..., 0][x[..., 0] > 1] -= 1
220            x[..., 0][x[..., 0] < 0] += 1
221            x[..., 1] *= sat
222            x[..., 2] *= val
223            x[x > 1] = 1
224            x[x < 0] = 0
225            image_data = hsv_to_rgb(x)  # numpy array, 0 to 1
226        image_data = image_data.astype(np.float32)
227
228        # preprocess bounding boxes
229        bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
230            _preprocess_true_boxes(box_data, anchors, image_size)
231
232        return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
233               ori_image_shape, gt_box1, gt_box2, gt_box3
234
235    if is_training:
236        images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
237        return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
238
239    images, shape, anno = _data_aug(image, box, is_training)
240    return images, shape, anno
241
242
243def anno_parser(annos_str):
244    """Parse annotation from string to list."""
245    annos = []
246    for anno_str in annos_str:
247        anno = list(map(int, anno_str.strip().split(',')))
248        annos.append(anno)
249    return annos
250
251
252def filter_valid_data(image_dir, anno_path):
253    """Filter valid image file, which both in image_dir and anno_path."""
254    image_files = []
255    image_anno_dict = {}
256    if not os.path.isdir(image_dir):
257        raise RuntimeError("Path given is not valid.")
258    if not os.path.isfile(anno_path):
259        raise RuntimeError("Annotation file is not valid.")
260
261    with open(anno_path, "rb") as f:
262        lines = f.readlines()
263    for line in lines:
264        line_str = line.decode("utf-8").strip()
265        line_split = str(line_str).split(' ')
266        file_name = line_split[0]
267        if os.path.isfile(os.path.join(image_dir, file_name)):
268            image_anno_dict[file_name] = anno_parser(line_split[1:])
269            image_files.append(file_name)
270    return image_files, image_anno_dict
271
272
273def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
274    """Create MindRecord file by image_dir and anno_path."""
275    mindrecord_path = os.path.join(mindrecord_dir, prefix)
276    writer = FileWriter(mindrecord_path, file_num)
277    image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
278
279    yolo_json = {
280        "image": {"type": "bytes"},
281        "annotation": {"type": "int64", "shape": [-1, 5]},
282    }
283    writer.add_schema(yolo_json, "yolo_json")
284
285    for image_name in image_files:
286        image_path = os.path.join(image_dir, image_name)
287        with open(image_path, 'rb') as f:
288            img = f.read()
289        annos = np.array(image_anno_dict[image_name])
290        row = {"image": img, "annotation": annos}
291        writer.write_raw_data([row])
292    writer.commit()
293
294
295def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
296                        is_training=True, num_parallel_workers=8):
297    """Create YOLOv3 dataset with MindDataset."""
298    ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
299                        num_parallel_workers=num_parallel_workers, shuffle=False)
300    decode = C.Decode()
301    ds = ds.map(operations=decode, input_columns=["image"])
302    compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
303
304    if is_training:
305        hwc_to_chw = C.HWC2CHW()
306        ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
307                    output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
308                    column_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
309                    num_parallel_workers=num_parallel_workers)
310        ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=num_parallel_workers)
311        ds = ds.batch(batch_size, drop_remainder=True)
312        ds = ds.repeat(repeat_num)
313    else:
314        ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
315                    output_columns=["image", "image_shape", "annotation"],
316                    column_order=["image", "image_shape", "annotation"],
317                    num_parallel_workers=num_parallel_workers)
318    return ds
319