1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import mindspore.dataset as ds 16from mindspore import log as logger 17from util import save_and_check_dict 18 19DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] 20SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 21GENERATE_GOLDEN = False 22 23 24def test_2ops_repeat_shuffle(): 25 """ 26 Test Repeat then Shuffle 27 """ 28 logger.info("Test Repeat then Shuffle") 29 # define parameters 30 repeat_count = 2 31 buffer_size = 5 32 seed = 0 33 34 # apply dataset operations 35 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 36 data1 = data1.repeat(repeat_count) 37 ds.config.set_seed(seed) 38 data1 = data1.shuffle(buffer_size=buffer_size) 39 40 filename = "test_2ops_repeat_shuffle.npz" 41 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 42 43 44def test_2ops_shuffle_repeat(): 45 """ 46 Test Shuffle then Repeat 47 """ 48 logger.info("Test Shuffle then Repeat") 49 # define parameters 50 repeat_count = 2 51 buffer_size = 5 52 seed = 0 53 54 # apply dataset operations 55 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 56 ds.config.set_seed(seed) 57 data1 = data1.shuffle(buffer_size=buffer_size) 58 data1 = data1.repeat(repeat_count) 59 60 filename = "test_2ops_shuffle_repeat.npz" 61 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 62 63 64def test_2ops_repeat_batch(): 65 """ 66 Test Repeat then Batch 67 """ 68 logger.info("Test Repeat then Batch") 69 # define parameters 70 repeat_count = 2 71 batch_size = 5 72 73 # apply dataset operations 74 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 75 data1 = data1.repeat(repeat_count) 76 data1 = data1.batch(batch_size, drop_remainder=True) 77 78 filename = "test_2ops_repeat_batch.npz" 79 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 80 81 82def test_2ops_batch_repeat(): 83 """ 84 Test Batch then Repeat 85 """ 86 logger.info("Test Batch then Repeat") 87 # define parameters 88 repeat_count = 2 89 batch_size = 5 90 91 # apply dataset operations 92 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 93 data1 = data1.batch(batch_size, drop_remainder=True) 94 data1 = data1.repeat(repeat_count) 95 96 filename = "test_2ops_batch_repeat.npz" 97 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 98 99 100def test_2ops_batch_shuffle(): 101 """ 102 Test Batch then Shuffle 103 """ 104 logger.info("Test Batch then Shuffle") 105 # define parameters 106 buffer_size = 5 107 seed = 0 108 batch_size = 2 109 110 # apply dataset operations 111 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 112 data1 = data1.batch(batch_size, drop_remainder=True) 113 ds.config.set_seed(seed) 114 data1 = data1.shuffle(buffer_size=buffer_size) 115 116 filename = "test_2ops_batch_shuffle.npz" 117 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 118 119 120def test_2ops_shuffle_batch(): 121 """ 122 Test Shuffle then Batch 123 """ 124 logger.info("Test Shuffle then Batch") 125 # define parameters 126 buffer_size = 5 127 seed = 0 128 batch_size = 2 129 130 # apply dataset operations 131 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 132 ds.config.set_seed(seed) 133 data1 = data1.shuffle(buffer_size=buffer_size) 134 data1 = data1.batch(batch_size, drop_remainder=True) 135 136 filename = "test_2ops_shuffle_batch.npz" 137 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 138 139 140if __name__ == '__main__': 141 test_2ops_repeat_shuffle() 142 test_2ops_shuffle_repeat() 143 test_2ops_repeat_batch() 144 test_2ops_batch_repeat() 145 test_2ops_batch_shuffle() 146 test_2ops_shuffle_batch() 147