1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import sys 16import pytest 17import numpy as np 18import pandas as pd 19import mindspore.dataset as de 20from mindspore import log as logger 21import mindspore.dataset.vision.c_transforms as vision 22 23 24def test_numpy_slices_list_1(): 25 logger.info("Test Slicing a 1D list.") 26 27 np_data = [1, 2, 3] 28 ds = de.NumpySlicesDataset(np_data, shuffle=False) 29 30 for i, data in enumerate(ds): 31 assert data[0].asnumpy() == np_data[i] 32 33 34def test_numpy_slices_list_2(): 35 logger.info("Test Slicing a 2D list into 1D list.") 36 37 np_data = [[1, 2], [3, 4]] 38 ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False) 39 40 for i, data in enumerate(ds): 41 assert np.equal(data[0].asnumpy(), np_data[i]).all() 42 43 44def test_numpy_slices_list_3(): 45 logger.info("Test Slicing list in the first dimension.") 46 47 np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] 48 ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False) 49 50 for i, data in enumerate(ds): 51 assert np.equal(data[0].asnumpy(), np_data[i]).all() 52 53 54def test_numpy_slices_numpy(): 55 logger.info("Test NumPy structure data.") 56 57 np_data = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]]]) 58 ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False) 59 60 for i, data in enumerate(ds): 61 assert np.equal(data[0].asnumpy(), np_data[i]).all() 62 63 64def test_numpy_slices_list_append(): 65 logger.info("Test reading data of image list.") 66 67 DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 68 resize_height, resize_width = 2, 2 69 70 data1 = de.TFRecordDataset(DATA_DIR) 71 resize_op = vision.Resize((resize_height, resize_width)) 72 data1 = data1.map(operations=[vision.Decode(True), resize_op], input_columns=["image"]) 73 74 res = [] 75 for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 76 res.append(data["image"]) 77 78 ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False) 79 80 for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)): 81 assert np.equal(data, res[i]).all() 82 83 84def test_numpy_slices_dict_1(): 85 logger.info("Test Dictionary structure data.") 86 87 np_data = {"a": [1, 2], "b": [3, 4]} 88 ds = de.NumpySlicesDataset(np_data, shuffle=False) 89 res = [[1, 3], [2, 4]] 90 91 for i, data in enumerate(ds): 92 assert data[0].asnumpy() == res[i][0] 93 assert data[1].asnumpy() == res[i][1] 94 95 96def test_numpy_slices_tuple_1(): 97 logger.info("Test slicing a list of tuple.") 98 99 np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])] 100 ds = de.NumpySlicesDataset(np_data, shuffle=False) 101 102 for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)): 103 assert np.equal(data, np_data[i]).all() 104 105 assert sum([1 for _ in ds]) == 3 106 107 108def test_numpy_slices_tuple_2(): 109 logger.info("Test slicing a tuple of list.") 110 111 np_data = ([1, 2], [3, 4], [5, 6]) 112 expected = [[1, 3, 5], [2, 4, 6]] 113 ds = de.NumpySlicesDataset(np_data, shuffle=False) 114 115 for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)): 116 assert np.equal(data, expected[i]).all() 117 118 assert sum([1 for _ in ds]) == 2 119 120 121def test_numpy_slices_tuple_3(): 122 logger.info("Test reading different dimension of tuple data.") 123 features, labels = np.random.sample((5, 2)), np.random.sample((5, 1)) 124 data = (features, labels) 125 126 ds = de.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False) 127 128 for i, data in enumerate(ds): 129 assert np.equal(data[0].asnumpy(), features[i]).all() 130 assert data[1].asnumpy() == labels[i] 131 132 133def test_numpy_slices_csv_value(): 134 logger.info("Test loading value of csv file.") 135 csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv" 136 137 df = pd.read_csv(csv_file) 138 target = df.pop("target") 139 df.pop("state") 140 np_data = (df.values, target.values) 141 142 ds = de.NumpySlicesDataset(np_data, column_names=["col1", "col2"], shuffle=False) 143 144 for i, data in enumerate(ds): 145 assert np.equal(np_data[0][i], data[0].asnumpy()).all() 146 assert np.equal(np_data[1][i], data[1].asnumpy()).all() 147 148 149def test_numpy_slices_csv_dict(): 150 logger.info("Test loading csv file as dict.") 151 152 csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv" 153 df = pd.read_csv(csv_file) 154 df.pop("state") 155 res = df.values 156 157 ds = de.NumpySlicesDataset(dict(df), shuffle=False) 158 159 for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)): 160 assert np.equal(data, res[i]).all() 161 162 163def test_numpy_slices_num_samplers(): 164 logger.info("Test num_samplers.") 165 166 np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] 167 ds = de.NumpySlicesDataset(np_data, shuffle=False, num_samples=2) 168 169 for i, data in enumerate(ds): 170 assert np.equal(data[0].asnumpy(), np_data[i]).all() 171 172 assert sum([1 for _ in ds]) == 2 173 174 175def test_numpy_slices_distributed_sampler(): 176 logger.info("Test distributed sampler.") 177 178 np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] 179 ds = de.NumpySlicesDataset(np_data, shuffle=False, shard_id=0, num_shards=4) 180 181 for i, data in enumerate(ds): 182 assert np.equal(data[0].asnumpy(), np_data[i * 4]).all() 183 184 assert sum([1 for _ in ds]) == 2 185 186 187def test_numpy_slices_distributed_shard_limit(): 188 logger.info("Test Slicing a 1D list.") 189 190 np_data = [1, 2, 3] 191 num = sys.maxsize 192 with pytest.raises(ValueError) as err: 193 de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False) 194 assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value) 195 196 197def test_numpy_slices_distributed_zero_shard(): 198 logger.info("Test Slicing a 1D list.") 199 200 np_data = [1, 2, 3] 201 with pytest.raises(ValueError) as err: 202 de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False) 203 assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value) 204 205 206def test_numpy_slices_sequential_sampler(): 207 logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") 208 209 np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] 210 ds = de.NumpySlicesDataset(np_data, sampler=de.SequentialSampler()).repeat(2) 211 212 for i, data in enumerate(ds): 213 assert np.equal(data[0].asnumpy(), np_data[i % 8]).all() 214 215 216def test_numpy_slices_invalid_column_names_type(): 217 logger.info("Test incorrect column_names input") 218 np_data = [1, 2, 3] 219 220 with pytest.raises(TypeError) as err: 221 de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False) 222 assert "Argument column_names[0] with value 1 is not of type [<class 'str'>]" in str(err.value) 223 224 225def test_numpy_slices_invalid_column_names_string(): 226 logger.info("Test incorrect column_names input") 227 np_data = [1, 2, 3] 228 229 with pytest.raises(ValueError) as err: 230 de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False) 231 assert "column_names[0] should not be empty" in str(err.value) 232 233 234def test_numpy_slices_invalid_empty_column_names(): 235 logger.info("Test incorrect column_names input") 236 np_data = [1, 2, 3] 237 238 with pytest.raises(ValueError) as err: 239 de.NumpySlicesDataset(np_data, column_names=[], shuffle=False) 240 assert "column_names should not be empty" in str(err.value) 241 242 243def test_numpy_slices_invalid_empty_data_column(): 244 logger.info("Test incorrect column_names input") 245 np_data = [] 246 247 with pytest.raises(ValueError) as err: 248 de.NumpySlicesDataset(np_data, shuffle=False) 249 assert "Argument data cannot be empty" in str(err.value) 250 251 252def test_numpy_slice_empty_output_shape(): 253 logger.info("running test_numpy_slice_empty_output_shape") 254 dataset = de.NumpySlicesDataset([[[1, 2], [3, 4]]], column_names=["col1"]) 255 dataset = dataset.batch(batch_size=3, drop_remainder=True) 256 assert dataset.output_shapes() == [] 257 258 259if __name__ == "__main__": 260 test_numpy_slices_list_1() 261 test_numpy_slices_list_2() 262 test_numpy_slices_list_3() 263 test_numpy_slices_list_append() 264 test_numpy_slices_dict_1() 265 test_numpy_slices_tuple_1() 266 test_numpy_slices_tuple_2() 267 test_numpy_slices_tuple_3() 268 test_numpy_slices_csv_value() 269 test_numpy_slices_csv_dict() 270 test_numpy_slices_num_samplers() 271 test_numpy_slices_distributed_sampler() 272 test_numpy_slices_distributed_shard_limit() 273 test_numpy_slices_distributed_zero_shard() 274 test_numpy_slices_sequential_sampler() 275 test_numpy_slices_invalid_column_names_type() 276 test_numpy_slices_invalid_column_names_string() 277 test_numpy_slices_invalid_empty_column_names() 278 test_numpy_slices_invalid_empty_data_column() 279 test_numpy_slice_empty_output_shape() 280