1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing cache operation with non-mappable datasets 17""" 18import os 19import itertools 20import numpy as np 21import pytest 22import mindspore.common.dtype as mstype 23import mindspore.dataset as ds 24import mindspore.dataset.text as text 25import mindspore.dataset.vision as c_vision 26from mindspore import log as logger 27 28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 30 31TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"] 32SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json" 33 34TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", 35 "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", 36 "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", 37 "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] 38TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" 39 40IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/" 41CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' 42CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' 43TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" 44 45PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"] 46PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json" 47 48GENERATE_GOLDEN = False 49 50 51@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 52def test_cache_nomap_basic1(): 53 """ 54 Feature: DatasetCache op 55 Description: Test a RandomDataset (a non mappable dataset) with a Cache over it just after the leaf 56 Expectation: Output is equal to the expected output 57 """ 58 logger.info("Test cache nomap basic 1") 59 if "SESSION_ID" in os.environ: 60 session_id = int(os.environ['SESSION_ID']) 61 else: 62 raise RuntimeError("Testcase requires SESSION_ID environment variable") 63 64 schema = ds.Schema() 65 schema.add_column('image', de_type=mstype.uint8, 66 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 67 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 68 69 # create a cache. arbitrary session_id for now 70 some_cache = ds.DatasetCache(session_id=session_id, size=0) 71 72 # User-created sampler here 73 ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) 74 ds1 = ds1.repeat(4) 75 76 num_iter = 0 77 for data in ds1.create_dict_iterator(num_epochs=1): 78 logger.info("printing the label: {}".format(data["label"])) 79 num_iter += 1 80 81 logger.info("Number of data in ds1: {} ".format(num_iter)) 82 assert num_iter == 40 83 logger.info("test_cache_nomap_basic1 Ended.\n") 84 85 86@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 87def test_cache_nomap_basic2(): 88 """ 89 Feature: DatasetCache op 90 Description: Test RandomDataset (a non mappable dataset with num_samples) with a Cache over it just after the leaf 91 Expectation: Output is equal to the expected output 92 """ 93 logger.info("Test cache nomap basic 2") 94 if "SESSION_ID" in os.environ: 95 session_id = int(os.environ['SESSION_ID']) 96 else: 97 raise RuntimeError("Testcase requires SESSION_ID environment variable") 98 99 schema = ds.Schema() 100 schema.add_column('image', de_type=mstype.uint8, 101 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 102 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 103 104 # create a cache. arbitrary session_id for now 105 some_cache = ds.DatasetCache(session_id=session_id, size=0) 106 107 # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: 108 # num_samples, shuffle, num_shards, shard_id 109 # In this case, the presence of num_samples chooses a sampler. 110 ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache) 111 ds1 = ds1.repeat(2) 112 113 num_iter = 0 114 for data in ds1.create_dict_iterator(num_epochs=1): 115 logger.info("printing the label: {}".format(data["label"])) 116 num_iter += 1 117 118 logger.info("Number of data in ds1: {} ".format(num_iter)) 119 assert num_iter == 40 120 logger.info("test_cache_nomap_basic2 Ended.\n") 121 122 123@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 124def test_cache_nomap_basic3(): 125 """ 126 Feature: DatasetCache op 127 Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf 128 129 Repeat 130 | 131 Map(Decode) 132 | 133 Cache 134 | 135 TFReader 136 137 Expectation: Output is equal to the expected output 138 """ 139 logger.info("Test cache nomap basic 3") 140 if "SESSION_ID" in os.environ: 141 session_id = int(os.environ['SESSION_ID']) 142 else: 143 raise RuntimeError("Testcase requires SESSION_ID environment variable") 144 145 some_cache = ds.DatasetCache(session_id=session_id, size=0) 146 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) 147 decode_op = c_vision.Decode() 148 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 149 ds1 = ds1.repeat(4) 150 151 num_iter = 0 152 for _ in ds1.create_dict_iterator(num_epochs=1): 153 num_iter += 1 154 155 logger.info("Number of data in ds1: {} ".format(num_iter)) 156 assert num_iter == 12 157 158 # Contact the server to get the statistics 159 stat = some_cache.get_stat() 160 cache_sz = stat.avg_cache_sz 161 num_mem_cached = stat.num_mem_cached 162 num_disk_cached = stat.num_disk_cached 163 164 logger.info("Number of rows cached in memory: {}".format(num_mem_cached)) 165 logger.info("Number of rows spilled to disk: {}".format(num_disk_cached)) 166 logger.info("Average row cache size: {}".format(cache_sz)) 167 168 logger.info("test_cache_nomap_basic3 Ended.\n") 169 170 171@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 172def test_cache_nomap_basic4(): 173 """ 174 Feature: DatasetCache op 175 Description: Test a TFReaderDataset (a non mappable dataset) with a map Decode and Cache after it 176 Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf. 177 But, if there's a cache later, that shuffle becomes invalid and should be removed. 178 179 Repeat 180 | 181 Cache 182 | 183 Map(Decode) 184 | 185 TFReader 186 187 Expectation: Output is equal to the expected output 188 """ 189 logger.info("Test cache nomap basic 4") 190 if "SESSION_ID" in os.environ: 191 session_id = int(os.environ['SESSION_ID']) 192 else: 193 raise RuntimeError("Testcase requires SESSION_ID environment variable") 194 195 # This dataset has 3 records in it only 196 some_cache = ds.DatasetCache(session_id=session_id, size=0) 197 # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache 198 # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will 199 # explicitly give the global option, even though it's the default in python. 200 # But, when caching is added in the ascendent tree above TF, we do global shuffling 201 # through the sampler over the cache, not by the shuffle op. In that case, tree prepare 202 # will remove the shuffle op that got injected by the initial tree creation. 203 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL) 204 decode_op = c_vision.Decode() 205 206 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 207 ds1 = ds1.repeat(4) 208 209 num_iter = 0 210 for _ in ds1.create_dict_iterator(num_epochs=1): 211 num_iter += 1 212 213 logger.info("Number of data in ds1: {} ".format(num_iter)) 214 assert num_iter == 12 215 logger.info("test_cache_nomap_basic4 Ended.\n") 216 217 218@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 219def test_cache_nomap_basic5(): 220 """ 221 Feature: DatasetCache op 222 Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf. 223 Same as test 3, but this one does not have Shuffle arg, causing TF to default to global 224 shuffle which attempts to inject a Shuffle operation. However, since there is a Cache 225 we do not need global shuffle, so the shuffle will not be built. It ends up being 226 identical to test basic 3, however we arrive at the same tree in different codepaths 227 (if there was no Cache, then the Shuffle is built) 228 229 Repeat 230 | 231 Map(Decode) 232 | 233 Cache 234 | 235 TFReader 236 237 Expectation: Output is equal to the expected output 238 """ 239 logger.info("Test cache nomap basic 5") 240 if "SESSION_ID" in os.environ: 241 session_id = int(os.environ['SESSION_ID']) 242 else: 243 raise RuntimeError("Testcase requires SESSION_ID environment variable") 244 245 # This dataset has 3 records in it only 246 some_cache = ds.DatasetCache(session_id=session_id, size=0) 247 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) 248 decode_op = c_vision.Decode() 249 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 250 ds1 = ds1.repeat(4) 251 252 num_iter = 0 253 for _ in ds1.create_dict_iterator(num_epochs=1): 254 num_iter += 1 255 256 logger.info("Number of data in ds1: {} ".format(num_iter)) 257 assert num_iter == 12 258 logger.info("test_cache_nomap_basic5 Ended.\n") 259 260 261@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 262def test_cache_nomap_basic6(): 263 """ 264 Feature: DatasetCache op 265 Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf 266 In this one, the TFReaderDataset will be given sharding configuration, however since a Cache is 267 used, the tree prepare should undo the sharding configuration and instead, a distributed 268 sampler will be chosen with the same shard config. 269 270 Repeat 271 | 272 Map(Decode) 273 | 274 Cache 275 | 276 TFReader 277 278 Expectation: Output is equal to the expected output 279 """ 280 logger.info("Test cache nomap basic 6") 281 if "SESSION_ID" in os.environ: 282 session_id = int(os.environ['SESSION_ID']) 283 else: 284 raise RuntimeError("Testcase requires SESSION_ID environment variable") 285 286 # This dataset has 3 records in it only 287 some_cache = ds.DatasetCache(session_id=session_id, size=0) 288 289 # With only 3 records shard into 3, we expect only 1 record returned for this shard 290 # However, the sharding will be done by the sampler, not by the tf record leaf node 291 # In this case, it is a row-based sharding, not the file-based sharding that would happen if 292 # there was not any cache. 293 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache) 294 decode_op = c_vision.Decode() 295 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 296 ds1 = ds1.repeat(4) 297 298 num_iter = 0 299 for _ in ds1.create_dict_iterator(num_epochs=1): 300 num_iter += 1 301 302 logger.info("Number of data in ds1: {} ".format(num_iter)) 303 assert num_iter == 4 304 logger.info("test_cache_nomap_basic6 Ended.\n") 305 306 307@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 308def test_cache_nomap_basic7(): 309 """ 310 Feature: DatasetCache op 311 Description: Test a TFReaderDataset (a non mappable dataset) that uses global shuffle, and is Cached followed by 312 Map. In this one, the TFReaderDataset with global shuffle might want to inject a Shuffle op over top of the 313 TFReaderDataset, but since a Cache is given, it will choose not to. 314 315 Repeat 316 | 317 Map(Decode) 318 | 319 cache 320 | 321 TFReader 322 323 Expectation: Output is equal to the expected output 324 """ 325 logger.info("Test cache nomap basic 7") 326 if "SESSION_ID" in os.environ: 327 session_id = int(os.environ['SESSION_ID']) 328 else: 329 raise RuntimeError("Testcase requires SESSION_ID environment variable") 330 331 some_cache = ds.DatasetCache(session_id=session_id, size=0) 332 333 # This dataset has 3 records in it only 334 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) 335 decode_op = c_vision.Decode() 336 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 337 ds1 = ds1.repeat(4) 338 339 num_iter = 0 340 for _ in ds1.create_dict_iterator(num_epochs=1): 341 num_iter += 1 342 343 logger.info("Number of data in ds1: {} ".format(num_iter)) 344 assert num_iter == 12 345 logger.info("test_cache_nomap_basic7 Ended.\n") 346 347 348@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 349def test_cache_nomap_basic8(): 350 """ 351 Feature: DatasetCache op 352 Description: Test Cache as root node 353 354 Cache 355 | 356 TFReader 357 358 Expectation: Output is equal to the expected output 359 """ 360 logger.info("Test cache basic 8") 361 if "SESSION_ID" in os.environ: 362 session_id = int(os.environ['SESSION_ID']) 363 else: 364 raise RuntimeError("Testcase requires SESSION_ID environment variable") 365 some_cache = ds.DatasetCache(session_id=session_id, size=0) 366 367 # This dataset has 3 records in it only 368 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 369 num_iter = 0 370 for _ in ds1.create_dict_iterator(num_epochs=1): 371 logger.info("get data from dataset") 372 num_iter += 1 373 374 logger.info("Number of data in ds1: {} ".format(num_iter)) 375 assert num_iter == 3 376 logger.info('test_cache_basic8 Ended.\n') 377 378 379@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 380def test_cache_nomap_basic9(): 381 """ 382 Feature: DatasetCache op 383 Description: Testing the get_stat interface for getting some info from server but Cache is not created in pipeline 384 Expectation: Error is raised as expected 385 """ 386 logger.info("Test cache nomap basic 9") 387 if "SESSION_ID" in os.environ: 388 session_id = int(os.environ['SESSION_ID']) 389 else: 390 raise RuntimeError("Testcase requires SESSION_ID environment variable") 391 392 some_cache = ds.DatasetCache(session_id=session_id, size=0) 393 394 # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline 395 # so there will not be any cache to get stats on. 396 with pytest.raises(RuntimeError) as e: 397 stat = some_cache.get_stat() 398 cache_sz = stat.avg_cache_sz 399 logger.info("Average row cache size: {}".format(cache_sz)) 400 assert "Unexpected error" in str(e.value) 401 402 logger.info("test_cache_nomap_basic9 Ended.\n") 403 404 405@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 406def test_cache_nomap_allowed_share1(): 407 """ 408 Feature: DatasetCache op 409 Description: Test sharing the Cache between the following two trees: 410 411 Repeat Shuffle 412 | | 413 Cache Cache 414 | | 415 TFReader TFReader 416 417 Expectation: Output is equal to the expected output 418 """ 419 logger.info("Test cache nomap allowed share 1") 420 if "SESSION_ID" in os.environ: 421 session_id = int(os.environ['SESSION_ID']) 422 else: 423 raise RuntimeError("Testcase requires SESSION_ID environment variable") 424 425 ds.config.set_seed(1) 426 # This dataset has 3 records in it only 427 some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32) 428 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) 429 ds1 = ds1.repeat(4) 430 431 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) 432 ds2 = ds2.shuffle(buffer_size=2) 433 434 num_iter = 0 435 for _ in ds1.create_dict_iterator(num_epochs=1): 436 num_iter += 1 437 assert num_iter == 12 438 logger.info("Number of data in ds1: {} ".format(num_iter)) 439 440 num_iter = 0 441 for _ in ds2.create_dict_iterator(num_epochs=1): 442 num_iter += 1 443 assert num_iter == 3 444 logger.info("test_cache_nomap_allowed_share1 Ended.\n") 445 446 447@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 448def test_cache_nomap_allowed_share2(): 449 """ 450 Feature: DatasetCache op 451 Description: Test sharing the Cache between the following two trees (with Map Decode): 452 453 Repeat Shuffle 454 | | 455 Cache Cache 456 | | 457 Map(Decode) Map(Decode) 458 | | 459 TFReader TFReader 460 461 Expectation: Output is equal to the expected output 462 """ 463 logger.info("Test cache nomap allowed share 2") 464 if "SESSION_ID" in os.environ: 465 session_id = int(os.environ['SESSION_ID']) 466 else: 467 raise RuntimeError("Testcase requires SESSION_ID environment variable") 468 469 ds.config.set_seed(1) 470 # This dataset has 3 records in it only 471 some_cache = ds.DatasetCache(session_id=session_id, size=0) 472 decode_op = c_vision.Decode() 473 474 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 475 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 476 ds1 = ds1.repeat(4) 477 478 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 479 ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache) 480 ds2 = ds2.shuffle(buffer_size=2) 481 482 num_iter = 0 483 for _ in ds1.create_dict_iterator(num_epochs=1): 484 num_iter += 1 485 logger.info("Number of data in ds1: {} ".format(num_iter)) 486 assert num_iter == 12 487 488 num_iter = 0 489 for _ in ds2.create_dict_iterator(num_epochs=1): 490 num_iter += 1 491 assert num_iter == 3 492 logger.info("test_cache_nomap_allowed_share2 Ended.\n") 493 494 495@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 496def test_cache_nomap_allowed_share3(): 497 """ 498 Feature: DatasetCache op 499 Description: Test sharing the Cache between the following two trees (different shard ids): 500 501 Repeat Repeat 502 | | 503 Cache Cache 504 | | 505 TFReader(shard_id = 0) TFReader(shard_id = 1) 506 507 Expectation: Output is equal to the expected output 508 """ 509 logger.info("Test cache nomap allowed share 3") 510 if "SESSION_ID" in os.environ: 511 session_id = int(os.environ['SESSION_ID']) 512 else: 513 raise RuntimeError("Testcase requires SESSION_ID environment variable") 514 515 some_cache = ds.DatasetCache(session_id=session_id, size=0) 516 517 tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] 518 ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) 519 ds1 = ds1.repeat(4) 520 521 ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache) 522 ds2 = ds2.repeat(4) 523 524 num_iter = 0 525 for _ in ds1.create_dict_iterator(num_epochs=1): 526 num_iter += 1 527 logger.info("Number of data in ds1: {} ".format(num_iter)) 528 assert num_iter == 12 529 530 num_iter = 0 531 for _ in ds2.create_dict_iterator(num_epochs=1): 532 num_iter += 1 533 assert num_iter == 12 534 logger.info("test_cache_nomap_allowed_share3 Ended.\n") 535 536 537@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 538def test_cache_nomap_allowed_share4(): 539 """ 540 Feature: DatasetCache op 541 Description: Test sharing the Cache between the following two trees: 542 543 Cache Cache 544 | | 545 Map(Decode, num_parallel_workers=1) Map(Decode, num_parallel_workers=2) 546 | | 547 TFReader TFReader 548 549 Expectation: Output is equal to the expected output 550 """ 551 logger.info("Test cache nomap allowed share 4") 552 if "SESSION_ID" in os.environ: 553 session_id = int(os.environ['SESSION_ID']) 554 else: 555 raise RuntimeError("Testcase requires SESSION_ID environment variable") 556 557 # This dataset has 3 records in it only 558 some_cache = ds.DatasetCache(session_id=session_id, size=0) 559 decode_op = c_vision.Decode() 560 561 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 562 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1) 563 564 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 565 ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2) 566 567 num_iter = 0 568 for _ in ds1.create_dict_iterator(num_epochs=1): 569 num_iter += 1 570 logger.info("Number of data in ds1: {} ".format(num_iter)) 571 assert num_iter == 3 572 573 num_iter = 0 574 for _ in ds2.create_dict_iterator(num_epochs=1): 575 num_iter += 1 576 logger.info("Number of data in ds2: {} ".format(num_iter)) 577 assert num_iter == 3 578 579 logger.info("test_cache_nomap_allowed_share4 Ended.\n") 580 581 582@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 583def test_cache_nomap_disallowed_share1(): 584 """ 585 Feature: DatasetCache op 586 Description: Test sharing the Cache between the following two trees: 587 588 Cache Cache 589 | | 590 Map(Decode) Map(Rescale) 591 | | 592 TFReader TFReader 593 594 Expectation: Error is raised as expected 595 """ 596 logger.info("Test cache nomap disallowed share1") 597 if "SESSION_ID" in os.environ: 598 session_id = int(os.environ['SESSION_ID']) 599 else: 600 raise RuntimeError("Testcase requires SESSION_ID environment variable") 601 602 # This dataset has 3 records in it only 603 some_cache = ds.DatasetCache(session_id=session_id, size=0) 604 decode_op = c_vision.Decode() 605 rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) 606 607 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 608 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 609 610 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 611 ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache) 612 613 num_iter = 0 614 for _ in ds1.create_dict_iterator(num_epochs=1): 615 num_iter += 1 616 logger.info("Number of data in ds1: {} ".format(num_iter)) 617 assert num_iter == 3 618 619 with pytest.raises(RuntimeError) as e: 620 sum([1 for _ in ds2]) 621 assert "Cannot re-use a cache for a different tree!" in str(e.value) 622 623 logger.info("test_cache_nomap_disallowed_share1 Ended.\n") 624 625 626@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 627def test_cache_nomap_running_twice1(): 628 """ 629 Feature: DatasetCache op 630 Description: Test executing the same pipeline for twice (from Python), with Cache injected after Map 631 632 Repeat 633 | 634 Cache 635 | 636 Map(Decode) 637 | 638 TFRecord 639 640 Expectation: Output is equal to the expected output 641 """ 642 logger.info("Test cache nomap running twice 1") 643 if "SESSION_ID" in os.environ: 644 session_id = int(os.environ['SESSION_ID']) 645 else: 646 raise RuntimeError("Testcase requires SESSION_ID environment variable") 647 648 some_cache = ds.DatasetCache(session_id=session_id, size=0) 649 650 # This dataset has 3 records in it only 651 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 652 decode_op = c_vision.Decode() 653 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 654 ds1 = ds1.repeat(4) 655 656 num_iter = 0 657 for _ in ds1.create_dict_iterator(num_epochs=1): 658 num_iter += 1 659 logger.info("Number of data in ds1: {} ".format(num_iter)) 660 assert num_iter == 12 661 662 num_iter = 0 663 for _ in ds1.create_dict_iterator(num_epochs=1): 664 num_iter += 1 665 logger.info("Number of data in ds1: {} ".format(num_iter)) 666 assert num_iter == 12 667 668 logger.info("test_cache_nomap_running_twice1 Ended.\n") 669 670 671@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 672def test_cache_nomap_running_twice2(): 673 """ 674 Feature: DatasetCache op 675 Description: Test executing the same pipeline for twice (from shell), with Cache injected after leaf 676 677 Repeat 678 | 679 Map(Decode) 680 | 681 Cache 682 | 683 TFRecord 684 685 Expectation: Output is equal to the expected output 686 """ 687 logger.info("Test cache nomap running twice 2") 688 if "SESSION_ID" in os.environ: 689 session_id = int(os.environ['SESSION_ID']) 690 else: 691 raise RuntimeError("Testcase requires SESSION_ID environment variable") 692 693 some_cache = ds.DatasetCache(session_id=session_id, size=0) 694 695 # This dataset has 3 records in it only 696 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 697 decode_op = c_vision.Decode() 698 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 699 ds1 = ds1.repeat(4) 700 701 num_iter = 0 702 for _ in ds1.create_dict_iterator(num_epochs=1): 703 num_iter += 1 704 705 logger.info("Number of data in ds1: {} ".format(num_iter)) 706 assert num_iter == 12 707 logger.info("test_cache_nomap_running_twice2 Ended.\n") 708 709 710@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 711def test_cache_nomap_extra_small_size1(): 712 """ 713 Feature: DatasetCache op 714 Description: Test running pipeline with Cache of extra small size and spilling=True 715 716 Repeat 717 | 718 Map(Decode) 719 | 720 Cache 721 | 722 TFRecord 723 724 Expectation: Output is equal to the expected output 725 """ 726 logger.info("Test cache nomap extra small size 1") 727 if "SESSION_ID" in os.environ: 728 session_id = int(os.environ['SESSION_ID']) 729 else: 730 raise RuntimeError("Testcase requires SESSION_ID environment variable") 731 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True) 732 733 # This dataset has 3 records in it only 734 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 735 decode_op = c_vision.Decode() 736 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 737 ds1 = ds1.repeat(4) 738 739 num_iter = 0 740 for _ in ds1.create_dict_iterator(num_epochs=1): 741 num_iter += 1 742 743 logger.info("Number of data in ds1: {} ".format(num_iter)) 744 assert num_iter == 12 745 logger.info("test_cache_nomap_extra_small_size1 Ended.\n") 746 747 748@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 749def test_cache_nomap_extra_small_size2(): 750 """ 751 Feature: DatasetCache op 752 Description: Test running pipeline with Cache of extra small size and spilling=False 753 754 Repeat 755 | 756 Cache 757 | 758 Map(Decode) 759 | 760 TFRecord 761 762 Expectation: Error is raised as expected 763 """ 764 logger.info("Test cache nomap extra small size 2") 765 if "SESSION_ID" in os.environ: 766 session_id = int(os.environ['SESSION_ID']) 767 else: 768 raise RuntimeError("Testcase requires SESSION_ID environment variable") 769 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) 770 771 # This dataset has 3 records in it only 772 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 773 decode_op = c_vision.Decode() 774 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 775 ds1 = ds1.repeat(4) 776 777 with pytest.raises(RuntimeError) as e: 778 sum([1 for _ in ds1]) 779 assert "Out of memory" in str(e.value) 780 logger.info("test_cache_nomap_extra_small_size2 Ended.\n") 781 782 783@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 784def test_cache_nomap_parallel_pipeline1(shard): 785 """ 786 Feature: DatasetCache op 787 Description: Test running two parallel pipelines (sharing Cache) with Cache injected after leaf op 788 789 Repeat 790 | 791 Map(Decode) 792 | 793 Cache 794 | 795 TFReader 796 797 Expectation: Output is equal to the expected output 798 """ 799 logger.info("Test cache nomap parallel pipeline 1") 800 if "SESSION_ID" in os.environ: 801 session_id = int(os.environ['SESSION_ID']) 802 else: 803 raise RuntimeError("Testcase requires SESSION_ID environment variable") 804 some_cache = ds.DatasetCache(session_id=session_id, size=0) 805 806 # This dataset has 3 records in it only 807 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache) 808 decode_op = c_vision.Decode() 809 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 810 ds1 = ds1.repeat(4) 811 812 num_iter = 0 813 for _ in ds1.create_dict_iterator(num_epochs=1): 814 num_iter += 1 815 816 logger.info("Number of data in ds1: {} ".format(num_iter)) 817 assert num_iter == 4 818 logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n") 819 820 821@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 822def test_cache_nomap_parallel_pipeline2(shard): 823 """ 824 Feature: DatasetCache op 825 Description: Test running two parallel pipelines (sharing Cache) with Cache injected after Map op 826 827 Repeat 828 | 829 Cache 830 | 831 Map(Decode) 832 | 833 TFReader 834 835 Expectation: Output is equal to the expected output 836 """ 837 logger.info("Test cache nomap parallel pipeline 2") 838 if "SESSION_ID" in os.environ: 839 session_id = int(os.environ['SESSION_ID']) 840 else: 841 raise RuntimeError("Testcase requires SESSION_ID environment variable") 842 some_cache = ds.DatasetCache(session_id=session_id, size=0) 843 844 # This dataset has 3 records in it only 845 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard)) 846 decode_op = c_vision.Decode() 847 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 848 ds1 = ds1.repeat(4) 849 850 num_iter = 0 851 for _ in ds1.create_dict_iterator(num_epochs=1): 852 num_iter += 1 853 854 logger.info("Number of data in ds1: {} ".format(num_iter)) 855 assert num_iter == 4 856 logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n") 857 858 859@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 860def test_cache_nomap_parallel_workers(): 861 """ 862 Feature: DatasetCache op 863 Description: Test Cache with num_parallel_workers > 1 set for Map op and leaf op 864 865 Repeat 866 | 867 Map(Decode) 868 | 869 Cache 870 | 871 TFReader 872 873 Expectation: Output is equal to the expected output 874 """ 875 logger.info("Test cache nomap parallel workers") 876 if "SESSION_ID" in os.environ: 877 session_id = int(os.environ['SESSION_ID']) 878 else: 879 raise RuntimeError("Testcase requires SESSION_ID environment variable") 880 some_cache = ds.DatasetCache(session_id=session_id, size=0) 881 882 # This dataset has 3 records in it only 883 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4) 884 decode_op = c_vision.Decode() 885 ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache) 886 ds1 = ds1.repeat(4) 887 888 num_iter = 0 889 for _ in ds1.create_dict_iterator(num_epochs=1): 890 num_iter += 1 891 892 logger.info("Number of data in ds1: {} ".format(num_iter)) 893 assert num_iter == 12 894 logger.info("test_cache_nomap_parallel_workers Ended.\n") 895 896 897@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 898def test_cache_nomap_server_workers_1(): 899 """ 900 Feature: DatasetCache op 901 Description: Start Cache server with --workers 1 and then test Cache function 902 903 Repeat 904 | 905 Cache 906 | 907 Map(Decode) 908 | 909 TFRecord 910 911 Expectation: Output is equal to the expected output 912 """ 913 logger.info("Test cache nomap server workers 1") 914 if "SESSION_ID" in os.environ: 915 session_id = int(os.environ['SESSION_ID']) 916 else: 917 raise RuntimeError("Testcase requires SESSION_ID environment variable") 918 919 some_cache = ds.DatasetCache(session_id=session_id, size=0) 920 921 # This dataset has 3 records in it only 922 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 923 decode_op = c_vision.Decode() 924 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 925 ds1 = ds1.repeat(4) 926 927 num_iter = 0 928 for _ in ds1.create_dict_iterator(num_epochs=1): 929 num_iter += 1 930 931 logger.info("Number of data in ds1: {} ".format(num_iter)) 932 assert num_iter == 12 933 logger.info("test_cache_nomap_server_workers_1 Ended.\n") 934 935 936@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 937def test_cache_nomap_server_workers_100(): 938 """ 939 Feature: DatasetCache op 940 Description: Start Cache server with --workers 100 and then test Cache function 941 942 Repeat 943 | 944 Map(Decode) 945 | 946 Cache 947 | 948 TFRecord 949 950 Expectation: Output is equal to the expected output 951 """ 952 logger.info("Test cache nomap server workers 100") 953 if "SESSION_ID" in os.environ: 954 session_id = int(os.environ['SESSION_ID']) 955 else: 956 raise RuntimeError("Testcase requires SESSION_ID environment variable") 957 958 some_cache = ds.DatasetCache(session_id=session_id, size=0) 959 960 # This dataset has 3 records in it only 961 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 962 decode_op = c_vision.Decode() 963 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 964 ds1 = ds1.repeat(4) 965 966 num_iter = 0 967 for _ in ds1.create_dict_iterator(num_epochs=1): 968 num_iter += 1 969 970 logger.info("Number of data in ds1: {} ".format(num_iter)) 971 assert num_iter == 12 972 logger.info("test_cache_nomap_server_workers_100 Ended.\n") 973 974 975@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 976def test_cache_nomap_num_connections_1(): 977 """ 978 Feature: DatasetCache op 979 Description: Test setting num_connections=1 in DatasetCache 980 981 Repeat 982 | 983 Cache 984 | 985 Map(Decode) 986 | 987 TFRecord 988 989 Expectation: Output is equal to the expected output 990 """ 991 logger.info("Test cache nomap num_connections 1") 992 if "SESSION_ID" in os.environ: 993 session_id = int(os.environ['SESSION_ID']) 994 else: 995 raise RuntimeError("Testcase requires SESSION_ID environment variable") 996 997 some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) 998 999 # This dataset has 3 records in it only 1000 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1001 decode_op = c_vision.Decode() 1002 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1003 ds1 = ds1.repeat(4) 1004 1005 num_iter = 0 1006 for _ in ds1.create_dict_iterator(num_epochs=1): 1007 num_iter += 1 1008 1009 logger.info("Number of data in ds1: {} ".format(num_iter)) 1010 assert num_iter == 12 1011 logger.info("test_cache_nomap_num_connections_1 Ended.\n") 1012 1013 1014@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1015def test_cache_nomap_num_connections_100(): 1016 """ 1017 Feature: DatasetCache op 1018 Description: Test setting num_connections=100 in DatasetCache 1019 1020 Repeat 1021 | 1022 Map(Decode) 1023 | 1024 Cache 1025 | 1026 TFRecord 1027 1028 Expectation: Output is equal to the expected output 1029 """ 1030 logger.info("Test cache nomap num_connections 100") 1031 if "SESSION_ID" in os.environ: 1032 session_id = int(os.environ['SESSION_ID']) 1033 else: 1034 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1035 1036 some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) 1037 1038 # This dataset has 3 records in it only 1039 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 1040 decode_op = c_vision.Decode() 1041 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1042 ds1 = ds1.repeat(4) 1043 1044 num_iter = 0 1045 for _ in ds1.create_dict_iterator(num_epochs=1): 1046 num_iter += 1 1047 1048 logger.info("Number of data in ds1: {} ".format(num_iter)) 1049 assert num_iter == 12 1050 logger.info("test_cache_nomap_num_connections_100 Ended.\n") 1051 1052 1053@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1054def test_cache_nomap_prefetch_size_1(): 1055 """ 1056 Feature: DatasetCache op 1057 Description: Test setting prefetch_size=1 in DatasetCache 1058 1059 Repeat 1060 | 1061 Cache 1062 | 1063 Map(Decode) 1064 | 1065 TFRecord 1066 1067 Expectation: Output is equal to the expected output 1068 """ 1069 logger.info("Test cache nomap prefetch_size 1") 1070 if "SESSION_ID" in os.environ: 1071 session_id = int(os.environ['SESSION_ID']) 1072 else: 1073 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1074 1075 some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) 1076 1077 # This dataset has 3 records in it only 1078 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1079 decode_op = c_vision.Decode() 1080 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1081 ds1 = ds1.repeat(4) 1082 1083 num_iter = 0 1084 for _ in ds1.create_dict_iterator(num_epochs=1): 1085 num_iter += 1 1086 1087 logger.info("Number of data in ds1: {} ".format(num_iter)) 1088 assert num_iter == 12 1089 logger.info("test_cache_nomap_prefetch_size_1 Ended.\n") 1090 1091 1092@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1093def test_cache_nomap_prefetch_size_100(): 1094 """ 1095 Feature: DatasetCache op 1096 Description: Test setting prefetch_size=100 in DatasetCache 1097 1098 Repeat 1099 | 1100 Map(Decode) 1101 | 1102 Cache 1103 | 1104 TFRecord 1105 1106 Expectation: Output is equal to the expected output 1107 """ 1108 logger.info("Test cache nomap prefetch_size 100") 1109 if "SESSION_ID" in os.environ: 1110 session_id = int(os.environ['SESSION_ID']) 1111 else: 1112 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1113 1114 some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) 1115 1116 # This dataset has 3 records in it only 1117 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 1118 decode_op = c_vision.Decode() 1119 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1120 ds1 = ds1.repeat(4) 1121 1122 num_iter = 0 1123 for _ in ds1.create_dict_iterator(num_epochs=1): 1124 num_iter += 1 1125 1126 logger.info("Number of data in ds1: {} ".format(num_iter)) 1127 assert num_iter == 12 1128 logger.info("test_cache_nomap_prefetch_size_100 Ended.\n") 1129 1130 1131@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1132def test_cache_nomap_device_que(): 1133 """ 1134 Feature: DatasetCache op 1135 Description: Test Cache with device_que 1136 1137 DeviceQueue 1138 | 1139 EpochCtrl 1140 | 1141 Repeat 1142 | 1143 Map(Decode) 1144 | 1145 Cache 1146 | 1147 TFReader 1148 1149 Expectation: Output is equal to the expected output 1150 """ 1151 logger.info("Test cache nomap device_que") 1152 if "SESSION_ID" in os.environ: 1153 session_id = int(os.environ['SESSION_ID']) 1154 else: 1155 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1156 1157 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1158 1159 # This dataset has 3 records in it only 1160 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1161 decode_op = c_vision.Decode() 1162 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1163 ds1 = ds1.repeat(4) 1164 ds1 = ds1.device_que() 1165 ds1.send() 1166 1167 logger.info("test_cache_nomap_device_que Ended.\n") 1168 1169 1170@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1171def test_cache_nomap_session_destroy(): 1172 """ 1173 Feature: DatasetCache op 1174 Description: Test executing cache_admin -d while the pipeline is running 1175 1176 Repeat 1177 | 1178 Cache 1179 | 1180 RandomDataset 1181 1182 Expectation: Error is raised as expected 1183 """ 1184 logger.info("Test cache nomap session destroy") 1185 if "SESSION_ID" in os.environ: 1186 session_id = int(os.environ['SESSION_ID']) 1187 else: 1188 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1189 1190 schema = ds.Schema() 1191 schema.add_column('image', de_type=mstype.uint8, 1192 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 1193 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 1194 1195 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1196 1197 # User-created sampler here 1198 ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) 1199 ds1 = ds1.repeat() 1200 1201 with pytest.raises(RuntimeError) as e: 1202 num_iter = 0 1203 for _ in ds1.create_dict_iterator(num_epochs=1): 1204 num_iter += 1 1205 assert "Unexpected error" in str(e.value) 1206 1207 logger.info("test_cache_nomap_session_destroy Ended.\n") 1208 1209 1210@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1211def test_cache_nomap_server_stop(): 1212 """ 1213 Feature: DatasetCache op 1214 Description: Test executing cache_admin --stop while the pipeline is running 1215 1216 Repeat 1217 | 1218 Cache 1219 | 1220 RandomDataset 1221 1222 Expectation: Error is raised as expected 1223 """ 1224 logger.info("Test cache nomap server stop") 1225 if "SESSION_ID" in os.environ: 1226 session_id = int(os.environ['SESSION_ID']) 1227 else: 1228 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1229 1230 schema = ds.Schema() 1231 schema.add_column('image', de_type=mstype.uint8, 1232 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 1233 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 1234 1235 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1236 1237 # User-created sampler here 1238 ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache) 1239 ds1 = ds1.repeat() 1240 1241 with pytest.raises(RuntimeError) as e: 1242 num_iter = 0 1243 for _ in ds1.create_dict_iterator(num_epochs=1): 1244 num_iter += 1 1245 assert "Network error. Cache server with port 50052 is unreachable. Make sure the server is running." in \ 1246 str(e.value) 1247 1248 logger.info("test_cache_nomap_server_stop Ended.\n") 1249 1250 1251@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1252def test_cache_nomap_interrupt_and_rerun(): 1253 """ 1254 Feature: DatasetCache op 1255 Description: Test interrupt a running pipeline and then re-use the same Cache to run another pipeline 1256 1257 Cache 1258 | 1259 RandomDataset 1260 1261 Expectation: Error is raised after the interrupt then putput is equal to the expected output after the rerun 1262 """ 1263 logger.info("Test cache nomap interrupt and rerun") 1264 if "SESSION_ID" in os.environ: 1265 session_id = int(os.environ['SESSION_ID']) 1266 else: 1267 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1268 1269 schema = ds.Schema() 1270 schema.add_column('image', de_type=mstype.uint8, 1271 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 1272 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 1273 1274 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1275 1276 # User-created sampler here 1277 ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache) 1278 iter1 = ds1.create_dict_iterator(num_epochs=-1) 1279 1280 num_iter = 0 1281 with pytest.raises(AttributeError) as e: 1282 for _ in iter1: 1283 num_iter += 1 1284 if num_iter == 10: 1285 iter1.stop() 1286 assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) 1287 1288 num_epoch = 2 1289 iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) 1290 epoch_count = 0 1291 for _ in range(num_epoch): 1292 num_iter = 0 1293 for _ in iter2: 1294 num_iter += 1 1295 logger.info("Number of data in ds1: {} ".format(num_iter)) 1296 assert num_iter == 10000 1297 epoch_count += 1 1298 1299 cache_stat = some_cache.get_stat() 1300 assert cache_stat.num_mem_cached == 10000 1301 1302 logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n") 1303 1304 1305@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1306def test_cache_nomap_epoch_ctrl1(): 1307 """ 1308 Feature: DatasetCache op 1309 Description: Test using two-loops method to run several epochs 1310 1311 Map(Decode) 1312 | 1313 Cache 1314 | 1315 TFRecord 1316 1317 Expectation: Output is equal to the expected output 1318 """ 1319 logger.info("Test cache nomap epoch ctrl1") 1320 if "SESSION_ID" in os.environ: 1321 session_id = int(os.environ['SESSION_ID']) 1322 else: 1323 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1324 1325 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1326 1327 # This dataset has 3 records in it only 1328 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 1329 decode_op = c_vision.Decode() 1330 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1331 1332 num_epoch = 5 1333 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1334 1335 epoch_count = 0 1336 for _ in range(num_epoch): 1337 row_count = 0 1338 for _ in iter1: 1339 row_count += 1 1340 logger.info("Number of data in ds1: {} ".format(row_count)) 1341 assert row_count == 3 1342 epoch_count += 1 1343 assert epoch_count == num_epoch 1344 logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n") 1345 1346 1347@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1348def test_cache_nomap_epoch_ctrl2(): 1349 """ 1350 Feature: DatasetCache op 1351 Description: Test using two-loops method with infinite epochs 1352 1353 Cache 1354 | 1355 Map(Decode) 1356 | 1357 TFRecord 1358 1359 Expectation: Output is equal to the expected output 1360 """ 1361 logger.info("Test cache nomap epoch ctrl2") 1362 if "SESSION_ID" in os.environ: 1363 session_id = int(os.environ['SESSION_ID']) 1364 else: 1365 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1366 1367 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1368 1369 # This dataset has 3 records in it only 1370 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1371 decode_op = c_vision.Decode() 1372 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1373 1374 num_epoch = 5 1375 # iter1 will always assume there is a next epoch and never shutdown 1376 iter1 = ds1.create_dict_iterator(num_epochs=-1) 1377 1378 epoch_count = 0 1379 for _ in range(num_epoch): 1380 row_count = 0 1381 for _ in iter1: 1382 row_count += 1 1383 logger.info("Number of data in ds1: {} ".format(row_count)) 1384 assert row_count == 3 1385 epoch_count += 1 1386 assert epoch_count == num_epoch 1387 1388 # manually stop the iterator 1389 iter1.stop() 1390 logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n") 1391 1392 1393@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1394def test_cache_nomap_epoch_ctrl3(): 1395 """ 1396 Feature: DatasetCache op 1397 Description: Test using two-loops method with infinite epochs over Repeat op 1398 1399 Repeat 1400 | 1401 Map(Decode) 1402 | 1403 Cache 1404 | 1405 TFRecord 1406 1407 Expectation: Output is equal to the expected output 1408 """ 1409 logger.info("Test cache nomap epoch ctrl3") 1410 if "SESSION_ID" in os.environ: 1411 session_id = int(os.environ['SESSION_ID']) 1412 else: 1413 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1414 1415 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1416 1417 # This dataset has 3 records in it only 1418 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 1419 decode_op = c_vision.Decode() 1420 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1421 ds1 = ds1.repeat(2) 1422 1423 num_epoch = 5 1424 # iter1 will always assume there is a next epoch and never shutdown 1425 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1426 1427 epoch_count = 0 1428 for _ in range(num_epoch): 1429 row_count = 0 1430 for _ in iter1: 1431 row_count += 1 1432 logger.info("Number of data in ds1: {} ".format(row_count)) 1433 assert row_count == 6 1434 epoch_count += 1 1435 assert epoch_count == num_epoch 1436 1437 # reply on garbage collector to destroy iter1 1438 1439 logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") 1440 1441 1442@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1443def test_cache_nomap_epoch_ctrl4(): 1444 """ 1445 Feature: DatasetCache op 1446 Description: Test using two-loops method with Repeat under Cache 1447 1448 Cache 1449 | 1450 Map(Decode) 1451 | 1452 Repeat 1453 | 1454 TFRecord 1455 1456 Expectation: Output is equal to the expected output 1457 """ 1458 logger.info("Test cache nomap epoch ctrl4") 1459 if "SESSION_ID" in os.environ: 1460 session_id = int(os.environ['SESSION_ID']) 1461 else: 1462 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1463 1464 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1465 1466 # This dataset has 3 records in it only 1467 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1468 ds1 = ds1.repeat(2) 1469 decode_op = c_vision.Decode() 1470 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1471 1472 num_epoch = 5 1473 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1474 1475 epoch_count = 0 1476 for _ in range(num_epoch): 1477 row_count = 0 1478 for _ in iter1: 1479 row_count += 1 1480 logger.info("Number of data in ds1: {} ".format(row_count)) 1481 assert row_count == 6 1482 epoch_count += 1 1483 assert epoch_count == num_epoch 1484 1485 logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n") 1486 1487 1488@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1489def test_cache_nomap_multiple_cache1(): 1490 """ 1491 Feature: DatasetCache op 1492 Description: Test multiple Cache in the same python script 1493 1494 Cache Cache 1495 | | 1496 Map(Decode) Map(Decode) 1497 | | 1498 TFRecord(train) TFRecord(eval) 1499 1500 Expectation: Output is equal to the expected output 1501 """ 1502 logger.info("Test cache nomap multiple cache 1") 1503 if "SESSION_ID" in os.environ: 1504 session_id = int(os.environ['SESSION_ID']) 1505 else: 1506 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1507 1508 train_cache = ds.DatasetCache(session_id=session_id, size=0) 1509 eval_cache = ds.DatasetCache(session_id=session_id, size=0) 1510 1511 # This dataset has 12 records in it 1512 train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) 1513 decode_op = c_vision.Decode() 1514 train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) 1515 1516 # This dataset has 3 records in it only 1517 eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1518 eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache) 1519 1520 num_epoch = 5 1521 train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch) 1522 eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch) 1523 1524 epoch_count = 0 1525 for _ in range(num_epoch): 1526 assert sum([1 for _ in train_iter]) == 12 1527 assert sum([1 for _ in eval_iter]) == 3 1528 epoch_count += 1 1529 assert epoch_count == num_epoch 1530 1531 logger.info("test_cache_nomap_multiple_cache1 Ended.\n") 1532 1533 1534@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1535def test_cache_nomap_multiple_cache2(): 1536 """ 1537 Feature: DatasetCache op 1538 Description: Test multiple Cache in the same Python script 1539 1540 Cache 1541 | 1542 Map(Decode) Cache 1543 | | 1544 TFRecord(image) TFRecord(text) 1545 1546 Expectation: Output is equal to the expected output 1547 """ 1548 logger.info("Test cache nomap multiple cache 2") 1549 if "SESSION_ID" in os.environ: 1550 session_id = int(os.environ['SESSION_ID']) 1551 else: 1552 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1553 1554 image_cache = ds.DatasetCache(session_id=session_id, size=0) 1555 text_cache = ds.DatasetCache(session_id=session_id, size=0) 1556 1557 # This dataset has 3 records in it only 1558 image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1559 decode_op = c_vision.Decode() 1560 image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) 1561 1562 # This dataset has 3 records in it only 1563 text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache) 1564 1565 num_epoch = 5 1566 image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) 1567 text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1568 1569 epoch_count = 0 1570 for _ in range(num_epoch): 1571 row_count = 0 1572 for _, _ in itertools.zip_longest(image_iter, text_iter): 1573 row_count += 1 1574 assert row_count == 3 1575 epoch_count += 1 1576 assert epoch_count == num_epoch 1577 1578 logger.info("test_cache_nomap_multiple_cache2 Ended.\n") 1579 1580 1581@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1582def test_cache_nomap_multiple_cache3(): 1583 """ 1584 Feature: DatasetCache op 1585 Description: Test multiple Cache in the same Python script 1586 1587 Cache Cache 1588 | | 1589 Map(Decode) Map(Decode) 1590 | | 1591 TFRecord ImageFolder 1592 1593 Expectation: Output is equal to the expected output 1594 """ 1595 1596 logger.info("Test cache nomap multiple cache 3") 1597 if "SESSION_ID" in os.environ: 1598 session_id = int(os.environ['SESSION_ID']) 1599 else: 1600 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1601 1602 tf_cache = ds.DatasetCache(session_id=session_id, size=0) 1603 image_cache = ds.DatasetCache(session_id=session_id, size=0) 1604 1605 # This dataset has 3 records in it only 1606 tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1607 decode_op = c_vision.Decode() 1608 tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache) 1609 1610 # This DATA_DIR only has 2 images in it 1611 image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR) 1612 image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache) 1613 1614 num_epoch = 5 1615 tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch) 1616 image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch) 1617 1618 epoch_count = 0 1619 for _ in range(num_epoch): 1620 assert sum([1 for _ in tf_iter]) == 3 1621 assert sum([1 for _ in image_iter]) == 2 1622 epoch_count += 1 1623 assert epoch_count == num_epoch 1624 1625 logger.info("test_cache_nomap_multiple_cache3 Ended.\n") 1626 1627 1628@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1629def test_cache_nomap_multiple_cache_train(): 1630 """ 1631 Feature: DatasetCache op 1632 Description: Test multi Cache in different Python scripts. 1633 Runs concurrently with test_cache_nomap_multiple_cache_eval 1634 1635 Cache 1636 | 1637 Map(Decode) 1638 | 1639 TFRecord(train) 1640 1641 Expectation: Output is equal to the expected output 1642 """ 1643 logger.info("Test cache nomap multiple cache train") 1644 if "SESSION_ID" in os.environ: 1645 session_id = int(os.environ['SESSION_ID']) 1646 else: 1647 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1648 1649 train_cache = ds.DatasetCache(session_id=session_id, size=0) 1650 1651 # This dataset has 12 records in it 1652 train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR) 1653 decode_op = c_vision.Decode() 1654 train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache) 1655 1656 num_epoch = 5 1657 train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch) 1658 1659 epoch_count = 0 1660 for _ in range(num_epoch): 1661 assert sum([1 for _ in train_iter]) == 12 1662 epoch_count += 1 1663 assert epoch_count == num_epoch 1664 1665 logger.info("test_cache_nomap_multiple_cache_train Ended.\n") 1666 1667 1668@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1669def test_cache_nomap_multiple_cache_eval(): 1670 """ 1671 Feature: DatasetCache op 1672 Description: Test multi Cache in different Python scripts. 1673 Runs concurrently with test_cache_nomap_multiple_cache_eval 1674 1675 Cache 1676 | 1677 Map(Decode) 1678 | 1679 TFRecord(eval) 1680 1681 Expectation: Output is equal to the expected output 1682 """ 1683 logger.info("Test cache nomap multiple cache eval") 1684 if "SESSION_ID" in os.environ: 1685 session_id = int(os.environ['SESSION_ID']) 1686 else: 1687 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1688 1689 eval_cache = ds.DatasetCache(session_id=session_id, size=0) 1690 1691 # This dataset only has 3 records in it 1692 eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1693 decode_op = c_vision.Decode() 1694 eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache) 1695 1696 num_epoch = 5 1697 eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch) 1698 1699 epoch_count = 0 1700 for _ in range(num_epoch): 1701 assert sum([1 for _ in eval_iter]) == 3 1702 epoch_count += 1 1703 assert epoch_count == num_epoch 1704 1705 logger.info("test_cache_nomap_multiple_cache_eval Ended.\n") 1706 1707 1708@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1709def test_cache_nomap_clue1(): 1710 """ 1711 Feature: DatasetCache op 1712 Description: Test CLUEDataset (a non mappable dataset) with a Cache over it just after the leaf 1713 In this one, the CLUEDataset will be given sharding configuration, however since a Cache is 1714 used, the tree prepare should undo the sharding configuration and instead, a distributed 1715 sampler will be chosen with the same shard config. 1716 1717 Cache 1718 | 1719 CLUE 1720 1721 Expectation: Output is equal to the expected output 1722 """ 1723 logger.info("Test cache nomap clue 1") 1724 if "SESSION_ID" in os.environ: 1725 session_id = int(os.environ['SESSION_ID']) 1726 else: 1727 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1728 1729 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1730 1731 # With only 3 records shard into 3, we expect only 1 record returned for this shard 1732 # However, the sharding will be done by the sampler, not by the clue leaf node 1733 # In this case, it is a row-based sharding, not the file-based sharding that would happen if 1734 # there was not any cache. 1735 ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache) 1736 1737 num_epoch = 4 1738 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1739 1740 epoch_count = 0 1741 for _ in range(num_epoch): 1742 assert sum([1 for _ in iter1]) == 1 1743 epoch_count += 1 1744 assert epoch_count == num_epoch 1745 1746 logger.info("test_cache_nomap_clue1 Ended.\n") 1747 1748 1749@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1750def test_cache_nomap_clue2(): 1751 """ 1752 Feature: DatasetCache op 1753 Description: Test CLUEDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given 1754 1755 Cache 1756 | 1757 Map(lambda x: x) 1758 | 1759 CLUE 1760 1761 Expectation: Output is equal to the expected output 1762 """ 1763 logger.info("Test cache nomap clue 2") 1764 if "SESSION_ID" in os.environ: 1765 session_id = int(os.environ['SESSION_ID']) 1766 else: 1767 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1768 1769 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1770 1771 ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) 1772 ds1 = ds1.map(vision.not_random(lambda x: x), ["label"], cache=some_cache) 1773 1774 num_epoch = 4 1775 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1776 1777 epoch_count = 0 1778 for _ in range(num_epoch): 1779 assert sum([1 for _ in iter1]) == 2 1780 epoch_count += 1 1781 assert epoch_count == num_epoch 1782 1783 logger.info("test_cache_nomap_clue2 Ended.\n") 1784 1785 1786@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1787def test_cache_nomap_csv1(): 1788 """ 1789 Feature: DatasetCache op 1790 Description: Test CSVDataset (a non mappable dataset) with a Cache over it just after the leaf 1791 In this one, the CSVDataset will be given sharding configuration, however since a Cache is 1792 used, the tree prepare should undo the sharding configuration and instead, a distributed 1793 sampler will be chosen with the same shard config. 1794 1795 Cache 1796 | 1797 CSV 1798 1799 Expectation: Output is equal to the expected output 1800 """ 1801 logger.info("Test cache nomap csv 1") 1802 if "SESSION_ID" in os.environ: 1803 session_id = int(os.environ['SESSION_ID']) 1804 else: 1805 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1806 1807 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1808 1809 # With only 3 records shard into 3, we expect only 1 record returned for this shard 1810 # However, the sharding will be done by the sampler, not by the clue leaf node 1811 # In this case, it is a row-based sharding, not the file-based sharding that would happen if 1812 # there was not any cache. 1813 ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], 1814 column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache) 1815 1816 num_epoch = 4 1817 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1818 1819 epoch_count = 0 1820 for _ in range(num_epoch): 1821 assert sum([1 for _ in iter1]) == 1 1822 epoch_count += 1 1823 assert epoch_count == num_epoch 1824 1825 logger.info("test_cache_nomap_csv1 Ended.\n") 1826 1827 1828@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1829def test_cache_nomap_csv2(): 1830 """ 1831 Feature: DatasetCache op 1832 Description: Test CSVDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given 1833 1834 Cache 1835 | 1836 Map(lambda x: x) 1837 | 1838 CSV 1839 1840 Expectation: Output is equal to the expected output 1841 """ 1842 logger.info("Test cache nomap csv 2") 1843 if "SESSION_ID" in os.environ: 1844 session_id = int(os.environ['SESSION_ID']) 1845 else: 1846 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1847 1848 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1849 1850 ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], 1851 column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) 1852 ds1 = ds1.map(vision.not_random(lambda x: x), ["col1"], cache=some_cache) 1853 1854 num_epoch = 4 1855 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1856 1857 epoch_count = 0 1858 for _ in range(num_epoch): 1859 assert sum([1 for _ in iter1]) == 2 1860 epoch_count += 1 1861 assert epoch_count == num_epoch 1862 1863 logger.info("test_cache_nomap_csv2 Ended.\n") 1864 1865 1866@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1867def test_cache_nomap_textfile1(): 1868 """ 1869 Feature: DatasetCache op 1870 Description: Test TextFileDataset (a non mappable dataset) with a Cache over it just after the leaf 1871 In this one, the text file dataset will be given sharding configuration, however since a Cache is 1872 used, the tree prepare should undo the sharding configuration and instead, a distributed 1873 sampler will be chosen with the same shard config. 1874 1875 Cache 1876 | 1877 TextFile 1878 1879 Expectation: Output is equal to the expected output 1880 """ 1881 logger.info("Test cache nomap textfile 1") 1882 if "SESSION_ID" in os.environ: 1883 session_id = int(os.environ['SESSION_ID']) 1884 else: 1885 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1886 1887 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1888 1889 # With only 3 records shard into 3, we expect only 1 record returned for this shard 1890 # However, the sharding will be done by the sampler, not by the clue leaf node 1891 # In this case, it is a row-based sharding, not the file-based sharding that would happen if 1892 # there was not any cache. 1893 ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache) 1894 1895 num_epoch = 4 1896 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1897 1898 epoch_count = 0 1899 for _ in range(num_epoch): 1900 assert sum([1 for _ in iter1]) == 1 1901 epoch_count += 1 1902 assert epoch_count == num_epoch 1903 1904 logger.info("test_cache_nomap_textfile1 Ended.\n") 1905 1906 1907@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1908def test_cache_nomap_textfile2(): 1909 """ 1910 Feature: DatasetCache op 1911 Description: Test TextFileDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given 1912 1913 Cache 1914 | 1915 Map(Tokenizer) 1916 | 1917 TextFile 1918 1919 Expectation: Output is equal to the expected output 1920 """ 1921 def my_tokenizer(line): 1922 words = line.split() 1923 if not words: 1924 return [""] 1925 return words 1926 1927 logger.info("Test cache nomap textfile 2") 1928 if "SESSION_ID" in os.environ: 1929 session_id = int(os.environ['SESSION_ID']) 1930 else: 1931 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1932 1933 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1934 1935 ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2) 1936 tokenizer = text.PythonTokenizer(my_tokenizer) 1937 ds1 = ds1.map(operations=tokenizer, cache=some_cache) 1938 1939 num_epoch = 4 1940 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1941 1942 epoch_count = 0 1943 for _ in range(num_epoch): 1944 assert sum([1 for _ in iter1]) == 2 1945 epoch_count += 1 1946 assert epoch_count == num_epoch 1947 1948 logger.info("test_cache_nomap_textfile2 Ended.\n") 1949 1950 1951@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1952def test_cache_nomap_nested_repeat(): 1953 """ 1954 Feature: DatasetCache op 1955 Description: Test Cache on pipeline with nested Repeat ops 1956 1957 Repeat 1958 | 1959 Cache 1960 | 1961 Map(Decode) 1962 | 1963 Repeat 1964 | 1965 TFRecord 1966 1967 Expectation: Output is equal to the expected output 1968 """ 1969 logger.info("Test cache nomap nested repeat") 1970 if "SESSION_ID" in os.environ: 1971 session_id = int(os.environ['SESSION_ID']) 1972 else: 1973 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1974 1975 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1976 1977 # This dataset has 3 records in it only 1978 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 1979 decode_op = c_vision.Decode() 1980 ds1 = ds1.repeat(4) 1981 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 1982 ds1 = ds1.repeat(2) 1983 1984 num_iter = 0 1985 for _ in ds1.create_dict_iterator(num_epochs=1): 1986 logger.info("get data from dataset") 1987 num_iter += 1 1988 1989 logger.info("Number of data in ds1: {} ".format(num_iter)) 1990 assert num_iter == 24 1991 logger.info('test_cache_nomap_nested_repeat Ended.\n') 1992 1993 1994@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1995def test_cache_nomap_get_repeat_count(): 1996 """ 1997 Feature: DatasetCache op 1998 Description: Test get_repeat_count for a pipeline with Cache and nested repeat ops 1999 2000 Cache 2001 | 2002 Map(Decode) 2003 | 2004 Repeat 2005 | 2006 TFRecord 2007 2008 Expectation: Output is equal to the expected output 2009 """ 2010 logger.info("Test cache nomap get_repeat_count") 2011 if "SESSION_ID" in os.environ: 2012 session_id = int(os.environ['SESSION_ID']) 2013 else: 2014 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2015 2016 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2017 2018 # This dataset has 3 records in it only 2019 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 2020 ds1 = ds1.repeat(4) 2021 decode_op = c_vision.Decode() 2022 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 2023 2024 repeat_count = ds1.get_repeat_count() 2025 logger.info("repeat_count: {}".format(repeat_count)) 2026 assert repeat_count == 4 2027 2028 num_iter = 0 2029 for _ in ds1.create_dict_iterator(num_epochs=1): 2030 logger.info("get data from dataset") 2031 num_iter += 1 2032 assert num_iter == 12 2033 2034 2035@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2036def test_cache_nomap_long_file_list(): 2037 """ 2038 Feature: DatasetCache op 2039 Description: Test Cache after TFRecord with a long list of files as arguments 2040 2041 Cache 2042 | 2043 TFRecord 2044 2045 Expectation: Error is raised as expected 2046 """ 2047 logger.info("Test cache nomap long file list") 2048 if "SESSION_ID" in os.environ: 2049 session_id = int(os.environ['SESSION_ID']) 2050 else: 2051 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2052 2053 some_cache = ds.DatasetCache(session_id=session_id, size=1) 2054 2055 ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"], 2056 cache=some_cache) 2057 2058 with pytest.raises(RuntimeError) as e: 2059 sum([1 for _ in ds1]) 2060 assert "Out of memory" in str(e.value) 2061 logger.info("test_cache_nomap_long_file_list Ended.\n") 2062 2063 2064@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2065def test_cache_nomap_failure1(): 2066 """ 2067 Feature: DatasetCache op 2068 Description: Test nested Cache 2069 2070 Repeat 2071 | 2072 Cache 2073 | 2074 Map(Decode) 2075 | 2076 Cache 2077 | 2078 TFRecord 2079 2080 Expectation: Error is raised as expected 2081 """ 2082 logger.info("Test cache nomap failure 1") 2083 if "SESSION_ID" in os.environ: 2084 session_id = int(os.environ['SESSION_ID']) 2085 else: 2086 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2087 2088 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2089 2090 # This dataset has 3 records in it only 2091 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) 2092 decode_op = c_vision.Decode() 2093 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 2094 ds1 = ds1.repeat(4) 2095 2096 with pytest.raises(RuntimeError) as e: 2097 ds1.get_batch_size() 2098 assert "Nested cache operations" in str(e.value) 2099 2100 with pytest.raises(RuntimeError) as e: 2101 num_iter = 0 2102 for _ in ds1.create_dict_iterator(num_epochs=1): 2103 num_iter += 1 2104 assert "Nested cache operations" in str(e.value) 2105 2106 assert num_iter == 0 2107 logger.info('test_cache_nomap_failure1 Ended.\n') 2108 2109 2110@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2111def test_cache_nomap_failure2(): 2112 """ 2113 Feature: DatasetCache op 2114 Description: Test Zip under Cache 2115 2116 Repeat 2117 | 2118 Cache 2119 | 2120 Map(Decode) 2121 | 2122 Zip 2123 | | 2124 Random Random 2125 2126 Expectation: Error is raised as expected 2127 """ 2128 logger.info("Test cache nomap failure 2") 2129 if "SESSION_ID" in os.environ: 2130 session_id = int(os.environ['SESSION_ID']) 2131 else: 2132 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2133 2134 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2135 2136 schema = ds.Schema() 2137 schema.add_column('image', de_type=mstype.uint8, 2138 shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) 2139 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 2140 2141 ds1 = ds.RandomDataset(schema=schema) 2142 ds2 = ds.RandomDataset(schema=schema) 2143 dsz = ds.zip((ds1, ds2)) 2144 decode_op = c_vision.Decode() 2145 dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache) 2146 dsz = dsz.repeat(4) 2147 2148 with pytest.raises(RuntimeError) as e: 2149 num_iter = 0 2150 for _ in dsz.create_dict_iterator(num_epochs=1): 2151 num_iter += 1 2152 assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) 2153 2154 assert num_iter == 0 2155 logger.info('test_cache_nomap_failure2 Ended.\n') 2156 2157 2158@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2159def test_cache_nomap_failure3(): 2160 """ 2161 Feature: DatasetCache op 2162 Description: Test Batch under Cache 2163 2164 Repeat 2165 | 2166 Cache 2167 | 2168 Map(Resize) 2169 | 2170 Batch 2171 | 2172 Clue 2173 2174 Expectation: Error is raised as expected 2175 """ 2176 logger.info("Test cache nomap failure 3") 2177 if "SESSION_ID" in os.environ: 2178 session_id = int(os.environ['SESSION_ID']) 2179 else: 2180 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2181 2182 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2183 2184 ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train') 2185 ds1 = ds1.batch(2) 2186 resize_op = c_vision.Resize((224, 224)) 2187 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 2188 ds1 = ds1.repeat(4) 2189 2190 with pytest.raises(RuntimeError) as e: 2191 num_iter = 0 2192 for _ in ds1.create_dict_iterator(num_epochs=1): 2193 num_iter += 1 2194 assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) 2195 2196 assert num_iter == 0 2197 logger.info('test_cache_nomap_failure3 Ended.\n') 2198 2199 2200@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2201def test_cache_nomap_failure4(): 2202 """ 2203 Feature: DatasetCache op 2204 Description: Test Filter under Cache 2205 2206 Repeat 2207 | 2208 Cache 2209 | 2210 Map(Decode) 2211 | 2212 Filter 2213 | 2214 CSV 2215 2216 Expectation: Error is raised as expected 2217 """ 2218 logger.info("Test cache nomap failure 4") 2219 if "SESSION_ID" in os.environ: 2220 session_id = int(os.environ['SESSION_ID']) 2221 else: 2222 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2223 2224 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2225 2226 ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], 2227 column_names=['col1', 'col2', 'col3', 'col4']) 2228 ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) 2229 2230 decode_op = c_vision.Decode() 2231 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 2232 ds1 = ds1.repeat(4) 2233 2234 with pytest.raises(RuntimeError) as e: 2235 num_iter = 0 2236 for _ in ds1.create_dict_iterator(num_epochs=1): 2237 num_iter += 1 2238 assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) 2239 2240 assert num_iter == 0 2241 logger.info('test_cache_nomap_failure4 Ended.\n') 2242 2243 2244@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2245def test_cache_nomap_failure5(): 2246 """ 2247 Feature: DatasetCache op 2248 Description: Test Map containing Random operation under Cache 2249 2250 Repeat 2251 | 2252 Cache 2253 | 2254 Map(Decode, RandomCrop) 2255 | 2256 TextFile 2257 2258 Expectation: Error is raised as expected 2259 """ 2260 logger.info("Test cache nomap failure 5") 2261 if "SESSION_ID" in os.environ: 2262 session_id = int(os.environ['SESSION_ID']) 2263 else: 2264 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2265 2266 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2267 2268 data = ds.TextFileDataset(TEXT_FILE_DATA_DIR) 2269 random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) 2270 decode_op = c_vision.Decode() 2271 2272 data = data.map(input_columns=["image"], operations=decode_op) 2273 data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache) 2274 data = data.repeat(4) 2275 2276 with pytest.raises(RuntimeError) as e: 2277 num_iter = 0 2278 for _ in data.create_dict_iterator(num_epochs=1): 2279 num_iter += 1 2280 assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) 2281 2282 assert num_iter == 0 2283 logger.info('test_cache_nomap_failure5 Ended.\n') 2284 2285 2286@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2287def test_cache_nomap_pyfunc_lambda(): 2288 """ 2289 Feature: DatasetCache op 2290 Description: Test cache after Map op with a Python lambda function 2291 2292 Cache 2293 | 2294 Map(lambda function1, lambda function2) 2295 | 2296 TFRecord 2297 2298 Expectation: Only success if the lambda function is wrapped by 'pyvision.not_random', otherwise error is raised 2299 """ 2300 logger.info("Test cache nomap pyfunc lambda") 2301 if "SESSION_ID" in os.environ: 2302 session_id = int(os.environ['SESSION_ID']) 2303 else: 2304 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2305 2306 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2307 2308 # This dataset has 12 records in it 2309 data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) 2310 transforms = [vision.not_random(lambda x: x + x), vision.not_random(lambda x: x - 1)] 2311 data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache) 2312 2313 num_iter = 0 2314 for _ in data1.create_dict_iterator(num_epochs=1): 2315 num_iter += 1 2316 assert num_iter == 12 2317 2318 other_cache = ds.DatasetCache(session_id=session_id, size=0) 2319 ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) 2320 ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache) 2321 2322 with pytest.raises(RuntimeError) as e: 2323 num_iter = 0 2324 for _ in ds2.create_dict_iterator(num_epochs=1): 2325 num_iter += 1 2326 assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) 2327 logger.info("test_cache_nomap_pyfunc_lambda Ended.\n") 2328 2329 2330@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2331def test_cache_nomap_pyfunc_builtin(): 2332 """ 2333 Feature: DatasetCache op 2334 Description: Test Cache after Map op with a Python builtin PyFunc 2335 2336 Cache 2337 | 2338 Map([builtin pyfunc1, builtin pyfunc2]) 2339 | 2340 TFRecord 2341 2342 Expectation: Error will be raised if the builtin PyFunc containing Random op, otherwise runs successfully 2343 """ 2344 logger.info("Test cache nomap pyfunc builtin") 2345 if "SESSION_ID" in os.environ: 2346 session_id = int(os.environ['SESSION_ID']) 2347 else: 2348 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2349 2350 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2351 # This dataset has 3 records in it only 2352 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 2353 ds1 = ds1.map(operations=[vision.Decode(), vision.ToTensor()], input_columns=["image"], cache=some_cache) 2354 2355 num_iter = 0 2356 for _ in ds1.create_dict_iterator(num_epochs=1): 2357 num_iter += 1 2358 assert num_iter == 3 2359 2360 other_cache = ds.DatasetCache(session_id=session_id, size=0) 2361 # This dataset has 3 records in it only 2362 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 2363 ds2 = ds2.map(operations=[vision.Decode(), vision.RandomCrop(224), vision.ToTensor()], 2364 input_columns=["image"], cache=other_cache) 2365 2366 with pytest.raises(RuntimeError) as e: 2367 num_iter = 0 2368 for _ in ds2.create_dict_iterator(num_epochs=1): 2369 num_iter += 1 2370 assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) 2371 logger.info("test_cache_nomap_pyfunc_builtin Ended.\n") 2372 2373 2374@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2375def test_cache_nomap_pyfunc_function(): 2376 """ 2377 Feature: DatasetCache op 2378 Description: Test Cache after Map op with a Python customized function 2379 2380 Cache 2381 | 2382 Map([function1, function2]) 2383 | 2384 TFRecord 2385 2386 Expectation: Only success if the function is decorated with 'vision.not_random', otherwise an error will be raised 2387 """ 2388 @vision.not_random 2389 def not_random_func(x): 2390 return np.ones(x.shape, dtype=x.dtype) 2391 2392 def normal_func(x): 2393 return np.ones(x.shape, dtype=x.dtype) 2394 2395 logger.info("Test cache nomap pyfunc function") 2396 if "SESSION_ID" in os.environ: 2397 session_id = int(os.environ['SESSION_ID']) 2398 else: 2399 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2400 2401 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2402 # This dataset has 3 records in it only 2403 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 2404 ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache) 2405 2406 num_iter = 0 2407 for _ in ds1.create_dict_iterator(num_epochs=1): 2408 num_iter += 1 2409 assert num_iter == 3 2410 2411 other_cache = ds.DatasetCache(session_id=session_id, size=0) 2412 # This dataset has 3 records in it only 2413 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 2414 ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache) 2415 2416 with pytest.raises(RuntimeError) as e: 2417 num_iter = 0 2418 for _ in ds2.create_dict_iterator(num_epochs=1): 2419 num_iter += 1 2420 assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) 2421 logger.info("test_cache_nomap_pyfunc_function Ended.\n") 2422 2423 2424@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2425def test_cache_nomap_all_rows_cached(): 2426 """ 2427 Feature: DatasetCache op 2428 Description: Test making sure all rows are cached before we switch to the fetching phase 2429 2430 Cache 2431 | 2432 RandomDataset 2433 2434 Expectation: Output is equal to the expected output 2435 """ 2436 logger.info("Test cache nomap all rows cached") 2437 if "SESSION_ID" in os.environ: 2438 session_id = int(os.environ['SESSION_ID']) 2439 else: 2440 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2441 2442 schema = ds.Schema() 2443 schema.add_column('image', de_type=mstype.uint8, 2444 shape=[450, 450, 3]) 2445 schema.add_column('label', de_type=mstype.uint8, shape=[1]) 2446 2447 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2448 2449 # easier to reproduce the problem with 271 total rows 2450 num_total_rows = 271 2451 # User-created sampler here 2452 ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache) 2453 iter1 = ds1.create_dict_iterator(num_epochs=1) 2454 2455 num_iter = 0 2456 for _ in iter1: 2457 num_iter += 1 2458 logger.info("Number of data in ds1: {} ".format(num_iter)) 2459 assert num_iter == num_total_rows 2460 2461 cache_stat = some_cache.get_stat() 2462 assert cache_stat.num_mem_cached == num_total_rows 2463 2464 logger.info("test_cache_nomap_all_rows_cached Ended.\n") 2465 2466 2467@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2468def test_cache_nomap_dataset_size1(): 2469 """ 2470 Feature: DatasetCache op 2471 Description: Test get_dataset_size when Cache is injected directly after a non-mappable leaf 2472 2473 Cache 2474 | 2475 TFRecord 2476 2477 Expectation: Output is equal to the expected output 2478 """ 2479 logger.info("Test cache nomap dataset size 1") 2480 if "SESSION_ID" in os.environ: 2481 session_id = int(os.environ['SESSION_ID']) 2482 else: 2483 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2484 2485 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2486 2487 # This dataset has 3 records in it only 2488 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0, cache=some_cache) 2489 2490 dataset_size = ds1.get_dataset_size() 2491 assert dataset_size == 2 2492 2493 num_iter = 0 2494 for _ in ds1.create_dict_iterator(num_epochs=1): 2495 num_iter += 1 2496 2497 logger.info("Number of data in ds1: {} ".format(num_iter)) 2498 assert num_iter == dataset_size 2499 logger.info("test_cache_nomap_dataset_size1 Ended.\n") 2500 2501 2502@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2503def test_cache_nomap_dataset_size2(): 2504 """ 2505 Feature: DatasetCache op 2506 Description: Test get_dataset_size when Cache is injected after Map 2507 2508 Cache 2509 | 2510 Map(Decode) 2511 | 2512 TFRecord 2513 2514 Expectation: Output is equal to the expected output 2515 """ 2516 logger.info("Test cache nomap dataset size 2") 2517 if "SESSION_ID" in os.environ: 2518 session_id = int(os.environ['SESSION_ID']) 2519 else: 2520 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2521 2522 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2523 2524 # This dataset has 3 records in it only 2525 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0) 2526 decode_op = c_vision.Decode() 2527 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 2528 2529 dataset_size = ds1.get_dataset_size() 2530 assert dataset_size == 2 2531 2532 num_iter = 0 2533 for _ in ds1.create_dict_iterator(num_epochs=1): 2534 num_iter += 1 2535 2536 logger.info("Number of data in ds1: {} ".format(num_iter)) 2537 assert num_iter == dataset_size 2538 logger.info("test_cache_nomap_dataset_size2 Ended.\n") 2539 2540 2541if __name__ == '__main__': 2542 # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' 2543 # since cache server is required to be brought up first 2544 test_cache_nomap_basic1() 2545 test_cache_nomap_basic2() 2546 test_cache_nomap_basic3() 2547 test_cache_nomap_basic4() 2548 test_cache_nomap_basic5() 2549 test_cache_nomap_basic6() 2550 test_cache_nomap_basic7() 2551 test_cache_nomap_basic8() 2552 test_cache_nomap_basic9() 2553 test_cache_nomap_allowed_share1() 2554 test_cache_nomap_allowed_share2() 2555 test_cache_nomap_allowed_share3() 2556 test_cache_nomap_allowed_share4() 2557 test_cache_nomap_disallowed_share1() 2558 test_cache_nomap_running_twice1() 2559 test_cache_nomap_running_twice2() 2560 test_cache_nomap_extra_small_size1() 2561 test_cache_nomap_extra_small_size2() 2562 test_cache_nomap_parallel_pipeline1(shard=0) 2563 test_cache_nomap_parallel_pipeline2(shard=1) 2564 test_cache_nomap_parallel_workers() 2565 test_cache_nomap_server_workers_1() 2566 test_cache_nomap_server_workers_100() 2567 test_cache_nomap_num_connections_1() 2568 test_cache_nomap_num_connections_100() 2569 test_cache_nomap_prefetch_size_1() 2570 test_cache_nomap_prefetch_size_100() 2571 test_cache_nomap_device_que() 2572 test_cache_nomap_session_destroy() 2573 test_cache_nomap_server_stop() 2574 test_cache_nomap_epoch_ctrl1() 2575 test_cache_nomap_epoch_ctrl2() 2576 test_cache_nomap_epoch_ctrl3() 2577 test_cache_nomap_epoch_ctrl4() 2578 test_cache_nomap_multiple_cache1() 2579 test_cache_nomap_multiple_cache2() 2580 test_cache_nomap_multiple_cache3() 2581 test_cache_nomap_multiple_cache_train() 2582 test_cache_nomap_multiple_cache_eval() 2583 test_cache_nomap_clue1() 2584 test_cache_nomap_clue2() 2585 test_cache_nomap_csv1() 2586 test_cache_nomap_csv2() 2587 test_cache_nomap_textfile1() 2588 test_cache_nomap_textfile2() 2589 test_cache_nomap_nested_repeat() 2590 test_cache_nomap_get_repeat_count() 2591 test_cache_nomap_long_file_list() 2592 test_cache_nomap_failure1() 2593 test_cache_nomap_failure2() 2594 test_cache_nomap_failure3() 2595 test_cache_nomap_failure4() 2596 test_cache_nomap_failure5() 2597 test_cache_nomap_pyfunc_lambda() 2598 test_cache_nomap_pyfunc_builtin() 2599 test_cache_nomap_pyfunc_function() 2600 test_cache_nomap_dataset_size1() 2601 test_cache_nomap_dataset_size2() 2602