1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import hashlib 17import json 18import os 19import itertools 20from enum import Enum 21import numpy as np 22import matplotlib.pyplot as plt 23import matplotlib.patches as patches 24import PIL 25# import jsbeautifier 26import mindspore.dataset as ds 27from mindspore import log as logger 28 29# These are list of plot title in different visualize modes 30PLOT_TITLE_DICT = { 31 1: ["Original image", "Transformed image"], 32 2: ["c_transform image", "py_transform image"] 33} 34SAVE_JSON = False 35 36 37def _save_golden(cur_dir, golden_ref_dir, result_dict): 38 """ 39 Save the dictionary values as the golden result in .npz file 40 """ 41 logger.info("cur_dir is {}".format(cur_dir)) 42 logger.info("golden_ref_dir is {}".format(golden_ref_dir)) 43 np.savez(golden_ref_dir, np.array(list(result_dict.values()))) 44 45 46def _save_golden_dict(cur_dir, golden_ref_dir, result_dict): 47 """ 48 Save the dictionary (both keys and values) as the golden result in .npz file 49 """ 50 logger.info("cur_dir is {}".format(cur_dir)) 51 logger.info("golden_ref_dir is {}".format(golden_ref_dir)) 52 np.savez(golden_ref_dir, np.array(list(result_dict.items()))) 53 54 55def _compare_to_golden(golden_ref_dir, result_dict): 56 """ 57 Compare as numpy arrays the test result to the golden result 58 """ 59 test_array = np.array(list(result_dict.values())) 60 golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] 61 np.testing.assert_array_equal(test_array, golden_array) 62 63 64def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False): 65 """ 66 Compare as dictionaries the test result to the golden result 67 """ 68 golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] 69 # Note: The version of PILLOW that is used in Jenkins CI is compared with below 70 if (not check_pillow_version or PIL.__version__ == '7.1.2'): 71 np.testing.assert_equal(result_dict, dict(golden_array)) 72 else: 73 # Beware: If error, PILLOW version of golden results may be incompatible with current PILLOW version 74 np.testing.assert_equal(result_dict, dict(golden_array), 75 'Items are not equal and problem may be due to PILLOW version incompatibility') 76 77def _save_json(filename, parameters, result_dict): 78 """ 79 Save the result dictionary in json file 80 """ 81 fout = open(filename[:-3] + "json", "w") 82 options = jsbeautifier.default_options() 83 options.indent_size = 2 84 85 out_dict = {**parameters, **{"columns": result_dict}} 86 fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) 87 88 89def save_and_check_dict(data, filename, generate_golden=False): 90 """ 91 Save the dataset dictionary and compare (as dictionary) with golden file. 92 Use create_dict_iterator to access the dataset. 93 """ 94 num_iter = 0 95 result_dict = {} 96 97 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 98 for data_key in list(item.keys()): 99 if data_key not in result_dict: 100 result_dict[data_key] = [] 101 result_dict[data_key].append(item[data_key].tolist()) 102 num_iter += 1 103 104 logger.info("Number of data in ds1: {}".format(num_iter)) 105 106 cur_dir = os.path.dirname(os.path.realpath(__file__)) 107 golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) 108 if generate_golden: 109 # Save as the golden result 110 _save_golden_dict(cur_dir, golden_ref_dir, result_dict) 111 112 _compare_to_golden_dict(golden_ref_dir, result_dict, False) 113 114 if SAVE_JSON: 115 # Save result to a json file for inspection 116 parameters = {"params": {}} 117 _save_json(filename, parameters, result_dict) 118 119 120def save_and_check_md5(data, filename, generate_golden=False): 121 """ 122 Save the dataset dictionary and compare (as dictionary) with golden file (md5). 123 Use create_dict_iterator to access the dataset. 124 """ 125 num_iter = 0 126 result_dict = {} 127 128 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 129 for data_key in list(item.keys()): 130 if data_key not in result_dict: 131 result_dict[data_key] = [] 132 # save the md5 as numpy array 133 result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4')) 134 num_iter += 1 135 136 logger.info("Number of data in ds1: {}".format(num_iter)) 137 138 cur_dir = os.path.dirname(os.path.realpath(__file__)) 139 golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) 140 if generate_golden: 141 # Save as the golden result 142 _save_golden_dict(cur_dir, golden_ref_dir, result_dict) 143 144 _compare_to_golden_dict(golden_ref_dir, result_dict, True) 145 146 147def save_and_check_tuple(data, parameters, filename, generate_golden=False): 148 """ 149 Save the dataset dictionary and compare (as numpy array) with golden file. 150 Use create_tuple_iterator to access the dataset. 151 """ 152 num_iter = 0 153 result_dict = {} 154 155 for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 156 for data_key, _ in enumerate(item): 157 if data_key not in result_dict: 158 result_dict[data_key] = [] 159 result_dict[data_key].append(item[data_key].tolist()) 160 num_iter += 1 161 162 logger.info("Number of data in data1: {}".format(num_iter)) 163 164 cur_dir = os.path.dirname(os.path.realpath(__file__)) 165 golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) 166 if generate_golden: 167 # Save as the golden result 168 _save_golden(cur_dir, golden_ref_dir, result_dict) 169 170 _compare_to_golden(golden_ref_dir, result_dict) 171 172 if SAVE_JSON: 173 # Save result to a json file for inspection 174 _save_json(filename, parameters, result_dict) 175 176 177def config_get_set_seed(seed_new): 178 """ 179 Get and return the original configuration seed value. 180 Set the new configuration seed value. 181 """ 182 seed_original = ds.config.get_seed() 183 ds.config.set_seed(seed_new) 184 logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) 185 return seed_original 186 187 188def config_get_set_num_parallel_workers(num_parallel_workers_new): 189 """ 190 Get and return the original configuration num_parallel_workers value. 191 Set the new configuration num_parallel_workers value. 192 """ 193 num_parallel_workers_original = ds.config.get_num_parallel_workers() 194 ds.config.set_num_parallel_workers(num_parallel_workers_new) 195 logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, 196 num_parallel_workers_new)) 197 return num_parallel_workers_original 198 199 200def diff_mse(in1, in2): 201 mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() 202 return mse * 100 203 204 205def diff_me(in1, in2): 206 mse = (np.abs(in1.astype(float) - in2.astype(float))).mean() 207 return mse / 255 * 100 208 209 210def visualize_one_channel_dataset(images_original, images_transformed, labels): 211 """ 212 Helper function to visualize one channel grayscale images 213 """ 214 num_samples = len(images_original) 215 for i in range(num_samples): 216 plt.subplot(2, num_samples, i + 1) 217 # Note: Use squeeze() to convert (H, W, 1) images to (H, W) 218 plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray) 219 plt.title(PLOT_TITLE_DICT[1][0] + ":" + str(labels[i])) 220 221 plt.subplot(2, num_samples, i + num_samples + 1) 222 plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray) 223 plt.title(PLOT_TITLE_DICT[1][1] + ":" + str(labels[i])) 224 plt.show() 225 226 227def visualize_list(image_list_1, image_list_2, visualize_mode=1): 228 """ 229 visualizes a list of images using DE op 230 """ 231 plot_title = PLOT_TITLE_DICT[visualize_mode] 232 num = len(image_list_1) 233 for i in range(num): 234 plt.subplot(2, num, i + 1) 235 plt.imshow(image_list_1[i]) 236 plt.title(plot_title[0]) 237 238 plt.subplot(2, num, i + num + 1) 239 plt.imshow(image_list_2[i]) 240 plt.title(plot_title[1]) 241 242 plt.show() 243 244 245def visualize_image(image_original, image_de, mse=None, image_lib=None): 246 """ 247 visualizes one example image with optional input: mse, image using 3rd party op. 248 If three images are passing in, different image is calculated by 2nd and 3rd images. 249 """ 250 num = 2 251 if image_lib is not None: 252 num += 1 253 if mse is not None: 254 num += 1 255 plt.subplot(1, num, 1) 256 plt.imshow(image_original) 257 plt.title("Original image") 258 259 plt.subplot(1, num, 2) 260 plt.imshow(image_de) 261 plt.title("DE Op image") 262 263 if image_lib is not None: 264 plt.subplot(1, num, 3) 265 plt.imshow(image_lib) 266 plt.title("Lib Op image") 267 if mse is not None: 268 plt.subplot(1, num, 4) 269 plt.imshow(image_de - image_lib) 270 plt.title("Diff image,\n mse : {}".format(mse)) 271 elif mse is not None: 272 plt.subplot(1, num, 3) 273 plt.imshow(image_original - image_de) 274 plt.title("Diff image,\n mse : {}".format(mse)) 275 276 plt.show() 277 278 279def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3): 280 """ 281 Take a list of un-augmented and augmented images with "bbox" bounding boxes 282 Plot images to compare test correct BBox augment functionality 283 :param orig: list of original images and bboxes (without aug) 284 :param aug: list of augmented images and bboxes 285 :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC) 286 :param plot_rows: number of rows on plot (rows = samples on one plot) 287 :return: None 288 """ 289 290 def add_bounding_boxes(ax, bboxes): 291 for bbox in bboxes: 292 rect = patches.Rectangle((bbox[0], bbox[1]), 293 bbox[2] * 0.997, bbox[3] * 0.997, 294 linewidth=1.80, edgecolor='r', facecolor='none') 295 # Add the patch to the Axes 296 # Params to Rectangle slightly modified to prevent drawing overflow 297 ax.add_patch(rect) 298 299 # Quick check to confirm correct input parameters 300 if not isinstance(orig, list) or not isinstance(aug, list): 301 return 302 if len(orig) != len(aug) or not orig: 303 return 304 305 batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together 306 split_point = batch_size * plot_rows 307 308 orig, aug = np.array(orig), np.array(aug) 309 310 if len(orig) > plot_rows: 311 # Create batches of required size and add remainder to last batch 312 orig = np.split(orig[:split_point], batch_size) + ( 313 [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added 314 aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) 315 else: 316 orig = [orig] 317 aug = [aug] 318 319 for ix, allData in enumerate(zip(orig, aug)): 320 base_ix = ix * plot_rows # current batch starting index 321 curPlot = len(allData[0]) 322 323 fig, axs = plt.subplots(curPlot, 2) 324 fig.tight_layout(pad=1.5) 325 326 for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): 327 cur_ix = base_ix + x 328 # select plotting axes based on number of image rows on plot - else case when 1 row 329 (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) 330 331 axA.imshow(dataA["image"]) 332 add_bounding_boxes(axA, dataA[annot_name]) 333 axA.title.set_text("Original" + str(cur_ix + 1)) 334 335 axB.imshow(dataB["image"]) 336 add_bounding_boxes(axB, dataB[annot_name]) 337 axB.title.set_text("Augmented" + str(cur_ix + 1)) 338 339 logger.info("Original **\n{} : {}".format(str(cur_ix + 1), dataA[annot_name])) 340 logger.info("Augmented **\n{} : {}\n".format(str(cur_ix + 1), dataB[annot_name])) 341 342 plt.show() 343 344 345class InvalidBBoxType(Enum): 346 """ 347 Defines Invalid Bounding Bbox types for test cases 348 """ 349 WidthOverflow = 1 350 HeightOverflow = 2 351 NegativeXY = 3 352 WrongShape = 4 353 354 355def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): 356 """ 357 :param data: de object detection pipeline 358 :param test_op: Augmentation Op to test on image 359 :param invalid_bbox_type: type of bad box 360 :param expected_error: error expected to get due to bad box 361 :return: None 362 """ 363 364 def add_bad_bbox(img, bboxes, invalid_bbox_type_): 365 """ 366 Used to generate erroneous bounding box examples on given img. 367 :param img: image where the bounding boxes are. 368 :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format 369 :param box_type_: type of bad box 370 :return: bboxes with bad examples added 371 """ 372 height = img.shape[0] 373 width = img.shape[1] 374 if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow: 375 # use box that overflows on width 376 return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32) 377 378 if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow: 379 # use box that overflows on height 380 return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32) 381 382 if invalid_bbox_type_ == InvalidBBoxType.NegativeXY: 383 # use box with negative xy 384 return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32) 385 386 if invalid_bbox_type_ == InvalidBBoxType.WrongShape: 387 # use box that has incorrect shape 388 return img, np.array([[0, 0, width - 1]]).astype(np.float32) 389 return img, bboxes 390 391 try: 392 # map to use selected invalid bounding box type 393 data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type), 394 input_columns=["image", "bbox"], 395 output_columns=["image", "bbox"], 396 column_order=["image", "bbox"]) 397 # map to apply ops 398 data = data.map(operations=[test_op], input_columns=["image", "bbox"], 399 output_columns=["image", "bbox"], 400 column_order=["image", "bbox"]) # Add column for "bbox" 401 for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 402 break 403 except RuntimeError as error: 404 logger.info("Got an exception in DE: {}".format(str(error))) 405 assert expected_error in str(error) 406 407 408# return true if datasets are equal 409def dataset_equal(data1, data2, mse_threshold): 410 if data1.get_dataset_size() != data2.get_dataset_size(): 411 return False 412 equal = True 413 for item1, item2 in itertools.zip_longest(data1, data2): 414 for column1, column2 in itertools.zip_longest(item1, item2): 415 mse = diff_mse(column1.asnumpy(), column2.asnumpy()) 416 if mse > mse_threshold: 417 equal = False 418 break 419 if not equal: 420 break 421 return equal 422 423 424# return true if datasets are equal after modification to target 425# params: data_unchanged - dataset kept unchanged 426# data_target - dataset to be modified by foo 427# mse_threshold - maximum allowable value of mse 428# foo - function applied to data_target columns BEFORE compare 429# foo_args - arguments passed into foo 430def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args): 431 if data_unchanged.get_dataset_size() != data_target.get_dataset_size(): 432 return False 433 equal = True 434 for item1, item2 in itertools.zip_longest(data_unchanged, data_target): 435 for column1, column2 in itertools.zip_longest(item1, item2): 436 # note the function is to be applied to the second dataset 437 column2 = foo(column2.asnumpy(), *foo_args) 438 mse = diff_mse(column1.asnumpy(), column2) 439 if mse > mse_threshold: 440 equal = False 441 break 442 if not equal: 443 break 444 return equal 445