• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import hashlib
17import json
18import os
19import itertools
20from enum import Enum
21import numpy as np
22import matplotlib.pyplot as plt
23import matplotlib.patches as patches
24import PIL
25# import jsbeautifier
26import mindspore.dataset as ds
27from mindspore import log as logger
28import mindspore.dataset.vision.transforms as vision
29
30# These are list of plot title in different visualize modes
31PLOT_TITLE_DICT = {
32    1: ["Original image", "Transformed image"],
33    2: ["c_transform image", "py_transform image"]
34}
35SAVE_JSON = False
36
37
38def _save_golden(cur_dir, golden_ref_dir, result_dict):
39    """
40    Save the dictionary values as the golden result in .npz file
41    """
42    logger.info("cur_dir is {}".format(cur_dir))
43    logger.info("golden_ref_dir is {}".format(golden_ref_dir))
44    np.savez(golden_ref_dir, np.array(list(result_dict.values())))
45
46
47def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
48    """
49    Save the dictionary (both keys and values) as the golden result in .npz file
50    """
51    logger.info("cur_dir is {}".format(cur_dir))
52    logger.info("golden_ref_dir is {}".format(golden_ref_dir))
53    np.savez(golden_ref_dir, np.array(list(result_dict.items())))
54
55
56def _compare_to_golden(golden_ref_dir, result_dict):
57    """
58    Compare as numpy arrays the test result to the golden result
59    """
60    test_array = np.array(list(result_dict.values()))
61    golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
62    np.testing.assert_array_equal(test_array, golden_array)
63
64
65def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False):
66    """
67    Compare as dictionaries the test result to the golden result
68    """
69    golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
70    # Note: The version of PILLOW that is used in Jenkins CI is compared with below
71    if check_pillow_version and PIL.__version__ >= '9.0.0':
72        try:
73            np.testing.assert_equal(result_dict, dict(golden_array))
74        except AssertionError:
75            logger.warning(
76                "Results from Pillow >= 9.0.0 is incompatibale with Pillow < 9.0.0, need more validation.")
77    elif check_pillow_version:
78        # Note: The version of PILLOW that is used in Jenkins CI is >= 9.0.0 and
79        #       some of the md5 results files that are generated with PILLOW 7.2.0
80        #       are not compatible with PILLOW 9.0.0.
81        np.testing.assert_equal(result_dict, dict(golden_array),
82                                'Items are not equal and problem may be due to PILLOW version incompatibility')
83    else:
84        np.testing.assert_equal(result_dict, dict(golden_array))
85
86
87def _save_json(filename, parameters, result_dict):
88    """
89    Save the result dictionary in json file
90    """
91    fout = open(filename[:-3] + "json", "w")
92    options = jsbeautifier.default_options()
93    options.indent_size = 2
94
95    out_dict = {**parameters, **{"columns": result_dict}}
96    fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
97
98
99def save_and_check_dict(data, filename, generate_golden=False):
100    """
101    Save the dataset dictionary and compare (as dictionary) with golden file.
102    Use create_dict_iterator to access the dataset.
103    """
104    num_iter = 0
105    result_dict = {}
106
107    # each data is a dictionary
108    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
109        for data_key in list(item.keys()):
110            if data_key not in result_dict:
111                result_dict[data_key] = []
112            result_dict[data_key].append(item[data_key].tolist())
113        num_iter += 1
114
115    logger.info("Number of data in ds1: {}".format(num_iter))
116
117    cur_dir = os.path.dirname(os.path.realpath(__file__))
118    golden_ref_dir = os.path.join(
119        cur_dir, "../../data/dataset", 'golden', filename)
120    if generate_golden:
121        # Save as the golden result
122        _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
123
124    _compare_to_golden_dict(golden_ref_dir, result_dict, False)
125
126    if SAVE_JSON:
127        # Save result to a json file for inspection
128        parameters = {"params": {}}
129        _save_json(filename, parameters, result_dict)
130
131
132def _helper_save_and_check_md5(data, filename, generate_golden=False):
133    """
134    Helper for save_and_check_md5 for both PIL and non-PIL
135    """
136    num_iter = 0
137    result_dict = {}
138
139    # each data is a dictionary
140    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
141        for data_key in list(item.keys()):
142            if data_key not in result_dict:
143                result_dict[data_key] = []
144            # save the md5 as numpy array
145            result_dict.get(data_key).append(np.frombuffer(
146                hashlib.md5(item[data_key]).digest(), dtype='<f4'))
147        num_iter += 1
148
149    logger.info("Number of data in ds1: {}".format(num_iter))
150
151    cur_dir = os.path.dirname(os.path.realpath(__file__))
152    golden_ref_dir = os.path.join(
153        cur_dir, "../../data/dataset", 'golden', filename)
154    if generate_golden:
155        # Save as the golden result
156        _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
157
158    return golden_ref_dir, result_dict
159
160
161def save_and_check_md5(data, filename, generate_golden=False):
162    """
163    Save the dataset dictionary and compare (as dictionary) with golden file (md5) for non-PIL only.
164    Use create_dict_iterator to access the dataset.
165    """
166    golden_ref_dir, result_dict = _helper_save_and_check_md5(data, filename, generate_golden)
167    _compare_to_golden_dict(golden_ref_dir, result_dict, False)
168
169
170def save_and_check_md5_pil(data, filename, generate_golden=False):
171    """
172    Save the dataset dictionary and compare (as dictionary) with golden file (md5) for PIL only.
173    If PIL version >= 9.0.0, only log warning when assertion fails and allow the test to succeed.
174    Use create_dict_iterator to access the dataset.
175    """
176    golden_ref_dir, result_dict = _helper_save_and_check_md5(data, filename, generate_golden)
177    _compare_to_golden_dict(golden_ref_dir, result_dict, True)
178
179
180def save_and_check_tuple(data, parameters, filename, generate_golden=False):
181    """
182    Save the dataset dictionary and compare (as numpy array) with golden file.
183    Use create_tuple_iterator to access the dataset.
184    """
185    num_iter = 0
186    result_dict = {}
187
188    # each data is a dictionary
189    for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
190        for data_key, _ in enumerate(item):
191            if data_key not in result_dict:
192                result_dict[data_key] = []
193            result_dict[data_key].append(item[data_key].tolist())
194        num_iter += 1
195
196    logger.info("Number of data in data1: {}".format(num_iter))
197
198    cur_dir = os.path.dirname(os.path.realpath(__file__))
199    golden_ref_dir = os.path.join(
200        cur_dir, "../../data/dataset", 'golden', filename)
201    if generate_golden:
202        # Save as the golden result
203        _save_golden(cur_dir, golden_ref_dir, result_dict)
204
205    _compare_to_golden(golden_ref_dir, result_dict)
206
207    if SAVE_JSON:
208        # Save result to a json file for inspection
209        _save_json(filename, parameters, result_dict)
210
211
212def config_get_set_seed(seed_new):
213    """
214    Get and return the original configuration seed value.
215    Set the new configuration seed value.
216    """
217    seed_original = ds.config.get_seed()
218    ds.config.set_seed(seed_new)
219    logger.info("seed: original = {}  new = {} ".format(
220        seed_original, seed_new))
221    return seed_original
222
223
224def config_get_set_num_parallel_workers(num_parallel_workers_new):
225    """
226    Get and return the original configuration num_parallel_workers value.
227    Set the new configuration num_parallel_workers value.
228    """
229    num_parallel_workers_original = ds.config.get_num_parallel_workers()
230    ds.config.set_num_parallel_workers(num_parallel_workers_new)
231    logger.info("num_parallel_workers: original = {}  new = {} ".format(num_parallel_workers_original,
232                                                                        num_parallel_workers_new))
233    return num_parallel_workers_original
234
235
236def config_get_set_enable_shared_mem(enable_shared_mem_new):
237    """
238    Get and return the original configuration enable_shared_mem value.
239    Set the new configuration enable_shared_mem value.
240    """
241    enable_shared_mem_original = ds.config.get_enable_shared_mem()
242    ds.config.set_enable_shared_mem(enable_shared_mem_new)
243    logger.info("enable_shared_mem: original = {}  new = {} ".format(enable_shared_mem_original,
244                                                                     enable_shared_mem_new))
245    return enable_shared_mem_original
246
247
248def diff_mse(in1, in2):
249    mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
250    return mse * 100
251
252
253def diff_me(in1, in2):
254    mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
255    return mse / 255 * 100
256
257
258def visualize_audio(waveform, expect_waveform):
259    """
260    Visualizes audio waveform.
261    """
262    plt.figure(1)
263    plt.subplot(1, 3, 1)
264    plt.imshow(waveform)
265    plt.title("waveform")
266
267    plt.subplot(1, 3, 2)
268    plt.imshow(expect_waveform)
269    plt.title("expect waveform")
270
271    plt.subplot(1, 3, 3)
272    plt.imshow(waveform - expect_waveform)
273    plt.title("difference")
274
275    plt.show()
276
277
278def visualize_one_channel_dataset(images_original, images_transformed, labels):
279    """
280    Helper function to visualize one channel grayscale images
281    """
282    num_samples = len(images_original)
283    for i in range(num_samples):
284        plt.subplot(2, num_samples, i + 1)
285        # Note: Use squeeze() to convert (H, W, 1) images to (H, W)
286        plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray)
287        plt.title(PLOT_TITLE_DICT.get(1)[0] + ":" + str(labels[i]))
288
289        plt.subplot(2, num_samples, i + num_samples + 1)
290        plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray)
291        plt.title(PLOT_TITLE_DICT.get(1)[1] + ":" + str(labels[i]))
292    plt.show()
293
294
295def visualize_list(image_list_1, image_list_2=None, visualize_mode=1):
296    """
297    visualizes one or two lists of images using DE op
298    """
299    plot_title = PLOT_TITLE_DICT[visualize_mode]
300    num = len(image_list_1)
301    for i in range(num):
302        plt.subplot(2, num, i + 1)
303        plt.imshow(image_list_1[i])
304        plt.title(plot_title[0])
305
306        if image_list_2 is not None:
307            plt.subplot(2, num, i + num + 1)
308            plt.imshow(image_list_2[i])
309            plt.title(plot_title[1])
310
311    plt.show()
312
313
314def visualize_image(image_original, image_de, mse=None, image_lib=None):
315    """
316    visualizes one example image with optional input: mse, image using 3rd party op.
317    If three images are passing in, different image is calculated by 2nd and 3rd images.
318    """
319    num = 2
320    if image_lib is not None:
321        num += 1
322    if mse is not None:
323        num += 1
324    plt.subplot(1, num, 1)
325    plt.imshow(image_original)
326    plt.title("Original image")
327
328    plt.subplot(1, num, 2)
329    plt.imshow(image_de)
330    plt.title("DE Op image")
331
332    if image_lib is not None:
333        plt.subplot(1, num, 3)
334        plt.imshow(image_lib)
335        plt.title("Lib Op image")
336        if mse is not None:
337            plt.subplot(1, num, 4)
338            plt.imshow(image_de - image_lib)
339            plt.title("Diff image,\n mse : {}".format(mse))
340    elif mse is not None:
341        plt.subplot(1, num, 3)
342        plt.imshow(image_original - image_de)
343        plt.title("Diff image,\n mse : {}".format(mse))
344
345    plt.show()
346
347
348def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
349    """
350    Take a list of un-augmented and augmented images with "bbox" bounding boxes
351    Plot images to compare test correct BBox augment functionality
352    :param orig: list of original images and bboxes (without aug)
353    :param aug: list of augmented images and bboxes
354    :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
355    :param plot_rows: number of rows on plot (rows = samples on one plot)
356    :return: None
357    """
358
359    def add_bounding_boxes(ax, bboxes):
360        for bbox in bboxes:
361            rect = patches.Rectangle((bbox[0], bbox[1]),
362                                     bbox[2] * 0.997, bbox[3] * 0.997,
363                                     linewidth=1.80, edgecolor='r', facecolor='none')
364            # Add the patch to the Axes
365            # Params to Rectangle slightly modified to prevent drawing overflow
366            ax.add_patch(rect)
367
368    # Quick check to confirm correct input parameters
369    if not isinstance(orig, list) or not isinstance(aug, list):
370        return
371    if len(orig) != len(aug) or not orig:
372        return
373
374    # creates batches of images to plot together
375    batch_size = int(len(orig) / plot_rows)
376    split_point = batch_size * plot_rows
377
378    orig, aug = np.array(orig), np.array(aug)
379
380    if len(orig) > plot_rows:
381        # Create batches of required size and add remainder to last batch
382        orig = np.split(orig[:split_point], batch_size) + (
383            [orig[split_point:]] if (split_point < orig.shape[0]) else [])  # check to avoid empty arrays being added
384        aug = np.split(aug[:split_point], batch_size) + \
385            ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
386    else:
387        orig = [orig]
388        aug = [aug]
389
390    for ix, all_data in enumerate(zip(orig, aug)):
391        base_ix = ix * plot_rows  # current batch starting index
392        cur_plot = len(all_data[0])
393
394        fig, axs = plt.subplots(cur_plot, 2)
395        fig.tight_layout(pad=1.5)
396
397        for x, (data_a, data_b) in enumerate(zip(all_data[0], all_data[1])):
398            cur_ix = base_ix + x
399            # select plotting axes based on number of image rows on plot - else case when 1 row
400            (ax_a, ax_b) = (axs[x, 0], axs[x, 1]) if (
401                cur_plot > 1) else (axs[0], axs[1])
402
403            ax_a.imshow(data_a["image"])
404            add_bounding_boxes(ax_a, data_a[annot_name])
405            ax_a.title.set_text("Original" + str(cur_ix + 1))
406
407            ax_b.imshow(data_b["image"])
408            add_bounding_boxes(ax_b, data_b[annot_name])
409            ax_b.title.set_text("Augmented" + str(cur_ix + 1))
410
411            logger.info(
412                "Original **\n{} : {}".format(str(cur_ix + 1), data_a[annot_name]))
413            logger.info(
414                "Augmented **\n{} : {}\n".format(str(cur_ix + 1), data_b[annot_name]))
415
416        plt.show()
417
418
419def helper_perform_ops_bbox(data, test_op=None, edge_case=False):
420    """
421    Transform data based on test_op and whether it is an edge_case
422    :param data: original images
423    :param test_op: random operation being tested
424    :param edge_case: boolean whether edge_case is being augmented (note only for bbox data type edge case)
425    :return: transformed data
426    """
427    if edge_case:
428        if test_op:
429            return data.map(
430                operations=[lambda img, bboxes: (
431                    img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op],
432                input_columns=["image", "bbox"],
433                output_columns=["image", "bbox"])
434        return data.map(
435            operations=[lambda img, bboxes: (
436                img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))],
437            input_columns=["image", "bbox"],
438            output_columns=["image", "bbox"])
439
440    if test_op:
441        return data.map(operations=[test_op], input_columns=["image", "bbox"],
442                        output_columns=["image", "bbox"])
443
444    return data
445
446
447def helper_perform_ops_bbox_edgecase_float(data):
448    """
449    Transform data based an edge_case covering full image with float32
450    :param data: original images
451    :return: transformed data
452    """
453    return data.map(operations=lambda img, bbox: (img, np.array(
454        [[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)),
455                    input_columns=["image", "bbox"],
456                    output_columns=["image", "bbox"])
457
458
459def helper_test_visual_bbox(plot_vis, data1, data2):
460    """
461    Create list based of original images and bboxes with and without aug
462    :param plot_vis: boolean based on the test argument
463    :param data1: data without test_op
464    :param data2: data with test_op
465    :return: None
466    """
467    unaug_samp, aug_samp = [], []
468
469    for un_aug, aug in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
470                           data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
471        unaug_samp.append(un_aug)
472        aug_samp.append(aug)
473
474    if plot_vis:
475        visualize_with_bounding_boxes(unaug_samp, aug_samp)
476
477
478def helper_random_op_pipeline(data_dir, additional_op=None):
479    """
480    Create an original/transformed images at data_dir based on additional_op
481    :param data_dir: directory of the data
482    :param additional_op: additional operation to be pipelined, if None, then gives original images
483    :return: transformed image
484    """
485    data_set = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
486    transforms = [vision.Decode(), vision.Resize(size=[224, 224])]
487    if additional_op:
488        transforms.append(additional_op)
489    ds_transformed = data_set.map(operations=transforms, input_columns="image")
490    ds_transformed = ds_transformed.batch(512)
491
492    for idx, (image, _) in enumerate(ds_transformed):
493        if idx == 0:
494            images_transformed = image.asnumpy()
495        else:
496            images_transformed = np.append(images_transformed,
497                                           image.asnumpy(),
498                                           axis=0)
499
500    return images_transformed
501
502
503def helper_invalid_bounding_box_test(data_dir, test_op):
504    """
505    Helper function for invalid bounding box test by calling check_bad_bbox
506    :param data_dir: directory of the data
507    :param test_op: operation that is being tested
508    :return: None
509    """
510    data = ds.VOCDataset(
511        data_dir, task="Detection", usage="train", shuffle=False, decode=True)
512    check_bad_bbox(data, test_op, InvalidBBoxType.WidthOverflow,
513                   "bounding boxes is out of bounds of the image")
514    data = ds.VOCDataset(
515        data_dir, task="Detection", usage="train", shuffle=False, decode=True)
516    check_bad_bbox(data, test_op, InvalidBBoxType.HeightOverflow,
517                   "bounding boxes is out of bounds of the image")
518    data = ds.VOCDataset(
519        data_dir, task="Detection", usage="train", shuffle=False, decode=True)
520    check_bad_bbox(data, test_op,
521                   InvalidBBoxType.NegativeXY, "negative value")
522    data = ds.VOCDataset(
523        data_dir, task="Detection", usage="train", shuffle=False, decode=True)
524    check_bad_bbox(data, test_op,
525                   InvalidBBoxType.WrongShape, "4 features")
526
527
528class InvalidBBoxType(Enum):
529    """
530    Defines Invalid Bounding Bbox types for test cases
531    """
532    WidthOverflow = 1
533    HeightOverflow = 2
534    NegativeXY = 3
535    WrongShape = 4
536
537
538def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
539    """
540    :param data: de object detection pipeline
541    :param test_op: Augmentation Op to test on image
542    :param invalid_bbox_type: type of bad box
543    :param expected_error: error expected to get due to bad box
544    :return: None
545    """
546
547    def add_bad_bbox(img, bboxes, invalid_bbox_type_):
548        """
549        Used to generate erroneous bounding box examples on given img.
550        :param img: image where the bounding boxes are.
551        :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
552        :param box_type_: type of bad box
553        :return: bboxes with bad examples added
554        """
555        height = img.shape[0]
556        width = img.shape[1]
557        if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
558            # use box that overflows on width
559            return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
560
561        if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
562            # use box that overflows on height
563            return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
564
565        if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
566            # use box with negative xy
567            return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
568
569        if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
570            # use box that has incorrect shape
571            return img, np.array([[0, 0, width - 1]]).astype(np.float32)
572        return img, bboxes
573
574    try:
575        # map to use selected invalid bounding box type
576        data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type),
577                        input_columns=["image", "bbox"],
578                        output_columns=["image", "bbox"])
579        # map to apply ops
580        data = data.map(operations=[test_op], input_columns=["image", "bbox"],
581                        output_columns=["image", "bbox"])
582        for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
583            break
584    except RuntimeError as error:
585        logger.info("Got an exception in DE: {}".format(str(error)))
586        assert expected_error in str(error)
587
588
589# return true if datasets are equal
590def dataset_equal(data1, data2, mse_threshold):
591    if data1.get_dataset_size() != data2.get_dataset_size():
592        return False
593    equal = True
594    for item1, item2 in itertools.zip_longest(data1, data2):
595        for column1, column2 in itertools.zip_longest(item1, item2):
596            mse = diff_mse(column1.asnumpy(), column2.asnumpy())
597            if mse > mse_threshold:
598                equal = False
599                break
600        if not equal:
601            break
602    return equal
603
604
605# return true if datasets are equal after modification to target
606# params: data_unchanged - dataset kept unchanged
607#         data_target    - dataset to be modified by foo
608#         mse_threshold  - maximum allowable value of mse
609#         foo            - function applied to data_target columns BEFORE compare
610#         foo_args       - arguments passed into foo
611def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
612    if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
613        return False
614    equal = True
615    for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
616        for column1, column2 in itertools.zip_longest(item1, item2):
617            # note the function is to be applied to the second dataset
618            column2 = foo(column2.asnumpy(), *foo_args)
619            mse = diff_mse(column1.asnumpy(), column2)
620            if mse > mse_threshold:
621                equal = False
622                break
623        if not equal:
624            break
625    return equal
626