# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Testing SlicePatches Python API """ import functools import numpy as np import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as c_vision import mindspore.dataset.vision.utils as mode from mindspore import log as logger from util import diff_mse, visualize_list DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_slice_patches_01(plot=False): """ slice rgb image(100, 200) to 4 patches """ slice_to_patches([100, 200], 2, 2, True, plot=plot) def test_slice_patches_02(plot=False): """ no op """ slice_to_patches([100, 200], 1, 1, True, plot=plot) def test_slice_patches_03(plot=False): """ slice rgb image(99, 199) to 4 patches in pad mode """ slice_to_patches([99, 199], 2, 2, True, plot=plot) def test_slice_patches_04(plot=False): """ slice rgb image(99, 199) to 4 patches in drop mode """ slice_to_patches([99, 199], 2, 2, False, plot=plot) def test_slice_patches_05(plot=False): """ slice rgb image(99, 199) to 4 patches in pad mode """ slice_to_patches([99, 199], 2, 2, True, 255, plot=plot) def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False): """ Tool function for slice patches """ logger.info("test_slice_patches_pipeline") cols = ['img' + str(x) for x in range(num_h*num_w)] # First dataset dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) decode_op = c_vision.Decode() resize_op = c_vision.Resize(ori_size) # H, W slice_patches_op = c_vision.SlicePatches( num_h, num_w, mode.SliceMode.PAD, fill_value) if not pad_or_drop: slice_patches_op = c_vision.SlicePatches( num_h, num_w, mode.SliceMode.DROP) dataset1 = dataset1.map(operations=decode_op, input_columns=["image"]) dataset1 = dataset1.map(operations=resize_op, input_columns=["image"]) dataset1 = dataset1.map(operations=slice_patches_op, input_columns=["image"], output_columns=cols, column_order=cols) # Second dataset dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) dataset2 = dataset2.map(operations=decode_op, input_columns=["image"]) dataset2 = dataset2.map(operations=resize_op, input_columns=["image"]) func_slice_patches = functools.partial( slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value) dataset2 = dataset2.map(operations=func_slice_patches, input_columns=["image"], output_columns=cols, column_order=cols) num_iter = 0 patches_c = [] patches_py = [] for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True), dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)): for x in range(num_h*num_w): col = "img" + str(x) mse = diff_mse(data1[col], data2[col]) logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse)) assert mse == 0 patches_c.append(data1[col]) patches_py.append(data2[col]) num_iter += 1 if plot: visualize_list(patches_py, patches_c) def test_slice_patches_exception_01(): """ Test SlicePatches with invalid parameters """ logger.info("test_Slice_Patches_exception") try: _ = c_vision.SlicePatches(0, 2) except ValueError as e: logger.info("Got an exception in SlicePatches: {}".format(str(e))) assert "Input num_height is not within" in str(e) try: _ = c_vision.SlicePatches(2, 0) except ValueError as e: logger.info("Got an exception in SlicePatches: {}".format(str(e))) assert "Input num_width is not within" in str(e) try: _ = c_vision.SlicePatches(2, 2, 1) except TypeError as e: logger.info("Got an exception in SlicePatches: {}".format(str(e))) assert "Argument slice_mode with value" in str(e) try: _ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1) except ValueError as e: logger.info("Got an exception in SlicePatches: {}".format(str(e))) assert "Input fill_value is not within" in str(e) def test_slice_patches_06(): image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32) slice_patches_op = c_vision.SlicePatches(2, 8) patches = slice_patches_op(image) assert len(patches) == 16 assert patches[0].shape == (79, 16, 1) def test_slice_patches_07(): image = np.random.randint(0, 255, (158, 126)).astype(np.int32) slice_patches_op = c_vision.SlicePatches(2, 8) patches = slice_patches_op(image) assert len(patches) == 16 assert patches[0].shape == (79, 16) def test_slice_patches_08(): np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8) dataset = ds.NumpySlicesDataset(np_data, column_names=["image"]) slice_patches_op = c_vision.SlicePatches(2, 2) dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"], column_order=["img0", "img1", "img2", "img3"], operations=slice_patches_op) for item in dataset.create_dict_iterator(output_numpy=True): patch_shape = item['img0'].shape assert patch_shape == (28, 41, 256) def test_slice_patches_09(): image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8) slice_patches_op = c_vision.SlicePatches(4, 3, mode.SliceMode.PAD) patches = slice_patches_op(image) assert len(patches) == 12 assert patches[0].shape == (14, 28, 256) def skip_test_slice_patches_10(): image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8) slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP) patches = slice_patches_op(image) assert patches[0].shape == (700, 538, 255) def skip_test_slice_patches_11(): np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8) dataset = ds.NumpySlicesDataset(np_data, column_names=["image"]) slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP) cols = ['img' + str(x) for x in range(10*13)] dataset = dataset.map(input_columns=["image"], output_columns=cols, column_order=cols, operations=slice_patches_op) for item in dataset.create_dict_iterator(output_numpy=True): patch_shape = item['img0'].shape assert patch_shape == (700, 538, 256) def slice_patches(image, num_h, num_w, pad_or_drop, fill_value): """ help function which slice patches with numpy """ if num_h == 1 and num_w == 1: return image # (H, W, C) H, W, C = image.shape patch_h = H // num_h patch_w = W // num_w if H % num_h != 0: if pad_or_drop: patch_h += 1 if W % num_w != 0: if pad_or_drop: patch_w += 1 img = image[:, :, :] if pad_or_drop: img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8) img[:H, :W] = image[:, :, :] patches = [] for top in range(num_h): for left in range(num_w): patches.append(img[top*patch_h:(top+1)*patch_h, left*patch_w:(left+1)*patch_w, :]) return (*patches,) if __name__ == "__main__": test_slice_patches_01(plot=True) test_slice_patches_02(plot=True) test_slice_patches_03(plot=True) test_slice_patches_04(plot=True) test_slice_patches_05(plot=True) test_slice_patches_06() test_slice_patches_07() test_slice_patches_08() test_slice_patches_09() test_slice_patches_exception_01()