• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import hashlib
17import json
18import os
19import itertools
20from enum import Enum
21import numpy as np
22import matplotlib.pyplot as plt
23import matplotlib.patches as patches
24import PIL
25# import jsbeautifier
26import mindspore.dataset as ds
27from mindspore import log as logger
28
29# These are list of plot title in different visualize modes
30PLOT_TITLE_DICT = {
31    1: ["Original image", "Transformed image"],
32    2: ["c_transform image", "py_transform image"]
33}
34SAVE_JSON = False
35
36
37def _save_golden(cur_dir, golden_ref_dir, result_dict):
38    """
39    Save the dictionary values as the golden result in .npz file
40    """
41    logger.info("cur_dir is {}".format(cur_dir))
42    logger.info("golden_ref_dir is {}".format(golden_ref_dir))
43    np.savez(golden_ref_dir, np.array(list(result_dict.values())))
44
45
46def _save_golden_dict(cur_dir, golden_ref_dir, result_dict):
47    """
48    Save the dictionary (both keys and values) as the golden result in .npz file
49    """
50    logger.info("cur_dir is {}".format(cur_dir))
51    logger.info("golden_ref_dir is {}".format(golden_ref_dir))
52    np.savez(golden_ref_dir, np.array(list(result_dict.items())))
53
54
55def _compare_to_golden(golden_ref_dir, result_dict):
56    """
57    Compare as numpy arrays the test result to the golden result
58    """
59    test_array = np.array(list(result_dict.values()))
60    golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
61    np.testing.assert_array_equal(test_array, golden_array)
62
63
64def _compare_to_golden_dict(golden_ref_dir, result_dict, check_pillow_version=False):
65    """
66    Compare as dictionaries the test result to the golden result
67    """
68    golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
69    # Note: The version of PILLOW that is used in Jenkins CI is compared with below
70    if (not check_pillow_version or PIL.__version__ == '7.1.2'):
71        np.testing.assert_equal(result_dict, dict(golden_array))
72    else:
73        # Beware: If error, PILLOW version of golden results may be incompatible with current PILLOW version
74        np.testing.assert_equal(result_dict, dict(golden_array),
75                                'Items are not equal and problem may be due to PILLOW version incompatibility')
76
77def _save_json(filename, parameters, result_dict):
78    """
79    Save the result dictionary in json file
80    """
81    fout = open(filename[:-3] + "json", "w")
82    options = jsbeautifier.default_options()
83    options.indent_size = 2
84
85    out_dict = {**parameters, **{"columns": result_dict}}
86    fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
87
88
89def save_and_check_dict(data, filename, generate_golden=False):
90    """
91    Save the dataset dictionary and compare (as dictionary) with golden file.
92    Use create_dict_iterator to access the dataset.
93    """
94    num_iter = 0
95    result_dict = {}
96
97    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
98        for data_key in list(item.keys()):
99            if data_key not in result_dict:
100                result_dict[data_key] = []
101            result_dict[data_key].append(item[data_key].tolist())
102        num_iter += 1
103
104    logger.info("Number of data in ds1: {}".format(num_iter))
105
106    cur_dir = os.path.dirname(os.path.realpath(__file__))
107    golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
108    if generate_golden:
109        # Save as the golden result
110        _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
111
112    _compare_to_golden_dict(golden_ref_dir, result_dict, False)
113
114    if SAVE_JSON:
115        # Save result to a json file for inspection
116        parameters = {"params": {}}
117        _save_json(filename, parameters, result_dict)
118
119
120def save_and_check_md5(data, filename, generate_golden=False):
121    """
122    Save the dataset dictionary and compare (as dictionary) with golden file (md5).
123    Use create_dict_iterator to access the dataset.
124    """
125    num_iter = 0
126    result_dict = {}
127
128    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
129        for data_key in list(item.keys()):
130            if data_key not in result_dict:
131                result_dict[data_key] = []
132            # save the md5 as numpy array
133            result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='<f4'))
134        num_iter += 1
135
136    logger.info("Number of data in ds1: {}".format(num_iter))
137
138    cur_dir = os.path.dirname(os.path.realpath(__file__))
139    golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
140    if generate_golden:
141        # Save as the golden result
142        _save_golden_dict(cur_dir, golden_ref_dir, result_dict)
143
144    _compare_to_golden_dict(golden_ref_dir, result_dict, True)
145
146
147def save_and_check_tuple(data, parameters, filename, generate_golden=False):
148    """
149    Save the dataset dictionary and compare (as numpy array) with golden file.
150    Use create_tuple_iterator to access the dataset.
151    """
152    num_iter = 0
153    result_dict = {}
154
155    for item in data.create_tuple_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
156        for data_key, _ in enumerate(item):
157            if data_key not in result_dict:
158                result_dict[data_key] = []
159            result_dict[data_key].append(item[data_key].tolist())
160        num_iter += 1
161
162    logger.info("Number of data in data1: {}".format(num_iter))
163
164    cur_dir = os.path.dirname(os.path.realpath(__file__))
165    golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
166    if generate_golden:
167        # Save as the golden result
168        _save_golden(cur_dir, golden_ref_dir, result_dict)
169
170    _compare_to_golden(golden_ref_dir, result_dict)
171
172    if SAVE_JSON:
173        # Save result to a json file for inspection
174        _save_json(filename, parameters, result_dict)
175
176
177def config_get_set_seed(seed_new):
178    """
179    Get and return the original configuration seed value.
180    Set the new configuration seed value.
181    """
182    seed_original = ds.config.get_seed()
183    ds.config.set_seed(seed_new)
184    logger.info("seed: original = {}  new = {} ".format(seed_original, seed_new))
185    return seed_original
186
187
188def config_get_set_num_parallel_workers(num_parallel_workers_new):
189    """
190    Get and return the original configuration num_parallel_workers value.
191    Set the new configuration num_parallel_workers value.
192    """
193    num_parallel_workers_original = ds.config.get_num_parallel_workers()
194    ds.config.set_num_parallel_workers(num_parallel_workers_new)
195    logger.info("num_parallel_workers: original = {}  new = {} ".format(num_parallel_workers_original,
196                                                                        num_parallel_workers_new))
197    return num_parallel_workers_original
198
199
200def diff_mse(in1, in2):
201    mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
202    return mse * 100
203
204
205def diff_me(in1, in2):
206    mse = (np.abs(in1.astype(float) - in2.astype(float))).mean()
207    return mse / 255 * 100
208
209
210def visualize_one_channel_dataset(images_original, images_transformed, labels):
211    """
212    Helper function to visualize one channel grayscale images
213    """
214    num_samples = len(images_original)
215    for i in range(num_samples):
216        plt.subplot(2, num_samples, i + 1)
217        # Note: Use squeeze() to convert (H, W, 1) images to (H, W)
218        plt.imshow(images_original[i].squeeze(), cmap=plt.cm.gray)
219        plt.title(PLOT_TITLE_DICT[1][0] + ":" + str(labels[i]))
220
221        plt.subplot(2, num_samples, i + num_samples + 1)
222        plt.imshow(images_transformed[i].squeeze(), cmap=plt.cm.gray)
223        plt.title(PLOT_TITLE_DICT[1][1] + ":" + str(labels[i]))
224    plt.show()
225
226
227def visualize_list(image_list_1, image_list_2, visualize_mode=1):
228    """
229    visualizes a list of images using DE op
230    """
231    plot_title = PLOT_TITLE_DICT[visualize_mode]
232    num = len(image_list_1)
233    for i in range(num):
234        plt.subplot(2, num, i + 1)
235        plt.imshow(image_list_1[i])
236        plt.title(plot_title[0])
237
238        plt.subplot(2, num, i + num + 1)
239        plt.imshow(image_list_2[i])
240        plt.title(plot_title[1])
241
242    plt.show()
243
244
245def visualize_image(image_original, image_de, mse=None, image_lib=None):
246    """
247    visualizes one example image with optional input: mse, image using 3rd party op.
248    If three images are passing in, different image is calculated by 2nd and 3rd images.
249    """
250    num = 2
251    if image_lib is not None:
252        num += 1
253    if mse is not None:
254        num += 1
255    plt.subplot(1, num, 1)
256    plt.imshow(image_original)
257    plt.title("Original image")
258
259    plt.subplot(1, num, 2)
260    plt.imshow(image_de)
261    plt.title("DE Op image")
262
263    if image_lib is not None:
264        plt.subplot(1, num, 3)
265        plt.imshow(image_lib)
266        plt.title("Lib Op image")
267        if mse is not None:
268            plt.subplot(1, num, 4)
269            plt.imshow(image_de - image_lib)
270            plt.title("Diff image,\n mse : {}".format(mse))
271    elif mse is not None:
272        plt.subplot(1, num, 3)
273        plt.imshow(image_original - image_de)
274        plt.title("Diff image,\n mse : {}".format(mse))
275
276    plt.show()
277
278
279def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
280    """
281    Take a list of un-augmented and augmented images with "bbox" bounding boxes
282    Plot images to compare test correct BBox augment functionality
283    :param orig: list of original images and bboxes (without aug)
284    :param aug: list of augmented images and bboxes
285    :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
286    :param plot_rows: number of rows on plot (rows = samples on one plot)
287    :return: None
288    """
289
290    def add_bounding_boxes(ax, bboxes):
291        for bbox in bboxes:
292            rect = patches.Rectangle((bbox[0], bbox[1]),
293                                     bbox[2] * 0.997, bbox[3] * 0.997,
294                                     linewidth=1.80, edgecolor='r', facecolor='none')
295            # Add the patch to the Axes
296            # Params to Rectangle slightly modified to prevent drawing overflow
297            ax.add_patch(rect)
298
299    # Quick check to confirm correct input parameters
300    if not isinstance(orig, list) or not isinstance(aug, list):
301        return
302    if len(orig) != len(aug) or not orig:
303        return
304
305    batch_size = int(len(orig) / plot_rows)  # creates batches of images to plot together
306    split_point = batch_size * plot_rows
307
308    orig, aug = np.array(orig), np.array(aug)
309
310    if len(orig) > plot_rows:
311        # Create batches of required size and add remainder to last batch
312        orig = np.split(orig[:split_point], batch_size) + (
313            [orig[split_point:]] if (split_point < orig.shape[0]) else [])  # check to avoid empty arrays being added
314        aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else [])
315    else:
316        orig = [orig]
317        aug = [aug]
318
319    for ix, allData in enumerate(zip(orig, aug)):
320        base_ix = ix * plot_rows  # current batch starting index
321        curPlot = len(allData[0])
322
323        fig, axs = plt.subplots(curPlot, 2)
324        fig.tight_layout(pad=1.5)
325
326        for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
327            cur_ix = base_ix + x
328            # select plotting axes based on number of image rows on plot - else case when 1 row
329            (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1])
330
331            axA.imshow(dataA["image"])
332            add_bounding_boxes(axA, dataA[annot_name])
333            axA.title.set_text("Original" + str(cur_ix + 1))
334
335            axB.imshow(dataB["image"])
336            add_bounding_boxes(axB, dataB[annot_name])
337            axB.title.set_text("Augmented" + str(cur_ix + 1))
338
339            logger.info("Original **\n{} : {}".format(str(cur_ix + 1), dataA[annot_name]))
340            logger.info("Augmented **\n{} : {}\n".format(str(cur_ix + 1), dataB[annot_name]))
341
342        plt.show()
343
344
345class InvalidBBoxType(Enum):
346    """
347    Defines Invalid Bounding Bbox types for test cases
348    """
349    WidthOverflow = 1
350    HeightOverflow = 2
351    NegativeXY = 3
352    WrongShape = 4
353
354
355def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
356    """
357    :param data: de object detection pipeline
358    :param test_op: Augmentation Op to test on image
359    :param invalid_bbox_type: type of bad box
360    :param expected_error: error expected to get due to bad box
361    :return: None
362    """
363
364    def add_bad_bbox(img, bboxes, invalid_bbox_type_):
365        """
366        Used to generate erroneous bounding box examples on given img.
367        :param img: image where the bounding boxes are.
368        :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
369        :param box_type_: type of bad box
370        :return: bboxes with bad examples added
371        """
372        height = img.shape[0]
373        width = img.shape[1]
374        if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow:
375            # use box that overflows on width
376            return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32)
377
378        if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow:
379            # use box that overflows on height
380            return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32)
381
382        if invalid_bbox_type_ == InvalidBBoxType.NegativeXY:
383            # use box with negative xy
384            return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32)
385
386        if invalid_bbox_type_ == InvalidBBoxType.WrongShape:
387            # use box that has incorrect shape
388            return img, np.array([[0, 0, width - 1]]).astype(np.float32)
389        return img, bboxes
390
391    try:
392        # map to use selected invalid bounding box type
393        data = data.map(operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type),
394                        input_columns=["image", "bbox"],
395                        output_columns=["image", "bbox"],
396                        column_order=["image", "bbox"])
397        # map to apply ops
398        data = data.map(operations=[test_op], input_columns=["image", "bbox"],
399                        output_columns=["image", "bbox"],
400                        column_order=["image", "bbox"])  # Add column for "bbox"
401        for _, _ in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
402            break
403    except RuntimeError as error:
404        logger.info("Got an exception in DE: {}".format(str(error)))
405        assert expected_error in str(error)
406
407
408# return true if datasets are equal
409def dataset_equal(data1, data2, mse_threshold):
410    if data1.get_dataset_size() != data2.get_dataset_size():
411        return False
412    equal = True
413    for item1, item2 in itertools.zip_longest(data1, data2):
414        for column1, column2 in itertools.zip_longest(item1, item2):
415            mse = diff_mse(column1.asnumpy(), column2.asnumpy())
416            if mse > mse_threshold:
417                equal = False
418                break
419        if not equal:
420            break
421    return equal
422
423
424# return true if datasets are equal after modification to target
425# params: data_unchanged - dataset kept unchanged
426#         data_target    - dataset to be modified by foo
427#         mse_threshold  - maximum allowable value of mse
428#         foo            - function applied to data_target columns BEFORE compare
429#         foo_args       - arguments passed into foo
430def dataset_equal_with_function(data_unchanged, data_target, mse_threshold, foo, *foo_args):
431    if data_unchanged.get_dataset_size() != data_target.get_dataset_size():
432        return False
433    equal = True
434    for item1, item2 in itertools.zip_longest(data_unchanged, data_target):
435        for column1, column2 in itertools.zip_longest(item1, item2):
436            # note the function is to be applied to the second dataset
437            column2 = foo(column2.asnumpy(), *foo_args)
438            mse = diff_mse(column1.asnumpy(), column2)
439            if mse > mse_threshold:
440                equal = False
441                break
442        if not equal:
443            break
444    return equal
445