• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import mindspore.dataset as ds
16from mindspore import log as logger
17from util import save_and_check_dict
18
19DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
20SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
21GENERATE_GOLDEN = False
22
23
24def test_2ops_repeat_shuffle():
25    """
26    Test Repeat then Shuffle
27    """
28    logger.info("Test Repeat then Shuffle")
29    # define parameters
30    repeat_count = 2
31    buffer_size = 5
32    seed = 0
33
34    # apply dataset operations
35    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
36    data1 = data1.repeat(repeat_count)
37    ds.config.set_seed(seed)
38    data1 = data1.shuffle(buffer_size=buffer_size)
39
40    filename = "test_2ops_repeat_shuffle.npz"
41    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
42
43
44def test_2ops_shuffle_repeat():
45    """
46    Test Shuffle then Repeat
47    """
48    logger.info("Test Shuffle then Repeat")
49    # define parameters
50    repeat_count = 2
51    buffer_size = 5
52    seed = 0
53
54    # apply dataset operations
55    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
56    ds.config.set_seed(seed)
57    data1 = data1.shuffle(buffer_size=buffer_size)
58    data1 = data1.repeat(repeat_count)
59
60    filename = "test_2ops_shuffle_repeat.npz"
61    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
62
63
64def test_2ops_repeat_batch():
65    """
66    Test Repeat then Batch
67    """
68    logger.info("Test Repeat then Batch")
69    # define parameters
70    repeat_count = 2
71    batch_size = 5
72
73    # apply dataset operations
74    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
75    data1 = data1.repeat(repeat_count)
76    data1 = data1.batch(batch_size, drop_remainder=True)
77
78    filename = "test_2ops_repeat_batch.npz"
79    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
80
81
82def test_2ops_batch_repeat():
83    """
84    Test Batch then Repeat
85    """
86    logger.info("Test Batch then Repeat")
87    # define parameters
88    repeat_count = 2
89    batch_size = 5
90
91    # apply dataset operations
92    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
93    data1 = data1.batch(batch_size, drop_remainder=True)
94    data1 = data1.repeat(repeat_count)
95
96    filename = "test_2ops_batch_repeat.npz"
97    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
98
99
100def test_2ops_batch_shuffle():
101    """
102    Test Batch then Shuffle
103    """
104    logger.info("Test Batch then Shuffle")
105    # define parameters
106    buffer_size = 5
107    seed = 0
108    batch_size = 2
109
110    # apply dataset operations
111    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
112    data1 = data1.batch(batch_size, drop_remainder=True)
113    ds.config.set_seed(seed)
114    data1 = data1.shuffle(buffer_size=buffer_size)
115
116    filename = "test_2ops_batch_shuffle.npz"
117    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
118
119
120def test_2ops_shuffle_batch():
121    """
122    Test Shuffle then Batch
123    """
124    logger.info("Test Shuffle then Batch")
125    # define parameters
126    buffer_size = 5
127    seed = 0
128    batch_size = 2
129
130    # apply dataset operations
131    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
132    ds.config.set_seed(seed)
133    data1 = data1.shuffle(buffer_size=buffer_size)
134    data1 = data1.batch(batch_size, drop_remainder=True)
135
136    filename = "test_2ops_shuffle_batch.npz"
137    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
138
139
140if __name__ == '__main__':
141    test_2ops_repeat_shuffle()
142    test_2ops_shuffle_repeat()
143    test_2ops_repeat_batch()
144    test_2ops_batch_repeat()
145    test_2ops_batch_shuffle()
146    test_2ops_shuffle_batch()
147