1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import mindspore._c_dataengine as cde 17 18import mindspore.common.dtype as mstype 19import mindspore.dataset as ds 20from mindspore.dataset.text import to_str, to_bytes 21 22 23def test_basic(): 24 x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S') 25 n = cde.Tensor(x) 26 arr = n.as_array() 27 np.testing.assert_array_equal(x, arr) 28 29 30def compare(strings, dtype='S'): 31 arr = np.array(strings, dtype=dtype) 32 33 def gen(): 34 (yield arr,) 35 36 data = ds.GeneratorDataset(gen, column_names=["col"]) 37 38 for d in data.create_tuple_iterator(output_numpy=True): 39 np.testing.assert_array_equal(d[0], arr.astype('S')) 40 41 42def test_generator(): 43 compare(["ab"]) 44 compare(["", ""]) 45 compare([""]) 46 compare(["ab", ""]) 47 compare(["ab", "cde", "121"]) 48 compare([["ab", "cde", "121"], ["x", "km", "789"]]) 49 compare([["ab", "", "121"], ["", "km", "789"]]) 50 compare(["ab"], dtype='U') 51 compare(["", ""], dtype='U') 52 compare([""], dtype='U') 53 compare(["ab", ""], dtype='U') 54 compare(["", ""], dtype='U') 55 compare(["", "ab"], dtype='U') 56 compare(["ab", "cde", "121"], dtype='U') 57 compare([["ab", "cde", "121"], ["x", "km", "789"]], dtype='U') 58 compare([["ab", "", "121"], ["", "km", "789"]], dtype='U') 59 60 61line = np.array(["This is a text file.", 62 "Be happy every day.", 63 "Good luck to everyone."]) 64 65words = np.array([["This", "text", "file", "a"], 66 ["Be", "happy", "day", "b"], 67 ["女", "", "everyone", "c"]]) 68 69chinese = np.array(["今天天气太好了我们一起去外面玩吧", 70 "男默女泪", 71 "江州市长江大桥参加了长江大桥的通车仪式"]) 72 73 74def test_batching_strings(): 75 def gen(): 76 for row in chinese: 77 yield (np.array(row),) 78 79 data = ds.GeneratorDataset(gen, column_names=["col"]) 80 data = data.batch(2, drop_remainder=True) 81 82 for d in data.create_tuple_iterator(output_numpy=True): 83 np.testing.assert_array_equal(d[0], to_bytes(chinese[0:2])) 84 85 86def test_map(): 87 def gen(): 88 yield (np.array(["ab cde 121"], dtype='S'),) 89 90 data = ds.GeneratorDataset(gen, column_names=["col"]) 91 92 def split(b): 93 s = to_str(b) 94 splits = s.item().split() 95 return np.array(splits) 96 97 data = data.map(operations=split, input_columns=["col"]) 98 expected = np.array(["ab", "cde", "121"], dtype='S') 99 for d in data.create_tuple_iterator(output_numpy=True): 100 np.testing.assert_array_equal(d[0], expected) 101 102 103def test_map2(): 104 def gen(): 105 yield (np.array(["ab cde 121"], dtype='S'),) 106 107 data = ds.GeneratorDataset(gen, column_names=["col"]) 108 109 def upper(b): 110 out = np.char.upper(b) 111 return out 112 113 data = data.map(operations=upper, input_columns=["col"]) 114 expected = np.array(["AB CDE 121"], dtype='S') 115 for d in data.create_tuple_iterator(output_numpy=True): 116 np.testing.assert_array_equal(d[0], expected) 117 118 119def test_tfrecord1(): 120 s = ds.Schema() 121 s.add_column("line", "string", []) 122 s.add_column("words", "string", [-1]) 123 s.add_column("chinese", "string", []) 124 125 data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) 126 127 for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 128 assert d["line"].shape == line[i].shape 129 assert d["words"].shape == words[i].shape 130 assert d["chinese"].shape == chinese[i].shape 131 np.testing.assert_array_equal(line[i], to_str(d["line"])) 132 np.testing.assert_array_equal(words[i], to_str(d["words"])) 133 np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) 134 135 136def test_tfrecord2(): 137 data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, 138 schema='../data/dataset/testTextTFRecord/datasetSchema.json') 139 for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 140 assert d["line"].shape == line[i].shape 141 assert d["words"].shape == words[i].shape 142 assert d["chinese"].shape == chinese[i].shape 143 np.testing.assert_array_equal(line[i], to_str(d["line"])) 144 np.testing.assert_array_equal(words[i], to_str(d["words"])) 145 np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) 146 147 148def test_tfrecord3(): 149 s = ds.Schema() 150 s.add_column("line", mstype.string, []) 151 s.add_column("words", mstype.string, [-1, 2]) 152 s.add_column("chinese", mstype.string, []) 153 154 data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) 155 156 for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 157 assert d["line"].shape == line[i].shape 158 assert d["words"].shape == words[i].reshape([2, 2]).shape 159 assert d["chinese"].shape == chinese[i].shape 160 np.testing.assert_array_equal(line[i], to_str(d["line"])) 161 np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"])) 162 np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) 163 164 165def create_text_mindrecord(): 166 # methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord 167 from mindspore.mindrecord import FileWriter 168 169 mindrecord_file_name = "test.mindrecord" 170 data = [{"english": "This is a text file.", 171 "chinese": "今天天气太好了我们一起去外面玩吧"}, 172 {"english": "Be happy every day.", 173 "chinese": "男默女泪"}, 174 {"english": "Good luck to everyone.", 175 "chinese": "江州市长江大桥参加了长江大桥的通车仪式"}, 176 ] 177 writer = FileWriter(mindrecord_file_name) 178 schema = {"english": {"type": "string"}, 179 "chinese": {"type": "string"}, 180 } 181 writer.add_schema(schema) 182 writer.write_raw_data(data) 183 writer.commit() 184 185 186def test_mindrecord(): 187 data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False) 188 189 for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)): 190 assert d["english"].shape == line[i].shape 191 assert d["chinese"].shape == chinese[i].shape 192 np.testing.assert_array_equal(line[i], to_str(d["english"])) 193 np.testing.assert_array_equal(chinese[i], to_str(d["chinese"])) 194 195 196# The following tests cases were copied from test_pad_batch but changed to strings instead 197 198 199# this generator function yield two columns 200# col1d: [0],[1], [2], [3] 201# col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]] 202def gen_2cols(num): 203 for i in range(num): 204 yield (np.array([str(i)]), np.array([[str(i + 100)], [str(i + 200)]])) 205 206 207# this generator function yield one column of variable shapes 208# col: [0], [0,1], [0,1,2], [0,1,2,3] 209def gen_var_col(num): 210 for i in range(num): 211 yield (np.array([str(j) for j in range(i + 1)]),) 212 213 214# this generator function yield two columns of variable shapes 215# col1: [0], [0,1], [0,1,2], [0,1,2,3] 216# col2: [100], [100,101], [100,101,102], [100,110,102,103] 217def gen_var_cols(num): 218 for i in range(num): 219 yield (np.array([str(j) for j in range(i + 1)]), np.array([str(100 + j) for j in range(i + 1)])) 220 221 222# this generator function yield two columns of variable shapes 223# col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]] 224# col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]] 225def gen_var_cols_2d(num): 226 for i in range(num): 227 yield (np.array([[str(j) for j in range(i + 1)]]), np.array([[str(100 + j) for j in range(i + 1)]])) 228 229 230def test_batch_padding_01(): 231 data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"]) 232 data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], b"-2"), "col1d": ([2], b"-1")}) 233 data1 = data1.repeat(2) 234 for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 235 np.testing.assert_array_equal([[b"0", b"-1"], [b"1", b"-1"]], data["col1d"]) 236 np.testing.assert_array_equal([[[b"100", b"-2"], [b"200", b"-2"]], [[b"101", b"-2"], [b"201", b"-2"]]], 237 data["col2d"]) 238 239 240def test_batch_padding_02(): 241 data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"]) 242 data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], "")}) 243 data1 = data1.repeat(2) 244 for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 245 np.testing.assert_array_equal([[b"0"], [b"1"]], data["col1d"]) 246 np.testing.assert_array_equal([[[b"100", b""]], [[b"101", b""]]], data["col2d"]) 247 248 249def test_batch_padding_03(): 250 data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"]) 251 data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, "PAD_VALUE")}) # pad automatically 252 data1 = data1.repeat(2) 253 res = dict() 254 for ind, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): 255 res[ind] = data["col"].copy() 256 np.testing.assert_array_equal(res[0], [[b"0", b"PAD_VALUE"], [0, 1]]) 257 np.testing.assert_array_equal(res[1], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]]) 258 np.testing.assert_array_equal(res[2], [[b"0", b"PAD_VALUE"], [b"0", b"1"]]) 259 np.testing.assert_array_equal(res[3], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]]) 260 261 262def test_batch_padding_04(): 263 data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"]) 264 data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically 265 data1 = data1.repeat(2) 266 for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 267 np.testing.assert_array_equal(data["col1"], [[b"0", b""], [b"0", b"1"]]) 268 np.testing.assert_array_equal(data["col2"], [[b"100", b""], [b"100", b"101"]]) 269 270 271def test_batch_padding_05(): 272 data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"]) 273 data1 = data1.batch(batch_size=3, drop_remainder=False, 274 pad_info={"col2": ([2, None], "-2"), "col1": (None, "-1")}) # pad automatically 275 for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 276 np.testing.assert_array_equal(data["col1"], 277 [[[b"0", b"-1", b"-1"]], [[b"0", b"1", b"-1"]], [[b"0", b"1", b"2"]]]) 278 np.testing.assert_array_equal(data["col2"], 279 [[[b"100", b"-2", b"-2"], [b"-2", b"-2", b"-2"]], 280 [[b"100", b"101", b"-2"], [b"-2", b"-2", b"-2"]], 281 [[b"100", b"101", b"102"], [b"-2", b"-2", b"-2"]]]) 282 283 284if __name__ == '__main__': 285 test_generator() 286 test_basic() 287 test_batching_strings() 288 test_map() 289 test_map2() 290 test_tfrecord1() 291 test_tfrecord2() 292 test_tfrecord3() 293 test_mindrecord() 294 test_batch_padding_01() 295 test_batch_padding_02() 296 test_batch_padding_03() 297 test_batch_padding_04() 298 test_batch_padding_05() 299