1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from util import save_and_check_dict, save_and_check_md5 16 17import mindspore.dataset as ds 18from mindspore import log as logger 19 20# Dataset in DIR_1 has 5 rows and 5 columns 21DATA_DIR_1 = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"] 22SCHEMA_DIR_1 = "../data/dataset/testTFBert5Rows1/datasetSchema.json" 23# Dataset in DIR_2 has 5 rows and 2 columns 24DATA_DIR_2 = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"] 25SCHEMA_DIR_2 = "../data/dataset/testTFBert5Rows2/datasetSchema.json" 26# Dataset in DIR_3 has 3 rows and 2 columns 27DATA_DIR_3 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 28SCHEMA_DIR_3 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 29# Dataset in DIR_4 has 5 rows and 7 columns 30DATA_DIR_4 = ["../data/dataset/testTFBert5Rows/5TFDatas.data"] 31SCHEMA_DIR_4 = "../data/dataset/testTFBert5Rows/datasetSchema.json" 32 33GENERATE_GOLDEN = False 34 35 36def test_zip_01(): 37 """ 38 Test zip: zip 2 datasets, #rows-data1 == #rows-data2, #cols-data1 < #cols-data2 39 """ 40 logger.info("test_zip_01") 41 ds.config.set_seed(1) 42 data1 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2) 43 data2 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 44 dataz = ds.zip((data1, data2)) 45 # Note: zipped dataset has 5 rows and 7 columns 46 filename = "zip_01_result.npz" 47 save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN) 48 49 50def test_zip_02(): 51 """ 52 Test zip: zip 2 datasets, #rows-data1 < #rows-data2, #cols-data1 == #cols-data2 53 """ 54 logger.info("test_zip_02") 55 ds.config.set_seed(1) 56 data1 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3) 57 data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2) 58 dataz = ds.zip((data1, data2)) 59 # Note: zipped dataset has 3 rows and 4 columns 60 filename = "zip_02_result.npz" 61 save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN) 62 63 64def test_zip_03(): 65 """ 66 Test zip: zip 2 datasets, #rows-data1 > #rows-data2, #cols-data1 > #cols-data2 67 """ 68 logger.info("test_zip_03") 69 ds.config.set_seed(1) 70 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 71 data2 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3) 72 dataz = ds.zip((data1, data2)) 73 # Note: zipped dataset has 3 rows and 7 columns 74 filename = "zip_03_result.npz" 75 save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN) 76 77 78def test_zip_04(): 79 """ 80 Test zip: zip >2 datasets 81 """ 82 logger.info("test_zip_04") 83 ds.config.set_seed(1) 84 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 85 data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2) 86 data3 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3) 87 dataz = ds.zip((data1, data2, data3)) 88 # Note: zipped dataset has 3 rows and 9 columns 89 filename = "zip_04_result.npz" 90 save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN) 91 92 93def test_zip_05(): 94 """ 95 Test zip: zip dataset with renamed columns 96 """ 97 logger.info("test_zip_05") 98 ds.config.set_seed(1) 99 data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=True) 100 data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=True) 101 102 data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids") 103 data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids") 104 105 dataz = ds.zip((data1, data2)) 106 # Note: zipped dataset has 5 rows and 9 columns 107 filename = "zip_05_result.npz" 108 save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN) 109 110 111def test_zip_06(): 112 """ 113 Test zip: zip dataset with renamed columns and repeat zipped dataset 114 """ 115 logger.info("test_zip_06") 116 ds.config.set_seed(1) 117 data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=False) 118 data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False) 119 120 data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids") 121 data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids") 122 123 dataz = ds.zip((data1, data2)) 124 dataz = dataz.repeat(2) 125 # Note: resultant dataset has 10 rows and 9 columns 126 filename = "zip_06_result.npz" 127 save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN) 128 129 130def test_zip_exception_01(): 131 """ 132 Test zip: zip same datasets 133 """ 134 logger.info("test_zip_exception_01") 135 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 136 137 try: 138 dataz = ds.zip((data1, data1)) 139 140 num_iter = 0 141 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 142 logger.info("item[input_mask] is {}".format(item["input_mask"])) 143 num_iter += 1 144 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 145 146 except Exception as e: 147 logger.info("Got an exception in DE: {}".format(str(e))) 148 149 150def test_zip_exception_02(): 151 """ 152 Test zip: zip datasets with duplicate column name 153 """ 154 logger.info("test_zip_exception_02") 155 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 156 data2 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4) 157 158 try: 159 dataz = ds.zip((data1, data2)) 160 161 num_iter = 0 162 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 163 logger.info("item[input_mask] is {}".format(item["input_mask"])) 164 num_iter += 1 165 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 166 167 except Exception as e: 168 logger.info("Got an exception in DE: {}".format(str(e))) 169 170 171def test_zip_exception_03(): 172 """ 173 Test zip: zip with tuple of 1 dataset 174 """ 175 logger.info("test_zip_exception_03") 176 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 177 178 try: 179 dataz = ds.zip((data1)) 180 dataz = dataz.repeat(2) 181 182 num_iter = 0 183 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 184 logger.info("item[input_mask] is {}".format(item["input_mask"])) 185 num_iter += 1 186 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 187 188 except Exception as e: 189 logger.info("Got an exception in DE: {}".format(str(e))) 190 191 192def test_zip_exception_04(): 193 """ 194 Test zip: zip with empty tuple of datasets 195 """ 196 logger.info("test_zip_exception_04") 197 198 try: 199 dataz = ds.zip(()) 200 dataz = dataz.repeat(2) 201 202 num_iter = 0 203 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 204 logger.info("item[input_mask] is {}".format(item["input_mask"])) 205 num_iter += 1 206 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 207 208 except Exception as e: 209 logger.info("Got an exception in DE: {}".format(str(e))) 210 211 212def test_zip_exception_05(): 213 """ 214 Test zip: zip with non-tuple of 2 datasets 215 """ 216 logger.info("test_zip_exception_05") 217 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 218 data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2) 219 220 try: 221 dataz = ds.zip(data1, data2) 222 223 num_iter = 0 224 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 225 logger.info("item[input_mask] is {}".format(item["input_mask"])) 226 num_iter += 1 227 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 228 229 except Exception as e: 230 logger.info("Got an exception in DE: {}".format(str(e))) 231 232 233def test_zip_exception_06(): 234 """ 235 Test zip: zip with non-tuple of 1 dataset 236 """ 237 logger.info("test_zip_exception_06") 238 data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 239 240 try: 241 dataz = ds.zip(data1) 242 243 num_iter = 0 244 for _, item in enumerate(dataz.create_dict_iterator(num_epochs=1, output_numpy=True)): 245 logger.info("item[input_mask] is {}".format(item["input_mask"])) 246 num_iter += 1 247 logger.info("Number of data in zipped dataz: {}".format(num_iter)) 248 249 except Exception as e: 250 logger.info("Got an exception in DE: {}".format(str(e))) 251 252 253def test_zip_exception_07(): 254 """ 255 Test zip: zip with string as parameter 256 """ 257 logger.info("test_zip_exception_07") 258 259 try: 260 dataz = ds.zip(('dataset1', 'dataset2')) 261 262 num_iter = 0 263 for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): 264 num_iter += 1 265 assert False 266 267 except Exception as e: 268 logger.info("Got an exception in DE: {}".format(str(e))) 269 270 try: 271 data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) 272 dataz = data.zip(('dataset1',)) 273 274 num_iter = 0 275 for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): 276 num_iter += 1 277 assert False 278 279 except Exception as e: 280 logger.info("Got an exception in DE: {}".format(str(e))) 281 282if __name__ == '__main__': 283 test_zip_01() 284 test_zip_02() 285 test_zip_03() 286 test_zip_04() 287 test_zip_05() 288 test_zip_06() 289 test_zip_exception_01() 290 test_zip_exception_02() 291 test_zip_exception_03() 292 test_zip_exception_04() 293 test_zip_exception_05() 294 test_zip_exception_06() 295 test_zip_exception_07() 296