• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from util import save_and_check_dict, save_and_check_md5
16
17import mindspore.dataset as ds
18from mindspore import log as logger
19
20# Dataset in DIR_1 has 5 rows and 5 columns
21DATA_DIR_1 = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"]
22SCHEMA_DIR_1 = "../data/dataset/testTFBert5Rows1/datasetSchema.json"
23# Dataset in DIR_2 has 5 rows and 2 columns
24DATA_DIR_2 = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"]
25SCHEMA_DIR_2 = "../data/dataset/testTFBert5Rows2/datasetSchema.json"
26# Dataset in DIR_3 has 3 rows and 2 columns
27DATA_DIR_3 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR_3 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29# Dataset in DIR_4 has 5 rows and 7 columns
30DATA_DIR_4 = ["../data/dataset/testTFBert5Rows/5TFDatas.data"]
31SCHEMA_DIR_4 = "../data/dataset/testTFBert5Rows/datasetSchema.json"
32
33GENERATE_GOLDEN = False
34
35
36def test_zip_01():
37    """
38    Test zip: zip 2 datasets, #rows-data1 == #rows-data2, #cols-data1 < #cols-data2
39    """
40    logger.info("test_zip_01")
41    ds.config.set_seed(1)
42    data1 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
43    data2 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
44    dataz = ds.zip((data1, data2))
45    # Note: zipped dataset has 5 rows and 7 columns
46    filename = "zip_01_result.npz"
47    save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
48
49
50def test_zip_02():
51    """
52    Test zip: zip 2 datasets, #rows-data1 < #rows-data2, #cols-data1 == #cols-data2
53    """
54    logger.info("test_zip_02")
55    ds.config.set_seed(1)
56    data1 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
57    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
58    dataz = ds.zip((data1, data2))
59    # Note: zipped dataset has 3 rows and 4 columns
60    filename = "zip_02_result.npz"
61    save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
62
63
64def test_zip_03():
65    """
66    Test zip: zip 2 datasets, #rows-data1 > #rows-data2, #cols-data1 > #cols-data2
67    """
68    logger.info("test_zip_03")
69    ds.config.set_seed(1)
70    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
71    data2 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
72    dataz = ds.zip((data1, data2))
73    # Note: zipped dataset has 3 rows and 7 columns
74    filename = "zip_03_result.npz"
75    save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
76
77
78def test_zip_04():
79    """
80    Test zip: zip >2 datasets
81    """
82    logger.info("test_zip_04")
83    ds.config.set_seed(1)
84    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
85    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
86    data3 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
87    dataz = ds.zip((data1, data2, data3))
88    # Note: zipped dataset has 3 rows and 9 columns
89    filename = "zip_04_result.npz"
90    save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
91
92
93def test_zip_05():
94    """
95    Test zip: zip dataset with renamed columns
96    """
97    logger.info("test_zip_05")
98    ds.config.set_seed(1)
99    data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=True)
100    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=True)
101
102    data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
103    data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
104
105    dataz = ds.zip((data1, data2))
106    # Note: zipped dataset has 5 rows and 9 columns
107    filename = "zip_05_result.npz"
108    save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
109
110
111def test_zip_06():
112    """
113    Test zip: zip dataset with renamed columns and repeat zipped dataset
114    """
115    logger.info("test_zip_06")
116    ds.config.set_seed(1)
117    data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=False)
118    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False)
119
120    data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
121    data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
122
123    dataz = ds.zip((data1, data2))
124    dataz = dataz.repeat(2)
125    # Note: resultant dataset has 10 rows and 9 columns
126    filename = "zip_06_result.npz"
127    save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
128
129
130def test_zip_exception_01():
131    """
132    Test zip: zip same datasets
133    """
134    logger.info("test_zip_exception_01")
135    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
136
137    try:
138        dataz = ds.zip((data1, data1))
139
140        num_iter = 0
141        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
142            logger.info("item[input_mask] is {}".format(item["input_mask"]))
143            num_iter += 1
144        logger.info("Number of data in zipped dataz: {}".format(num_iter))
145
146    except Exception as e:
147        logger.info("Got an exception in DE: {}".format(str(e)))
148
149
150def test_zip_exception_02():
151    """
152    Test zip: zip datasets with duplicate column name
153    """
154    logger.info("test_zip_exception_02")
155    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
156    data2 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4)
157
158    try:
159        dataz = ds.zip((data1, data2))
160
161        num_iter = 0
162        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
163            logger.info("item[input_mask] is {}".format(item["input_mask"]))
164            num_iter += 1
165        logger.info("Number of data in zipped dataz: {}".format(num_iter))
166
167    except Exception as e:
168        logger.info("Got an exception in DE: {}".format(str(e)))
169
170
171def test_zip_exception_03():
172    """
173    Test zip: zip with tuple of 1 dataset
174    """
175    logger.info("test_zip_exception_03")
176    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
177
178    try:
179        dataz = ds.zip((data1))
180        dataz = dataz.repeat(2)
181
182        num_iter = 0
183        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
184            logger.info("item[input_mask] is {}".format(item["input_mask"]))
185            num_iter += 1
186        logger.info("Number of data in zipped dataz: {}".format(num_iter))
187
188    except Exception as e:
189        logger.info("Got an exception in DE: {}".format(str(e)))
190
191
192def test_zip_exception_04():
193    """
194    Test zip: zip with empty tuple of datasets
195    """
196    logger.info("test_zip_exception_04")
197
198    try:
199        dataz = ds.zip(())
200        dataz = dataz.repeat(2)
201
202        num_iter = 0
203        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
204            logger.info("item[input_mask] is {}".format(item["input_mask"]))
205            num_iter += 1
206        logger.info("Number of data in zipped dataz: {}".format(num_iter))
207
208    except Exception as e:
209        logger.info("Got an exception in DE: {}".format(str(e)))
210
211
212def test_zip_exception_05():
213    """
214    Test zip: zip with non-tuple of 2 datasets
215    """
216    logger.info("test_zip_exception_05")
217    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
218    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
219
220    try:
221        dataz = ds.zip(data1, data2)
222
223        num_iter = 0
224        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
225            logger.info("item[input_mask] is {}".format(item["input_mask"]))
226            num_iter += 1
227        logger.info("Number of data in zipped dataz: {}".format(num_iter))
228
229    except Exception as e:
230        logger.info("Got an exception in DE: {}".format(str(e)))
231
232
233def test_zip_exception_06():
234    """
235    Test zip: zip with non-tuple of 1 dataset
236    """
237    logger.info("test_zip_exception_06")
238    data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
239
240    try:
241        dataz = ds.zip(data1)
242
243        num_iter = 0
244        for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)):
245            logger.info("item[input_mask] is {}".format(item["input_mask"]))
246            num_iter += 1
247        logger.info("Number of data in zipped dataz: {}".format(num_iter))
248
249    except Exception as e:
250        logger.info("Got an exception in DE: {}".format(str(e)))
251
252
253def test_zip_exception_07():
254    """
255    Test zip: zip with string as parameter
256    """
257    logger.info("test_zip_exception_07")
258
259    try:
260        dataz = ds.zip(('dataset1', 'dataset2'))
261
262        num_iter = 0
263        for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
264            num_iter += 1
265        assert False
266
267    except Exception as e:
268        logger.info("Got an exception in DE: {}".format(str(e)))
269
270    try:
271        data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
272        dataz = data.zip(('dataset1',))
273
274        num_iter = 0
275        for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
276            num_iter += 1
277        assert False
278
279    except Exception as e:
280        logger.info("Got an exception in DE: {}".format(str(e)))
281
282if __name__ == '__main__':
283    test_zip_01()
284    test_zip_02()
285    test_zip_03()
286    test_zip_04()
287    test_zip_05()
288    test_zip_06()
289    test_zip_exception_01()
290    test_zip_exception_02()
291    test_zip_exception_03()
292    test_zip_exception_04()
293    test_zip_exception_05()
294    test_zip_exception_06()
295    test_zip_exception_07()
296