1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17import mindspore.dataset as ds 18 19DATA_FILE = '../data/dataset/testCSV/1.csv' 20 21 22def test_csv_dataset_basic(): 23 """ 24 Test CSV with repeat, skip and so on 25 """ 26 TRAIN_FILE = '../data/dataset/testCSV/1.csv' 27 28 buffer = [] 29 data = ds.CSVDataset( 30 TRAIN_FILE, 31 field_delim=',', 32 column_defaults=["0", 0, 0.0, "0"], 33 column_names=['1', '2', '3', '4'], 34 shuffle=False) 35 data = data.repeat(2) 36 data = data.skip(2) 37 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 38 buffer.append(d) 39 assert len(buffer) == 4 40 41 42def test_csv_dataset_one_file(): 43 data = ds.CSVDataset( 44 DATA_FILE, 45 column_defaults=["1", "2", "3", "4"], 46 column_names=['col1', 'col2', 'col3', 'col4'], 47 shuffle=False) 48 buffer = [] 49 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 50 buffer.append(d) 51 assert len(buffer) == 3 52 53 54def test_csv_dataset_all_file(): 55 APPEND_FILE = '../data/dataset/testCSV/2.csv' 56 data = ds.CSVDataset( 57 [DATA_FILE, APPEND_FILE], 58 column_defaults=["1", "2", "3", "4"], 59 column_names=['col1', 'col2', 'col3', 'col4'], 60 shuffle=False) 61 buffer = [] 62 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 63 buffer.append(d) 64 assert len(buffer) == 10 65 66 67def test_csv_dataset_num_samples(): 68 data = ds.CSVDataset( 69 DATA_FILE, 70 column_defaults=["1", "2", "3", "4"], 71 column_names=['col1', 'col2', 'col3', 'col4'], 72 shuffle=False, num_samples=2) 73 count = 0 74 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 75 count += 1 76 assert count == 2 77 78 79def test_csv_dataset_distribution(): 80 TEST_FILE = '../data/dataset/testCSV/1.csv' 81 data = ds.CSVDataset( 82 TEST_FILE, 83 column_defaults=["1", "2", "3", "4"], 84 column_names=['col1', 'col2', 'col3', 'col4'], 85 shuffle=False, num_shards=2, shard_id=0) 86 count = 0 87 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 88 count += 1 89 assert count == 2 90 91 92def test_csv_dataset_quoted(): 93 TEST_FILE = '../data/dataset/testCSV/quoted.csv' 94 data = ds.CSVDataset( 95 TEST_FILE, 96 column_defaults=["", "", "", ""], 97 column_names=['col1', 'col2', 'col3', 'col4'], 98 shuffle=False) 99 buffer = [] 100 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 101 buffer.extend([d['col1'].item().decode("utf8"), 102 d['col2'].item().decode("utf8"), 103 d['col3'].item().decode("utf8"), 104 d['col4'].item().decode("utf8")]) 105 assert buffer == ['a', 'b', 'c', 'd'] 106 107 108def test_csv_dataset_separated(): 109 TEST_FILE = '../data/dataset/testCSV/separated.csv' 110 data = ds.CSVDataset( 111 TEST_FILE, 112 field_delim='|', 113 column_defaults=["", "", "", ""], 114 column_names=['col1', 'col2', 'col3', 'col4'], 115 shuffle=False) 116 buffer = [] 117 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 118 buffer.extend([d['col1'].item().decode("utf8"), 119 d['col2'].item().decode("utf8"), 120 d['col3'].item().decode("utf8"), 121 d['col4'].item().decode("utf8")]) 122 assert buffer == ['a', 'b', 'c', 'd'] 123 124 125def test_csv_dataset_embedded(): 126 TEST_FILE = '../data/dataset/testCSV/embedded.csv' 127 data = ds.CSVDataset( 128 TEST_FILE, 129 column_defaults=["", "", "", ""], 130 column_names=['col1', 'col2', 'col3', 'col4'], 131 shuffle=False) 132 buffer = [] 133 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 134 buffer.extend([d['col1'].item().decode("utf8"), 135 d['col2'].item().decode("utf8"), 136 d['col3'].item().decode("utf8"), 137 d['col4'].item().decode("utf8")]) 138 assert buffer == ['a,b', 'c"d', 'e\nf', ' g '] 139 140 141def test_csv_dataset_chinese(): 142 TEST_FILE = '../data/dataset/testCSV/chinese.csv' 143 data = ds.CSVDataset( 144 TEST_FILE, 145 column_defaults=["", "", "", "", ""], 146 column_names=['col1', 'col2', 'col3', 'col4', 'col5'], 147 shuffle=False) 148 buffer = [] 149 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 150 buffer.extend([d['col1'].item().decode("utf8"), 151 d['col2'].item().decode("utf8"), 152 d['col3'].item().decode("utf8"), 153 d['col4'].item().decode("utf8"), 154 d['col5'].item().decode("utf8")]) 155 assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好'] 156 157 158def test_csv_dataset_header(): 159 TEST_FILE = '../data/dataset/testCSV/header.csv' 160 data = ds.CSVDataset( 161 TEST_FILE, 162 column_defaults=["", "", "", ""], 163 shuffle=False) 164 buffer = [] 165 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 166 buffer.extend([d['col1'].item().decode("utf8"), 167 d['col2'].item().decode("utf8"), 168 d['col3'].item().decode("utf8"), 169 d['col4'].item().decode("utf8")]) 170 assert buffer == ['a', 'b', 'c', 'd'] 171 172 173def test_csv_dataset_number(): 174 TEST_FILE = '../data/dataset/testCSV/number.csv' 175 data = ds.CSVDataset( 176 TEST_FILE, 177 column_defaults=[0.0, 0.0, 0, 0.0], 178 column_names=['col1', 'col2', 'col3', 'col4'], 179 shuffle=False) 180 buffer = [] 181 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 182 buffer.extend([d['col1'].item(), 183 d['col2'].item(), 184 d['col3'].item(), 185 d['col4'].item()]) 186 assert np.allclose(buffer, [3.0, 0.3, 4, 55.5]) 187 188 189def test_csv_dataset_field_delim_none(): 190 """ 191 Test CSV with field_delim=None 192 """ 193 TRAIN_FILE = '../data/dataset/testCSV/1.csv' 194 195 buffer = [] 196 data = ds.CSVDataset( 197 TRAIN_FILE, 198 field_delim=None, 199 column_defaults=["0", 0, 0.0, "0"], 200 column_names=['1', '2', '3', '4'], 201 shuffle=False) 202 data = data.repeat(2) 203 data = data.skip(2) 204 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 205 buffer.append(d) 206 assert len(buffer) == 4 207 208 209def test_csv_dataset_size(): 210 TEST_FILE = '../data/dataset/testCSV/size.csv' 211 data = ds.CSVDataset( 212 TEST_FILE, 213 column_defaults=[0.0, 0.0, 0, 0.0], 214 column_names=['col1', 'col2', 'col3', 'col4'], 215 shuffle=False) 216 assert data.get_dataset_size() == 5 217 218 219def test_csv_dataset_type_error(): 220 TEST_FILE = '../data/dataset/testCSV/exception.csv' 221 data = ds.CSVDataset( 222 TEST_FILE, 223 column_defaults=["", 0, "", ""], 224 column_names=['col1', 'col2', 'col3', 'col4'], 225 shuffle=False) 226 with pytest.raises(Exception) as err: 227 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 228 pass 229 assert "type does not match" in str(err.value) 230 231 232def test_csv_dataset_exception(): 233 TEST_FILE = '../data/dataset/testCSV/exception.csv' 234 data = ds.CSVDataset( 235 TEST_FILE, 236 column_defaults=["", "", "", ""], 237 column_names=['col1', 'col2', 'col3', 'col4'], 238 shuffle=False) 239 with pytest.raises(Exception) as err: 240 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 241 pass 242 assert "failed to parse file" in str(err.value) 243 244 TEST_FILE1 = '../data/dataset/testCSV/quoted.csv' 245 def exception_func(item): 246 raise Exception("Error occur!") 247 248 try: 249 data = ds.CSVDataset( 250 TEST_FILE1, 251 column_defaults=["", "", "", ""], 252 column_names=['col1', 'col2', 'col3', 'col4'], 253 shuffle=False) 254 data = data.map(operations=exception_func, input_columns=["col1"], num_parallel_workers=1) 255 for _ in data.__iter__(): 256 pass 257 assert False 258 except RuntimeError as e: 259 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 260 261 try: 262 data = ds.CSVDataset( 263 TEST_FILE1, 264 column_defaults=["", "", "", ""], 265 column_names=['col1', 'col2', 'col3', 'col4'], 266 shuffle=False) 267 data = data.map(operations=exception_func, input_columns=["col2"], num_parallel_workers=1) 268 for _ in data.__iter__(): 269 pass 270 assert False 271 except RuntimeError as e: 272 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 273 274 try: 275 data = ds.CSVDataset( 276 TEST_FILE1, 277 column_defaults=["", "", "", ""], 278 column_names=['col1', 'col2', 'col3', 'col4'], 279 shuffle=False) 280 data = data.map(operations=exception_func, input_columns=["col3"], num_parallel_workers=1) 281 for _ in data.__iter__(): 282 pass 283 assert False 284 except RuntimeError as e: 285 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 286 287 try: 288 data = ds.CSVDataset( 289 TEST_FILE1, 290 column_defaults=["", "", "", ""], 291 column_names=['col1', 'col2', 'col3', 'col4'], 292 shuffle=False) 293 data = data.map(operations=exception_func, input_columns=["col4"], num_parallel_workers=1) 294 for _ in data.__iter__(): 295 pass 296 assert False 297 except RuntimeError as e: 298 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 299 300 301def test_csv_dataset_duplicate_columns(): 302 data = ds.CSVDataset( 303 DATA_FILE, 304 column_defaults=["1", "2", "3", "4"], 305 column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'], 306 shuffle=False) 307 with pytest.raises(RuntimeError) as info: 308 _ = data.create_dict_iterator(num_epochs=1, output_numpy=True) 309 assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value) 310 assert "column_names" in str(info.value) 311 312 313if __name__ == "__main__": 314 test_csv_dataset_basic() 315 test_csv_dataset_one_file() 316 test_csv_dataset_all_file() 317 test_csv_dataset_num_samples() 318 test_csv_dataset_distribution() 319 test_csv_dataset_quoted() 320 test_csv_dataset_separated() 321 test_csv_dataset_embedded() 322 test_csv_dataset_chinese() 323 test_csv_dataset_header() 324 test_csv_dataset_number() 325 test_csv_dataset_field_delim_none() 326 test_csv_dataset_size() 327 test_csv_dataset_type_error() 328 test_csv_dataset_exception() 329 test_csv_dataset_duplicate_columns() 330