1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import math 16 17import matplotlib.pyplot as plt 18import numpy as np 19import pytest 20 21import mindspore.dataset as ds 22from mindspore import log as logger 23import mindspore.dataset.vision.c_transforms as c_vision 24 25DATASET_DIR = "../data/dataset/testSBData/sbd" 26 27 28def visualize_dataset(images, labels, task): 29 """ 30 Helper function to visualize the dataset samples 31 """ 32 image_num = len(images) 33 subplot_rows = 1 if task == "Segmentation" else 4 34 for i in range(image_num): 35 plt.imshow(images[i]) 36 plt.title('Original') 37 plt.savefig('./sbd_original_{}.jpg'.format(str(i))) 38 if task == "Segmentation": 39 plt.imshow(labels[i]) 40 plt.title(task) 41 plt.savefig('./sbd_segmentation_{}.jpg'.format(str(i))) 42 else: 43 b_num = labels[i].shape[0] 44 for j in range(b_num): 45 plt.subplot(subplot_rows, math.ceil(b_num / subplot_rows), j + 1) 46 plt.imshow(labels[i][j]) 47 plt.savefig('./sbd_boundaries_{}.jpg'.format(str(i))) 48 plt.close() 49 50 51def test_sbd_basic01(plot=False): 52 """ 53 Validate SBDataset with different usage 54 """ 55 task = 'Segmentation' # Boundaries, Segmentation 56 data = ds.SBDataset(DATASET_DIR, task=task, usage='all', shuffle=False, decode=True) 57 count = 0 58 images_list = [] 59 task_list = [] 60 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 61 images_list.append(item['image']) 62 task_list.append(item['task']) 63 count = count + 1 64 assert count == 6 65 if plot: 66 visualize_dataset(images_list, task_list, task) 67 68 data2 = ds.SBDataset(DATASET_DIR, task=task, usage='train', shuffle=False, decode=False) 69 count = 0 70 for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True): 71 count = count + 1 72 assert count == 4 73 74 data3 = ds.SBDataset(DATASET_DIR, task=task, usage='val', shuffle=False, decode=False) 75 count = 0 76 for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): 77 count = count + 1 78 assert count == 2 79 80 81def test_sbd_basic02(): 82 """ 83 Validate SBDataset with repeat and batch operation 84 """ 85 # Boundaries, Segmentation 86 # case 1: test num_samples 87 data1 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=3, shuffle=False) 88 num_iter1 = 0 89 for _ in data1.create_dict_iterator(num_epochs=1): 90 num_iter1 += 1 91 assert num_iter1 == 3 92 93 # case 2: test repeat 94 data2 = ds.SBDataset(DATASET_DIR, task='Boundaries', usage='train', num_samples=4, shuffle=False) 95 data2 = data2.repeat(5) 96 num_iter2 = 0 97 for _ in data2.create_dict_iterator(num_epochs=1): 98 num_iter2 += 1 99 assert num_iter2 == 20 100 101 # case 3: test batch with drop_remainder=False 102 data3 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True) 103 resize_op = c_vision.Resize((100, 100)) 104 data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) 105 data3 = data3.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1) 106 assert data3.get_dataset_size() == 4 107 assert data3.get_batch_size() == 1 108 data3 = data3.batch(batch_size=3) # drop_remainder is default to be False 109 assert data3.get_dataset_size() == 2 110 assert data3.get_batch_size() == 3 111 num_iter3 = 0 112 for _ in data3.create_dict_iterator(num_epochs=1): 113 num_iter3 += 1 114 assert num_iter3 == 2 115 116 # case 4: test batch with drop_remainder=True 117 data4 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, decode=True) 118 resize_op = c_vision.Resize((100, 100)) 119 data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1) 120 data4 = data4.map(operations=resize_op, input_columns=["task"], num_parallel_workers=1) 121 assert data4.get_dataset_size() == 4 122 assert data4.get_batch_size() == 1 123 data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped 124 assert data4.get_dataset_size() == 1 125 assert data4.get_batch_size() == 3 126 num_iter4 = 0 127 for _ in data4.create_dict_iterator(num_epochs=1): 128 num_iter4 += 1 129 assert num_iter4 == 1 130 131 132def test_sbd_sequential_sampler(): 133 """ 134 Test SBDataset with SequentialSampler 135 """ 136 logger.info("Test SBDataset Op with SequentialSampler") 137 num_samples = 5 138 sampler = ds.SequentialSampler(num_samples=num_samples) 139 data1 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', sampler=sampler) 140 data2 = ds.SBDataset(DATASET_DIR, task='Segmentation', usage='all', shuffle=False, num_samples=num_samples) 141 num_iter = 0 142 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 143 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 144 np.testing.assert_array_equal(item1["task"], item2["task"]) 145 num_iter += 1 146 assert num_iter == num_samples 147 148 149def test_sbd_exception(): 150 """ 151 Validate SBDataset with error parameters 152 """ 153 error_msg_1 = "sampler and shuffle cannot be specified at the same time" 154 with pytest.raises(RuntimeError, match=error_msg_1): 155 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, sampler=ds.PKSampler(3)) 156 157 error_msg_2 = "sampler and sharding cannot be specified at the same time" 158 with pytest.raises(RuntimeError, match=error_msg_2): 159 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=0, 160 sampler=ds.PKSampler(3)) 161 162 error_msg_3 = "num_shards is specified and currently requires shard_id as well" 163 with pytest.raises(RuntimeError, match=error_msg_3): 164 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=10) 165 166 error_msg_4 = "shard_id is specified but num_shards is not" 167 with pytest.raises(RuntimeError, match=error_msg_4): 168 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shard_id=0) 169 170 error_msg_5 = "Input shard_id is not within the required interval" 171 with pytest.raises(ValueError, match=error_msg_5): 172 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=-1) 173 with pytest.raises(ValueError, match=error_msg_5): 174 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=5, shard_id=5) 175 with pytest.raises(ValueError, match=error_msg_5): 176 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id=5) 177 178 error_msg_6 = "num_parallel_workers exceeds" 179 with pytest.raises(ValueError, match=error_msg_6): 180 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=0) 181 with pytest.raises(ValueError, match=error_msg_6): 182 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=256) 183 with pytest.raises(ValueError, match=error_msg_6): 184 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', shuffle=False, num_parallel_workers=-2) 185 186 error_msg_7 = "Argument shard_id" 187 with pytest.raises(TypeError, match=error_msg_7): 188 ds.SBDataset(DATASET_DIR, task='Segmentation', usage='train', num_shards=2, shard_id="0") 189 190 191def test_sbd_usage(): 192 """ 193 Validate SBDataset image readings 194 """ 195 196 def test_config(usage): 197 try: 198 data = ds.SBDataset(DATASET_DIR, task='Segmentation', usage=usage) 199 num_rows = 0 200 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 201 num_rows += 1 202 except (ValueError, TypeError, RuntimeError) as e: 203 return str(e) 204 return num_rows 205 206 assert test_config("train") == 4 207 assert test_config("train_noval") == 4 208 assert test_config("val") == 2 209 assert test_config("all") == 6 210 assert "usage is not within the valid set of ['train', 'val', 'train_noval', 'all']" in test_config("invalid") 211 assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"]) 212 213 214if __name__ == "__main__": 215 test_sbd_basic01() 216 test_sbd_basic02() 217 test_sbd_sequential_sampler() 218 test_sbd_exception() 219 test_sbd_usage() 220