1# Copyright 2019-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import hashlib 17import json 18import os 19import itertools 20from enum import Enum 21import numpy as np 22import matplotlib.pyplot as plt 23import matplotlib.patches as patches 24import PIL 25# import jsbeautifier 26import mindspore.dataset as ds 27from mindspore import log as logger 28import mindspore.dataset.vision.transforms as vision 29 30# These are list of plot title in different visualize modes 31PLOT_TITLE_DICT = { 32 1: ["Original image", "Transformed image"], 33 2: ["c_transform image", "py_transform image"] 34} 35SAVE_JSON = False 36 37 38def _save_golden(cur_dir, golden_ref_dir, result_dict): 39 """ 40 Save the dictionary values as the golden result in .npz file 41 """ 42 logger.info("cur_dir is {}".format(cur_dir)) 43 logger.info("golden_ref_dir is {}".format(golden_ref_dir)) 44 np.savez(golden_ref_dir, np.array(list(result_dict.values()))) 45 46 47def _save_golden_dict(cur_dir, golden_ref_dir, result_dict): 48 """ 49 Save the dictionary (both keys and values) as the golden result in .npz file 50 """ 51 logger.info("cur_dir is {}".format(cur_dir)) 52 logger.info("golden_ref_dir is {}".format(golden_ref_dir)) 53 np.savez(golden_ref_dir, np.array(list(result_dict.items()))) 54 55 56def _compare_to_golden(golden_ref_dir, result_dict): 57 """ 58 Compare as numpy arrays the test result to the golden result 59 """ 60 test_array = np.array(list(result_dict.values())) 61 golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] 62 np.testing.assert_array_equal(test_array, golden_array) 63 64 65def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False): 66 """ 67 Compare as dictionaries the test result to the golden result 68 """ 69 golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] 70 # Note: The version of PILLOW that is used in Jenkins CI is compared with below 71 if check_pillow_version and PIL.__version__ >= '9.0.0': 72 try: 73 np.testing.assert_equal(result_dict, dict(golden_array)) 74 except AssertionError: 75 logger.warning( 76 "Results from Pillow >= 9.0.0 is incompatibale with Pillow < 9.0.0, need more validation.") 77 elif check_pillow_version: 78 # Note: The version of PILLOW that is used in Jenkins CI is >= 9.0.0 and 79 # some of the md5 results files that are generated with PILLOW 7.2.0 80 # are not compatible with PILLOW 9.0.0. 81 np.testing.assert_equal(result_dict, dict(golden_array), 82 'Items are not equal and problem may be due to PILLOW version incompatibility') 83 else: 84 np.testing.assert_equal(result_dict, dict(golden_array)) 85 86 87def _save_json(filename, parameters, result_dict): 88 """ 89 Save the result dictionary in json file 90 """ 91 fout = open(filename[:-3] + "json", "w") 92 options = jsbeautifier.default_options() 93 options.indent_size = 2 94 95 out_dict = {**parameters, **{"columns": result_dict}} 96 fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) 97 98 99def save_and_check_dict(data, filename, generate_golden=False): 100 """ 101 Save the dataset dictionary and compare (as dictionary) with golden file. 102 Use create_dict_iterator to access the dataset. 103 """ 104 num_iter = 0 105 result_dict = {} 106 107 # each data is a dictionary 108 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 109 for data_key in list(item.keys()): 110 if data_key not in result_dict: 111 result_dict[data_key] = [] 112 result_dict[data_key].append(item[data_key].tolist()) 113 num_iter += 1 114 115 logger.info("Number of data in ds1: {}".format(num_iter)) 116 117 cur_dir = os.path.dirname(os.path.realpath(__file__)) 118 golden_ref_dir = os.path.join( 119 cur_dir, "../../data/dataset", 'golden', filename) 120 if generate_golden: 121 # Save as the golden result 122 _save_golden_dict(cur_dir, golden_ref_dir, result_dict) 123 124 _compare_to_golden_dict(golden_ref_dir, result_dict, False) 125 126 if SAVE_JSON: 127 # Save result to a json file for inspection 128 parameters = {"params": {}} 129 _save_json(filename, parameters, result_dict) 130 131 132def _helper_save_and_check_md5(data, filename, generate_golden=False): 133 """ 134 Helper for save_and_check_md5 for both PIL and non-PIL 135 """ 136 num_iter = 0 137 result_dict = {} 138 139 # each data is a dictionary 140 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 141 for data_key in list(item.keys()): 142 if data_key not in result_dict: 143 result_dict[data_key] = [] 144 # save the md5 as numpy array 145 result_dict.get(data_key).append(np.frombuffer( 146 hashlib.md5(item[data_key]).digest(), dtype='<f4')) 147 num_iter += 1 148 149 logger.info("Number of data in ds1: {}".format(num_iter)) 150 151 cur_dir = os.path.dirname(os.path.realpath(__file__)) 152 golden_ref_dir = os.path.join( 153 cur_dir, "../../data/dataset", 'golden', filename) 154 if generate_golden: 155 # Save as the golden result 156 _save_golden_dict(cur_dir, golden_ref_dir, result_dict) 157 158 return golden_ref_dir, result_dict 159 160 161def save_and_check_md5(data, filename, generate_golden=False): 162 """ 163 Save the dataset dictionary and compare (as dictionary) with golden file (md5) for non-PIL only. 164 Use create_dict_iterator to access the dataset. 165 """ 166 golden_ref_dir, result_dict = _helper_save_and_check_md5(data, filename, generate_golden) 167 _compare_to_golden_dict(golden_ref_dir, result_dict, False) 168 169 170def save_and_check_md5_pil(data, filename, generate_golden=False): 171 """ 172 Save the dataset dictionary and compare (as dictionary) with golden file (md5) for PIL only. 173 If PIL version >= 9.0.0, only log warning when assertion fails and allow the test to succeed. 174 Use create_dict_iterator to access the dataset. 175 """ 176 golden_ref_dir, result_dict = _helper_save_and_check_md5(data, filename, generate_golden) 177 _compare_to_golden_dict(golden_ref_dir, result_dict, True) 178 179 180def save_and_check_tuple(data, parameters, filename, generate_golden=False): 181 """ 182 Save the dataset dictionary and compare (as numpy array) with golden file. 183 Use create_tuple_iterator to access the dataset. 184 """ 185 num_iter = 0 186 result_dict = {} 187 188 # each data is a dictionary 189 for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True): 190 for data_key, _ in enumerate(item): 191 if data_key not in result_dict: 192 result_dict[data_key] = [] 193 result_dict[data_key].append(item[data_key].tolist()) 194 num_iter += 1 195 196 logger.info("Number of data in data1: {}".format(num_iter)) 197 198 cur_dir = os.path.dirname(os.path.realpath(__file__)) 199 golden_ref_dir = os.path.join( 200 cur_dir, "../../data/dataset", 'golden', filename) 201 if generate_golden: 202 # Save as the golden result 203 _save_golden(cur_dir, golden_ref_dir, result_dict) 204 205 _compare_to_golden(golden_ref_dir, result_dict) 206 207 if SAVE_JSON: 208 # Save result to a json file for inspection 209 _save_json(filename, parameters, result_dict) 210 211 212def config_get_set_seed(seed_new): 213 """ 214 Get and return the original configuration seed value. 215 Set the new configuration seed value. 216 """ 217 seed_original = ds.config.get_seed() 218 ds.config.set_seed(seed_new) 219 logger.info("seed: original = {} new = {} ".format( 220 seed_original, seed_new)) 221 return seed_original 222 223 224def config_get_set_num_parallel_workers(num_parallel_workers_new): 225 """ 226 Get and return the original configuration num_parallel_workers value. 227 Set the new configuration num_parallel_workers value. 228 """ 229 num_parallel_workers_original = ds.config.get_num_parallel_workers() 230 ds.config.set_num_parallel_workers(num_parallel_workers_new) 231 logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, 232 num_parallel_workers_new)) 233 return num_parallel_workers_original 234 235 236def config_get_set_enable_shared_mem(enable_shared_mem_new): 237 """ 238 Get and return the original configuration enable_shared_mem value. 239 Set the new configuration enable_shared_mem value. 240 """ 241 enable_shared_mem_original = ds.config.get_enable_shared_mem() 242 ds.config.set_enable_shared_mem(enable_shared_mem_new) 243 logger.info("enable_shared_mem: original = {} new = {} ".format(enable_shared_mem_original, 244 enable_shared_mem_new)) 245 return enable_shared_mem_original 246 247 248def diff_mse(in1, in2): 249 mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() 250 return mse * 100 251 252 253def diff_me(in1, in2): 254 mse = (np.abs(in1.astype(float) - in2.astype(float))).mean() 255 return mse / 255 * 100 256 257 258def visualize_audio(waveform, expect_waveform): 259 """ 260 Visualizes audio waveform. 261 """ 262 plt.figure(1) 263 plt.subplot(1, 3, 1) 264 plt.imshow(waveform) 265 plt.title("waveform") 266 267 plt.subplot(1, 3, 2) 268 plt.imshow(expect_waveform) 269 plt.title("expect waveform") 270 271 plt.subplot(1, 3, 3) 272 plt.imshow(waveform - expect_waveform) 273 plt.title("difference") 274 275 plt.show() 276 277 278def visualize_one_channel_dataset(images_original, images_transformed, labels): 279 """ 280 Helper function to visualize one channel grayscale images 281 """ 282 num_samples = len(images_original) 283 for i in range(num_samples): 284 plt.subplot(2, num_samples, i + 1) 285 # Note: Use squeeze() to convert (H, W, 1) images to (H, W) 286 plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray) 287 plt.title(PLOT_TITLE_DICT.get(1)[0] + ":" + str(labels[i])) 288 289 plt.subplot(2, num_samples, i + num_samples + 1) 290 plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray) 291 plt.title(PLOT_TITLE_DICT.get(1)[1] + ":" + str(labels[i])) 292 plt.show() 293 294 295def visualize_list(image_list_1, image_list_2=None, visualize_mode=1): 296 """ 297 visualizes one or two lists of images using DE op 298 """ 299 plot_title = PLOT_TITLE_DICT[visualize_mode] 300 num = len(image_list_1) 301 for i in range(num): 302 plt.subplot(2, num, i + 1) 303 plt.imshow(image_list_1[i]) 304 plt.title(plot_title[0]) 305 306 if image_list_2 is not None: 307 plt.subplot(2, num, i + num + 1) 308 plt.imshow(image_list_2[i]) 309 plt.title(plot_title[1]) 310 311 plt.show() 312 313 314def visualize_image(image_original, image_de, mse=None, image_lib=None): 315 """ 316 visualizes one example image with optional input: mse, image using 3rd party op. 317 If three images are passing in, different image is calculated by 2nd and 3rd images. 318 """ 319 num = 2 320 if image_lib is not None: 321 num += 1 322 if mse is not None: 323 num += 1 324 plt.subplot(1, num, 1) 325 plt.imshow(image_original) 326 plt.title("Original image") 327 328 plt.subplot(1, num, 2) 329 plt.imshow(image_de) 330 plt.title("DE Op image") 331 332 if image_lib is not None: 333 plt.subplot(1, num, 3) 334 plt.imshow(image_lib) 335 plt.title("Lib Op image") 336 if mse is not None: 337 plt.subplot(1, num, 4) 338 plt.imshow(image_de - image_lib) 339 plt.title("Diff image,\n mse : {}".format(mse)) 340 elif mse is not None: 341 plt.subplot(1, num, 3) 342 plt.imshow(image_original - image_de) 343 plt.title("Diff image,\n mse : {}".format(mse)) 344 345 plt.show() 346 347 348def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3): 349 """ 350 Take a list of un-augmented and augmented images with "bbox" bounding boxes 351 Plot images to compare test correct BBox augment functionality 352 :param orig: list of original images and bboxes (without aug) 353 :param aug: list of augmented images and bboxes 354 :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC) 355 :param plot_rows: number of rows on plot (rows = samples on one plot) 356 :return: None 357 """ 358 359 def add_bounding_boxes(ax, bboxes): 360 for bbox in bboxes: 361 rect = patches.Rectangle((bbox[0], bbox[1]), 362 bbox[2] * 0.997, bbox[3] * 0.997, 363 linewidth=1.80, edgecolor='r', facecolor='none') 364 # Add the patch to the Axes 365 # Params to Rectangle slightly modified to prevent drawing overflow 366 ax.add_patch(rect) 367 368 # Quick check to confirm correct input parameters 369 if not isinstance(orig, list) or not isinstance(aug, list): 370 return 371 if len(orig) != len(aug) or not orig: 372 return 373 374 # creates batches of images to plot together 375 batch_size = int(len(orig) / plot_rows) 376 split_point = batch_size * plot_rows 377 378 orig, aug = np.array(orig), np.array(aug) 379 380 if len(orig) > plot_rows: 381 # Create batches of required size and add remainder to last batch 382 orig = np.split(orig[:split_point], batch_size) + ( 383 [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added 384 aug = np.split(aug[:split_point], batch_size) + \ 385 ([aug[split_point:]] if (split_point < aug.shape[0]) else []) 386 else: 387 orig = [orig] 388 aug = [aug] 389 390 for ix, all_data in enumerate(zip(orig, aug)): 391 base_ix = ix * plot_rows # current batch starting index 392 cur_plot = len(all_data[0]) 393 394 fig, axs = plt.subplots(cur_plot, 2) 395 fig.tight_layout(pad=1.5) 396 397 for x, (data_a, data_b) in enumerate(zip(all_data[0], all_data[1])): 398 cur_ix = base_ix + x 399 # select plotting axes based on number of image rows on plot - else case when 1 row 400 (ax_a, ax_b) = (axs[x, 0], axs[x, 1]) if ( 401 cur_plot > 1) else (axs[0], axs[1]) 402 403 ax_a.imshow(data_a["image"]) 404 add_bounding_boxes(ax_a, data_a[annot_name]) 405 ax_a.title.set_text("Original" + str(cur_ix + 1)) 406 407 ax_b.imshow(data_b["image"]) 408 add_bounding_boxes(ax_b, data_b[annot_name]) 409 ax_b.title.set_text("Augmented" + str(cur_ix + 1)) 410 411 logger.info( 412 "Original **\n{} : {}".format(str(cur_ix + 1), data_a[annot_name])) 413 logger.info( 414 "Augmented **\n{} : {}\n".format(str(cur_ix + 1), data_b[annot_name])) 415 416 plt.show() 417 418 419def helper_perform_ops_bbox(data, test_op=None, edge_case=False): 420 """ 421 Transform data based on test_op and whether it is an edge_case 422 :param data: original images 423 :param test_op: random operation being tested 424 :param edge_case: boolean whether edge_case is being augmented (note only for bbox data type edge case) 425 :return: transformed data 426 """ 427 if edge_case: 428 if test_op: 429 return data.map( 430 operations=[lambda img, bboxes: ( 431 img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op], 432 input_columns=["image", "bbox"], 433 output_columns=["image", "bbox"]) 434 return data.map( 435 operations=[lambda img, bboxes: ( 436 img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))], 437 input_columns=["image", "bbox"], 438 output_columns=["image", "bbox"]) 439 440 if test_op: 441 return data.map(operations=[test_op], input_columns=["image", "bbox"], 442 output_columns=["image", "bbox"]) 443 444 return data 445 446 447def helper_perform_ops_bbox_edgecase_float(data): 448 """ 449 Transform data based an edge_case covering full image with float32 450 :param data: original images 451 :return: transformed data 452 """ 453 return data.map(operations=lambda img, bbox: (img, np.array( 454 [[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)), 455 input_columns=["image", "bbox"], 456 output_columns=["image", "bbox"]) 457 458 459def helper_test_visual_bbox(plot_vis, data1, data2): 460 """ 461 Create list based of original images and bboxes with and without aug 462 :param plot_vis: boolean based on the test argument 463 :param data1: data without test_op 464 :param data2: data with test_op 465 :return: None 466 """ 467 unaug_samp, aug_samp = [], [] 468 469 for un_aug, aug in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 470 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 471 unaug_samp.append(un_aug) 472 aug_samp.append(aug) 473 474 if plot_vis: 475 visualize_with_bounding_boxes(unaug_samp, aug_samp) 476 477 478def helper_random_op_pipeline(data_dir, additional_op=None): 479 """ 480 Create an original/transformed images at data_dir based on additional_op 481 :param data_dir: directory of the data 482 :param additional_op: additional operation to be pipelined, if None, then gives original images 483 :return: transformed image 484 """ 485 data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) 486 transforms = [vision.Decode(), vision.Resize(size=[224, 224])] 487 if additional_op: 488 transforms.append(additional_op) 489 ds_transformed = data_set.map(operations=transforms, input_columns="image") 490 ds_transformed = ds_transformed.batch(512) 491 492 for idx, (image, _) in enumerate(ds_transformed): 493 if idx == 0: 494 images_transformed = image.asnumpy() 495 else: 496 images_transformed = np.append(images_transformed, 497 image.asnumpy(), 498 axis=0) 499 500 return images_transformed 501 502 503def helper_invalid_bounding_box_test(data_dir, test_op): 504 """ 505 Helper function for invalid bounding box test by calling check_bad_bbox 506 :param data_dir: directory of the data 507 :param test_op: operation that is being tested 508 :return: None 509 """ 510 data = ds.VOCDataset( 511 data_dir, task="Detection", usage="train", shuffle=False, decode=True) 512 check_bad_bbox(data, test_op, InvalidBBoxType.WidthOverflow, 513 "bounding boxes is out of bounds of the image") 514 data = ds.VOCDataset( 515 data_dir, task="Detection", usage="train", shuffle=False, decode=True) 516 check_bad_bbox(data, test_op, InvalidBBoxType.HeightOverflow, 517 "bounding boxes is out of bounds of the image") 518 data = ds.VOCDataset( 519 data_dir, task="Detection", usage="train", shuffle=False, decode=True) 520 check_bad_bbox(data, test_op, 521 InvalidBBoxType.NegativeXY, "negative value") 522 data = ds.VOCDataset( 523 data_dir, task="Detection", usage="train", shuffle=False, decode=True) 524 check_bad_bbox(data, test_op, 525 InvalidBBoxType.WrongShape, "4 features") 526 527 528class InvalidBBoxType(Enum): 529 """ 530 Defines Invalid Bounding Bbox types for test cases 531 """ 532 WidthOverflow = 1 533 HeightOverflow = 2 534 NegativeXY = 3 535 WrongShape = 4 536 537 538def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): 539 """ 540 :param data: de object detection pipeline 541 :param test_op: Augmentation Op to test on image 542 :param invalid_bbox_type: type of bad box 543 :param expected_error: error expected to get due to bad box 544 :return: None 545 """ 546 547 def add_bad_bbox(img, bboxes, invalid_bbox_type_): 548 """ 549 Used to generate erroneous bounding box examples on given img. 550 :param img: image where the bounding boxes are. 551 :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format 552 :param box_type_: type of bad box 553 :return: bboxes with bad examples added 554 """ 555 height = img.shape[0] 556 width = img.shape[1] 557 if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow: 558 # use box that overflows on width 559 return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32) 560 561 if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow: 562 # use box that overflows on height 563 return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32) 564 565 if invalid_bbox_type_ == InvalidBBoxType.NegativeXY: 566 # use box with negative xy 567 return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32) 568 569 if invalid_bbox_type_ == InvalidBBoxType.WrongShape: 570 # use box that has incorrect shape 571 return img, np.array([[0, 0, width - 1]]).astype(np.float32) 572 return img, bboxes 573 574 try: 575 # map to use selected invalid bounding box type 576 data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type), 577 input_columns=["image", "bbox"], 578 output_columns=["image", "bbox"]) 579 # map to apply ops 580 data = data.map(operations=[test_op], input_columns=["image", "bbox"], 581 output_columns=["image", "bbox"]) 582 for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 583 break 584 except RuntimeError as error: 585 logger.info("Got an exception in DE: {}".format(str(error))) 586 assert expected_error in str(error) 587 588 589# return true if datasets are equal 590def dataset_equal(data1, data2, mse_threshold): 591 if data1.get_dataset_size() != data2.get_dataset_size(): 592 return False 593 equal = True 594 for item1, item2 in itertools.zip_longest(data1, data2): 595 for column1, column2 in itertools.zip_longest(item1, item2): 596 mse = diff_mse(column1.asnumpy(), column2.asnumpy()) 597 if mse > mse_threshold: 598 equal = False 599 break 600 if not equal: 601 break 602 return equal 603 604 605# return true if datasets are equal after modification to target 606# params: data_unchanged - dataset kept unchanged 607# data_target - dataset to be modified by foo 608# mse_threshold - maximum allowable value of mse 609# foo - function applied to data_target columns BEFORE compare 610# foo_args - arguments passed into foo 611def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args): 612 if data_unchanged.get_dataset_size() != data_target.get_dataset_size(): 613 return False 614 equal = True 615 for item1, item2 in itertools.zip_longest(data_unchanged, data_target): 616 for column1, column2 in itertools.zip_longest(item1, item2): 617 # note the function is to be applied to the second dataset 618 column2 = foo(column2.asnumpy(), *foo_args) 619 mse = diff_mse(column1.asnumpy(), column2) 620 if mse > mse_threshold: 621 equal = False 622 break 623 if not equal: 624 break 625 return equal 626