1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import pytest 16import mindspore.dataset as ds 17from mindspore import log as logger 18from util import config_get_set_num_parallel_workers, config_get_set_seed 19 20 21DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" 22DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" 23 24 25def test_textline_dataset_one_file(): 26 data = ds.TextFileDataset(DATA_FILE) 27 count = 0 28 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 29 logger.info("{}".format(i["text"])) 30 count += 1 31 assert count == 3 32 33 34def test_textline_dataset_all_file(): 35 data = ds.TextFileDataset(DATA_ALL_FILE) 36 count = 0 37 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 38 logger.info("{}".format(i["text"])) 39 count += 1 40 assert count == 5 41 42 43def test_textline_dataset_num_samples_none(): 44 # Do not provide a num_samples argument, so it would be None by default 45 data = ds.TextFileDataset(DATA_FILE) 46 count = 0 47 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 48 logger.info("{}".format(i["text"])) 49 count += 1 50 assert count == 3 51 52 53def test_textline_dataset_shuffle_false4(): 54 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 55 original_seed = config_get_set_seed(987) 56 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) 57 count = 0 58 line = ["This is a text file.", "Another file.", 59 "Be happy every day.", "End of file.", "Good luck to everyone."] 60 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 61 strs = i["text"].item().decode("utf8") 62 assert strs == line[count] 63 count += 1 64 assert count == 5 65 # Restore configuration 66 ds.config.set_num_parallel_workers(original_num_parallel_workers) 67 ds.config.set_seed(original_seed) 68 69 70def test_textline_dataset_shuffle_false1(): 71 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 72 original_seed = config_get_set_seed(987) 73 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) 74 count = 0 75 line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", 76 "Another file.", "End of file."] 77 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 78 strs = i["text"].item().decode("utf8") 79 assert strs == line[count] 80 count += 1 81 assert count == 5 82 # Restore configuration 83 ds.config.set_num_parallel_workers(original_num_parallel_workers) 84 ds.config.set_seed(original_seed) 85 86 87def test_textline_dataset_shuffle_files4(): 88 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 89 original_seed = config_get_set_seed(135) 90 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) 91 count = 0 92 line = ["This is a text file.", "Another file.", 93 "Be happy every day.", "End of file.", "Good luck to everyone."] 94 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 95 strs = i["text"].item().decode("utf8") 96 assert strs == line[count] 97 count += 1 98 assert count == 5 99 # Restore configuration 100 ds.config.set_num_parallel_workers(original_num_parallel_workers) 101 ds.config.set_seed(original_seed) 102 103 104def test_textline_dataset_shuffle_files1(): 105 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 106 original_seed = config_get_set_seed(135) 107 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) 108 count = 0 109 line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", 110 "Another file.", "End of file."] 111 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 112 strs = i["text"].item().decode("utf8") 113 assert strs == line[count] 114 count += 1 115 assert count == 5 116 # Restore configuration 117 ds.config.set_num_parallel_workers(original_num_parallel_workers) 118 ds.config.set_seed(original_seed) 119 120 121def test_textline_dataset_shuffle_global4(): 122 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 123 original_seed = config_get_set_seed(246) 124 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) 125 count = 0 126 line = ["Another file.", "Good luck to everyone.", "End of file.", 127 "This is a text file.", "Be happy every day."] 128 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 129 strs = i["text"].item().decode("utf8") 130 assert strs == line[count] 131 count += 1 132 assert count == 5 133 # Restore configuration 134 ds.config.set_num_parallel_workers(original_num_parallel_workers) 135 ds.config.set_seed(original_seed) 136 137 138def test_textline_dataset_shuffle_global1(): 139 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 140 original_seed = config_get_set_seed(246) 141 data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) 142 count = 0 143 line = ["Another file.", "Good luck to everyone.", "This is a text file.", 144 "End of file.", "Be happy every day."] 145 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 146 strs = i["text"].item().decode("utf8") 147 assert strs == line[count] 148 count += 1 149 assert count == 5 150 # Restore configuration 151 ds.config.set_num_parallel_workers(original_num_parallel_workers) 152 ds.config.set_seed(original_seed) 153 154 155def test_textline_dataset_num_samples(): 156 data = ds.TextFileDataset(DATA_FILE, num_samples=2) 157 count = 0 158 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 159 count += 1 160 assert count == 2 161 162 163def test_textline_dataset_distribution(): 164 data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) 165 count = 0 166 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 167 count += 1 168 assert count == 3 169 170 171def test_textline_dataset_repeat(): 172 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 173 data = data.repeat(3) 174 count = 0 175 line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", 176 "This is a text file.", "Be happy every day.", "Good luck to everyone.", 177 "This is a text file.", "Be happy every day.", "Good luck to everyone."] 178 for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): 179 strs = i["text"].item().decode("utf8") 180 assert strs == line[count] 181 count += 1 182 assert count == 9 183 184 185def test_textline_dataset_get_datasetsize(): 186 data = ds.TextFileDataset(DATA_FILE) 187 size = data.get_dataset_size() 188 assert size == 3 189 190def test_textline_dataset_to_device(): 191 data = ds.TextFileDataset(DATA_FILE, shuffle=False) 192 data = data.to_device() 193 data.send() 194 195def test_textline_dataset_exceptions(): 196 with pytest.raises(ValueError) as error_info: 197 _ = ds.TextFileDataset(DATA_FILE, num_samples=-1) 198 assert "num_samples exceeds the boundary" in str(error_info.value) 199 200 with pytest.raises(ValueError) as error_info: 201 _ = ds.TextFileDataset("does/not/exist/no.txt") 202 assert "The following patterns did not match any files" in str(error_info.value) 203 204 with pytest.raises(ValueError) as error_info: 205 _ = ds.TextFileDataset("") 206 assert "The following patterns did not match any files" in str(error_info.value) 207 208 def exception_func(item): 209 raise Exception("Error occur!") 210 with pytest.raises(RuntimeError) as error_info: 211 data = ds.TextFileDataset(DATA_FILE) 212 data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) 213 for _ in data.__iter__(): 214 pass 215 assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value) 216 217 218if __name__ == "__main__": 219 test_textline_dataset_one_file() 220 test_textline_dataset_all_file() 221 test_textline_dataset_num_samples_none() 222 test_textline_dataset_shuffle_false4() 223 test_textline_dataset_shuffle_false1() 224 test_textline_dataset_shuffle_files4() 225 test_textline_dataset_shuffle_files1() 226 test_textline_dataset_shuffle_global4() 227 test_textline_dataset_shuffle_global1() 228 test_textline_dataset_num_samples() 229 test_textline_dataset_distribution() 230 test_textline_dataset_repeat() 231 test_textline_dataset_get_datasetsize() 232 test_textline_dataset_to_device() 233 test_textline_dataset_exceptions() 234