• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import mindspore._c_dataengine as cde
17
18import mindspore.common.dtype as mstype
19import mindspore.dataset as ds
20from mindspore.dataset.text import to_str, to_bytes
21
22
23def test_basic():
24    x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
25    n = cde.Tensor(x)
26    arr = n.as_array()
27    np.testing.assert_array_equal(x, arr)
28
29
30def compare(strings, dtype='S'):
31    arr = np.array(strings, dtype=dtype)
32
33    def gen():
34        (yield arr,)
35
36    data = ds.GeneratorDataset(gen, column_names=["col"])
37
38    for d in data.create_tuple_iterator(output_numpy=True):
39        np.testing.assert_array_equal(d[0], arr.astype('S'))
40
41
42def test_generator():
43    compare(["ab"])
44    compare(["", ""])
45    compare([""])
46    compare(["ab", ""])
47    compare(["ab", "cde", "121"])
48    compare([["ab", "cde", "121"], ["x", "km", "789"]])
49    compare([["ab", "", "121"], ["", "km", "789"]])
50    compare(["ab"], dtype='U')
51    compare(["", ""], dtype='U')
52    compare([""], dtype='U')
53    compare(["ab", ""], dtype='U')
54    compare(["", ""], dtype='U')
55    compare(["", "ab"], dtype='U')
56    compare(["ab", "cde", "121"], dtype='U')
57    compare([["ab", "cde", "121"], ["x", "km", "789"]], dtype='U')
58    compare([["ab", "", "121"], ["", "km", "789"]], dtype='U')
59
60
61line = np.array(["This is a text file.",
62                 "Be happy every day.",
63                 "Good luck to everyone."])
64
65words = np.array([["This", "text", "file", "a"],
66                  ["Be", "happy", "day", "b"],
67                  ["女", "", "everyone", "c"]])
68
69chinese = np.array(["今天天气太好了我们一起去外面玩吧",
70                    "男默女泪",
71                    "江州市长江大桥参加了长江大桥的通车仪式"])
72
73
74def test_batching_strings():
75    def gen():
76        for row in chinese:
77            yield (np.array(row),)
78
79    data = ds.GeneratorDataset(gen, column_names=["col"])
80    data = data.batch(2, drop_remainder=True)
81
82    for d in data.create_tuple_iterator(output_numpy=True):
83        np.testing.assert_array_equal(d[0], to_bytes(chinese[0:2]))
84
85
86def test_map():
87    def gen():
88        yield (np.array(["ab cde 121"], dtype='S'),)
89
90    data = ds.GeneratorDataset(gen, column_names=["col"])
91
92    def split(b):
93        s = to_str(b)
94        splits = s.item().split()
95        return np.array(splits)
96
97    data = data.map(operations=split, input_columns=["col"])
98    expected = np.array(["ab", "cde", "121"], dtype='S')
99    for d in data.create_tuple_iterator(output_numpy=True):
100        np.testing.assert_array_equal(d[0], expected)
101
102
103def test_map2():
104    def gen():
105        yield (np.array(["ab cde 121"], dtype='S'),)
106
107    data = ds.GeneratorDataset(gen, column_names=["col"])
108
109    def upper(b):
110        out = np.char.upper(b)
111        return out
112
113    data = data.map(operations=upper, input_columns=["col"])
114    expected = np.array(["AB CDE 121"], dtype='S')
115    for d in data.create_tuple_iterator(output_numpy=True):
116        np.testing.assert_array_equal(d[0], expected)
117
118
119def test_tfrecord1():
120    s = ds.Schema()
121    s.add_column("line", "string", [])
122    s.add_column("words", "string", [-1])
123    s.add_column("chinese", "string", [])
124
125    data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
126
127    for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
128        assert d["line"].shape == line[i].shape
129        assert d["words"].shape == words[i].shape
130        assert d["chinese"].shape == chinese[i].shape
131        np.testing.assert_array_equal(line[i], to_str(d["line"]))
132        np.testing.assert_array_equal(words[i], to_str(d["words"]))
133        np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
134
135
136def test_tfrecord2():
137    data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False,
138                              schema='../data/dataset/testTextTFRecord/datasetSchema.json')
139    for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
140        assert d["line"].shape == line[i].shape
141        assert d["words"].shape == words[i].shape
142        assert d["chinese"].shape == chinese[i].shape
143        np.testing.assert_array_equal(line[i], to_str(d["line"]))
144        np.testing.assert_array_equal(words[i], to_str(d["words"]))
145        np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
146
147
148def test_tfrecord3():
149    s = ds.Schema()
150    s.add_column("line", mstype.string, [])
151    s.add_column("words", mstype.string, [-1, 2])
152    s.add_column("chinese", mstype.string, [])
153
154    data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
155
156    for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
157        assert d["line"].shape == line[i].shape
158        assert d["words"].shape == words[i].reshape([2, 2]).shape
159        assert d["chinese"].shape == chinese[i].shape
160        np.testing.assert_array_equal(line[i], to_str(d["line"]))
161        np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
162        np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
163
164
165def create_text_mindrecord():
166    # methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord
167    from mindspore.mindrecord import FileWriter
168
169    mindrecord_file_name = "test.mindrecord"
170    data = [{"english": "This is a text file.",
171             "chinese": "今天天气太好了我们一起去外面玩吧"},
172            {"english": "Be happy every day.",
173             "chinese": "男默女泪"},
174            {"english": "Good luck to everyone.",
175             "chinese": "江州市长江大桥参加了长江大桥的通车仪式"},
176            ]
177    writer = FileWriter(mindrecord_file_name)
178    schema = {"english": {"type": "string"},
179              "chinese": {"type": "string"},
180              }
181    writer.add_schema(schema)
182    writer.write_raw_data(data)
183    writer.commit()
184
185
186def test_mindrecord():
187    data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False)
188
189    for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
190        assert d["english"].shape == line[i].shape
191        assert d["chinese"].shape == chinese[i].shape
192        np.testing.assert_array_equal(line[i], to_str(d["english"]))
193        np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
194
195
196# The following tests cases were copied from test_pad_batch but changed to strings instead
197
198
199# this generator function yield two columns
200# col1d: [0],[1], [2], [3]
201# col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]]
202def gen_2cols(num):
203    for i in range(num):
204        yield (np.array([str(i)]), np.array([[str(i + 100)], [str(i + 200)]]))
205
206
207# this generator function yield one column of variable shapes
208# col: [0], [0,1], [0,1,2], [0,1,2,3]
209def gen_var_col(num):
210    for i in range(num):
211        yield (np.array([str(j) for j in range(i + 1)]),)
212
213
214# this generator function yield two columns of variable shapes
215# col1: [0], [0,1], [0,1,2], [0,1,2,3]
216# col2: [100], [100,101], [100,101,102], [100,110,102,103]
217def gen_var_cols(num):
218    for i in range(num):
219        yield (np.array([str(j) for j in range(i + 1)]), np.array([str(100 + j) for j in range(i + 1)]))
220
221
222# this generator function yield two columns of variable shapes
223# col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]]
224# col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]]
225def gen_var_cols_2d(num):
226    for i in range(num):
227        yield (np.array([[str(j) for j in range(i + 1)]]), np.array([[str(100 + j) for j in range(i + 1)]]))
228
229
230def test_batch_padding_01():
231    data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
232    data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], b"-2"), "col1d": ([2], b"-1")})
233    data1 = data1.repeat(2)
234    for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
235        np.testing.assert_array_equal([[b"0", b"-1"], [b"1", b"-1"]], data["col1d"])
236        np.testing.assert_array_equal([[[b"100", b"-2"], [b"200", b"-2"]], [[b"101", b"-2"], [b"201", b"-2"]]],
237                                      data["col2d"])
238
239
240def test_batch_padding_02():
241    data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
242    data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], "")})
243    data1 = data1.repeat(2)
244    for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
245        np.testing.assert_array_equal([[b"0"], [b"1"]], data["col1d"])
246        np.testing.assert_array_equal([[[b"100", b""]], [[b"101", b""]]], data["col2d"])
247
248
249def test_batch_padding_03():
250    data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"])
251    data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, "PAD_VALUE")})  # pad automatically
252    data1 = data1.repeat(2)
253    res = dict()
254    for ind, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
255        res[ind] = data["col"].copy()
256    np.testing.assert_array_equal(res[0], [[b"0", b"PAD_VALUE"], [0, 1]])
257    np.testing.assert_array_equal(res[1], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
258    np.testing.assert_array_equal(res[2], [[b"0", b"PAD_VALUE"], [b"0", b"1"]])
259    np.testing.assert_array_equal(res[3], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
260
261
262def test_batch_padding_04():
263    data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"])
264    data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={})  # pad automatically
265    data1 = data1.repeat(2)
266    for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
267        np.testing.assert_array_equal(data["col1"], [[b"0", b""], [b"0", b"1"]])
268        np.testing.assert_array_equal(data["col2"], [[b"100", b""], [b"100", b"101"]])
269
270
271def test_batch_padding_05():
272    data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"])
273    data1 = data1.batch(batch_size=3, drop_remainder=False,
274                        pad_info={"col2": ([2, None], "-2"), "col1": (None, "-1")})  # pad automatically
275    for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
276        np.testing.assert_array_equal(data["col1"],
277                                      [[[b"0", b"-1", b"-1"]], [[b"0", b"1", b"-1"]], [[b"0", b"1", b"2"]]])
278        np.testing.assert_array_equal(data["col2"],
279                                      [[[b"100", b"-2", b"-2"], [b"-2", b"-2", b"-2"]],
280                                       [[b"100", b"101", b"-2"], [b"-2", b"-2", b"-2"]],
281                                       [[b"100", b"101", b"102"], [b"-2", b"-2", b"-2"]]])
282
283
284if __name__ == '__main__':
285    test_generator()
286    test_basic()
287    test_batching_strings()
288    test_map()
289    test_map2()
290    test_tfrecord1()
291    test_tfrecord2()
292    test_tfrecord3()
293    test_mindrecord()
294    test_batch_padding_01()
295    test_batch_padding_02()
296    test_batch_padding_03()
297    test_batch_padding_04()
298    test_batch_padding_05()
299