1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""YOLOv3 dataset""" 17from __future__ import division 18 19import os 20import numpy as np 21from matplotlib.colors import rgb_to_hsv, hsv_to_rgb 22from PIL import Image 23import mindspore.dataset as de 24from mindspore.mindrecord import FileWriter 25import mindspore.dataset.vision.c_transforms as C 26from src.config import ConfigYOLOV3ResNet18 27 28iter_cnt = 0 29_NUM_BOXES = 50 30np.random.seed(1) 31de.config.set_seed(1) 32 33def preprocess_fn(image, box, is_training): 34 """Preprocess function for dataset.""" 35 config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326] 36 anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2) 37 do_hsv = False 38 max_boxes = 20 39 num_classes = ConfigYOLOV3ResNet18.num_classes 40 41 def _rand(a=0., b=1.): 42 return np.random.rand() * (b - a) + a 43 44 def _preprocess_true_boxes(true_boxes, anchors, in_shape=None): 45 """Get true boxes.""" 46 num_layers = anchors.shape[0] // 3 47 anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 48 true_boxes = np.array(true_boxes, dtype='float32') 49 # input_shape = np.array([in_shape, in_shape], dtype='int32') 50 input_shape = np.array(in_shape, dtype='int32') 51 boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. 52 boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] 53 true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] 54 true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] 55 56 grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] 57 y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), 58 5 + num_classes), dtype='float32') for l in range(num_layers)] 59 60 anchors = np.expand_dims(anchors, 0) 61 anchors_max = anchors / 2. 62 anchors_min = -anchors_max 63 64 valid_mask = boxes_wh[..., 0] >= 1 65 66 wh = boxes_wh[valid_mask] 67 68 69 if len(wh) >= 1: 70 wh = np.expand_dims(wh, -2) 71 boxes_max = wh / 2. 72 boxes_min = -boxes_max 73 74 intersect_min = np.maximum(boxes_min, anchors_min) 75 intersect_max = np.minimum(boxes_max, anchors_max) 76 intersect_wh = np.maximum(intersect_max - intersect_min, 0.) 77 intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 78 box_area = wh[..., 0] * wh[..., 1] 79 anchor_area = anchors[..., 0] * anchors[..., 1] 80 iou = intersect_area / (box_area + anchor_area - intersect_area) 81 82 best_anchor = np.argmax(iou, axis=-1) 83 for t, n in enumerate(best_anchor): 84 for l in range(num_layers): 85 if n in anchor_mask[l]: 86 i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') 87 j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') 88 k = anchor_mask[l].index(n) 89 90 c = true_boxes[t, 4].astype('int32') 91 y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] 92 y_true[l][j, i, k, 4] = 1. 93 y_true[l][j, i, k, 5 + c] = 1. 94 95 pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32) 96 pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32) 97 pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32) 98 99 mask0 = np.reshape(y_true[0][..., 4:5], [-1]) 100 gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) 101 gt_box0 = gt_box0[mask0 == 1] 102 pad_gt_box0[:gt_box0.shape[0]] = gt_box0 103 104 mask1 = np.reshape(y_true[1][..., 4:5], [-1]) 105 gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) 106 gt_box1 = gt_box1[mask1 == 1] 107 pad_gt_box1[:gt_box1.shape[0]] = gt_box1 108 109 mask2 = np.reshape(y_true[2][..., 4:5], [-1]) 110 gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) 111 gt_box2 = gt_box2[mask2 == 1] 112 pad_gt_box2[:gt_box2.shape[0]] = gt_box2 113 114 return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 115 116 def _infer_data(img_data, input_shape, box): 117 w, h = img_data.size 118 input_h, input_w = input_shape 119 scale = min(float(input_w) / float(w), float(input_h) / float(h)) 120 nw = int(w * scale) 121 nh = int(h * scale) 122 img_data = img_data.resize((nw, nh), Image.BICUBIC) 123 124 new_image = np.zeros((input_h, input_w, 3), np.float32) 125 new_image.fill(128) 126 img_data = np.array(img_data) 127 if len(img_data.shape) == 2: 128 img_data = np.expand_dims(img_data, axis=-1) 129 img_data = np.concatenate([img_data, img_data, img_data], axis=-1) 130 131 dh = int((input_h - nh) / 2) 132 dw = int((input_w - nw) / 2) 133 new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data 134 new_image /= 255. 135 new_image = np.transpose(new_image, (2, 0, 1)) 136 new_image = np.expand_dims(new_image, 0) 137 return new_image, np.array([h, w], np.float32), box 138 139 def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): 140 """Data augmentation function.""" 141 if not isinstance(image, Image.Image): 142 image = Image.fromarray(image) 143 144 iw, ih = image.size 145 ori_image_shape = np.array([ih, iw], np.int32) 146 h, w = image_size 147 148 if not is_training: 149 return _infer_data(image, image_size, box) 150 151 flip = _rand() < .5 152 # correct boxes 153 box_data = np.zeros((max_boxes, 5)) 154 while True: 155 # Prevent the situation that all boxes are eliminated 156 new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \ 157 _rand(1 - jitter, 1 + jitter) 158 scale = _rand(0.25, 2) 159 160 if new_ar < 1: 161 nh = int(scale * h) 162 nw = int(nh * new_ar) 163 else: 164 nw = int(scale * w) 165 nh = int(nw / new_ar) 166 167 dx = int(_rand(0, w - nw)) 168 dy = int(_rand(0, h - nh)) 169 170 if len(box) >= 1: 171 t_box = box.copy() 172 np.random.shuffle(t_box) 173 t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx 174 t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy 175 if flip: 176 t_box[:, [0, 2]] = w - t_box[:, [2, 0]] 177 t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 178 t_box[:, 2][t_box[:, 2] > w] = w 179 t_box[:, 3][t_box[:, 3] > h] = h 180 box_w = t_box[:, 2] - t_box[:, 0] 181 box_h = t_box[:, 3] - t_box[:, 1] 182 t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box 183 184 if len(t_box) >= 1: 185 box = t_box 186 break 187 188 box_data[:len(box)] = box 189 # resize image 190 image = image.resize((nw, nh), Image.BICUBIC) 191 # place image 192 new_image = Image.new('RGB', (w, h), (128, 128, 128)) 193 new_image.paste(image, (dx, dy)) 194 image = new_image 195 196 # flip image or not 197 if flip: 198 image = image.transpose(Image.FLIP_LEFT_RIGHT) 199 200 # convert image to gray or not 201 gray = _rand() < .25 202 if gray: 203 image = image.convert('L').convert('RGB') 204 205 # when the channels of image is 1 206 image = np.array(image) 207 if len(image.shape) == 2: 208 image = np.expand_dims(image, axis=-1) 209 image = np.concatenate([image, image, image], axis=-1) 210 211 # distort image 212 hue = _rand(-hue, hue) 213 sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) 214 val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) 215 image_data = image / 255. 216 if do_hsv: 217 x = rgb_to_hsv(image_data) 218 x[..., 0] += hue 219 x[..., 0][x[..., 0] > 1] -= 1 220 x[..., 0][x[..., 0] < 0] += 1 221 x[..., 1] *= sat 222 x[..., 2] *= val 223 x[x > 1] = 1 224 x[x < 0] = 0 225 image_data = hsv_to_rgb(x) # numpy array, 0 to 1 226 image_data = image_data.astype(np.float32) 227 228 # preprocess bounding boxes 229 bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ 230 _preprocess_true_boxes(box_data, anchors, image_size) 231 232 return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ 233 ori_image_shape, gt_box1, gt_box2, gt_box3 234 235 if is_training: 236 images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) 237 return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 238 239 images, shape, anno = _data_aug(image, box, is_training) 240 return images, shape, anno 241 242 243def anno_parser(annos_str): 244 """Parse annotation from string to list.""" 245 annos = [] 246 for anno_str in annos_str: 247 anno = list(map(int, anno_str.strip().split(','))) 248 annos.append(anno) 249 return annos 250 251 252def filter_valid_data(image_dir, anno_path): 253 """Filter valid image file, which both in image_dir and anno_path.""" 254 image_files = [] 255 image_anno_dict = {} 256 if not os.path.isdir(image_dir): 257 raise RuntimeError("Path given is not valid.") 258 if not os.path.isfile(anno_path): 259 raise RuntimeError("Annotation file is not valid.") 260 261 with open(anno_path, "rb") as f: 262 lines = f.readlines() 263 for line in lines: 264 line_str = line.decode("utf-8").strip() 265 line_split = str(line_str).split(' ') 266 file_name = line_split[0] 267 if os.path.isfile(os.path.join(image_dir, file_name)): 268 image_anno_dict[file_name] = anno_parser(line_split[1:]) 269 image_files.append(file_name) 270 return image_files, image_anno_dict 271 272 273def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8): 274 """Create MindRecord file by image_dir and anno_path.""" 275 mindrecord_path = os.path.join(mindrecord_dir, prefix) 276 writer = FileWriter(mindrecord_path, file_num) 277 image_files, image_anno_dict = filter_valid_data(image_dir, anno_path) 278 279 yolo_json = { 280 "image": {"type": "bytes"}, 281 "annotation": {"type": "int64", "shape": [-1, 5]}, 282 } 283 writer.add_schema(yolo_json, "yolo_json") 284 285 for image_name in image_files: 286 image_path = os.path.join(image_dir, image_name) 287 with open(image_path, 'rb') as f: 288 img = f.read() 289 annos = np.array(image_anno_dict[image_name]) 290 row = {"image": img, "annotation": annos} 291 writer.write_raw_data([row]) 292 writer.commit() 293 294 295def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0, 296 is_training=True, num_parallel_workers=8): 297 """Create YOLOv3 dataset with MindDataset.""" 298 ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, 299 num_parallel_workers=num_parallel_workers, shuffle=False) 300 decode = C.Decode() 301 ds = ds.map(operations=decode, input_columns=["image"]) 302 compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) 303 304 if is_training: 305 hwc_to_chw = C.HWC2CHW() 306 ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], 307 output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], 308 column_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], 309 num_parallel_workers=num_parallel_workers) 310 ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=num_parallel_workers) 311 ds = ds.batch(batch_size, drop_remainder=True) 312 ds = ds.repeat(repeat_num) 313 else: 314 ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], 315 output_columns=["image", "image_shape", "annotation"], 316 column_order=["image", "image_shape", "annotation"], 317 num_parallel_workers=num_parallel_workers) 318 return ds 319