• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import mindspore.dataset as ds
17from mindspore import log as logger
18from util import save_and_check_dict
19
20# Note: Number of rows in test.data dataset:  12
21DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
22GENERATE_GOLDEN = False
23
24
25def test_shuffle_01():
26    """
27    Test shuffle: buffer_size < number-of-rows-in-dataset
28    """
29    logger.info("test_shuffle_01")
30    # define parameters
31    buffer_size = 5
32    seed = 1
33
34    # apply dataset operations
35    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
36    ds.config.set_seed(seed)
37    data1 = data1.shuffle(buffer_size=buffer_size)
38
39    filename = "shuffle_01_result.npz"
40    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
41
42
43def test_shuffle_02():
44    """
45    Test shuffle: buffer_size = number-of-rows-in-dataset
46    """
47    logger.info("test_shuffle_02")
48    # define parameters
49    buffer_size = 12
50    seed = 1
51
52    # apply dataset operations
53    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
54    ds.config.set_seed(seed)
55    data1 = data1.shuffle(buffer_size=buffer_size)
56
57    filename = "shuffle_02_result.npz"
58    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
59
60
61def test_shuffle_03():
62    """
63    Test shuffle: buffer_size=2 (minimum size), number-of-rows-in-dataset > 2
64    """
65    logger.info("test_shuffle_03")
66    # define parameters
67    buffer_size = 2
68    seed = 1
69
70    # apply dataset operations
71    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
72    ds.config.set_seed(seed)
73    data1 = data1.shuffle(buffer_size)
74
75    filename = "shuffle_03_result.npz"
76    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
77
78
79def test_shuffle_04():
80    """
81    Test shuffle: buffer_size=2 (minimum size), number-of-rows-in-dataset = 2
82    """
83    logger.info("test_shuffle_04")
84    # define parameters
85    buffer_size = 2
86    seed = 1
87
88    # apply dataset operations
89    data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2)
90    ds.config.set_seed(seed)
91    data1 = data1.shuffle(buffer_size=buffer_size)
92
93    filename = "shuffle_04_result.npz"
94    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
95
96
97def test_shuffle_05():
98    """
99    Test shuffle: buffer_size > number-of-rows-in-dataset
100    """
101    logger.info("test_shuffle_05")
102    # define parameters
103    buffer_size = 13
104    seed = 1
105
106    # apply dataset operations
107    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
108    ds.config.set_seed(seed)
109    data1 = data1.shuffle(buffer_size=buffer_size)
110
111    filename = "shuffle_05_result.npz"
112    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
113
114
115def test_shuffle_06():
116    """
117    Test shuffle: with set seed, both datasets
118    """
119    logger.info("test_shuffle_06")
120    # define parameters
121    buffer_size = 13
122    seed = 1
123
124    # apply dataset operations
125    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
126    ds.config.set_seed(seed)
127    data1 = data1.shuffle(buffer_size=buffer_size)
128
129    data2 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
130    data2 = data2.shuffle(buffer_size=buffer_size)
131
132    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
133                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
134        np.testing.assert_equal(item1, item2)
135
136
137def test_shuffle_exception_01():
138    """
139    Test shuffle exception: buffer_size<0
140    """
141    logger.info("test_shuffle_exception_01")
142
143    # apply dataset operations
144    data1 = ds.TFRecordDataset(DATA_DIR)
145    ds.config.set_seed(1)
146    try:
147        data1 = data1.shuffle(buffer_size=-1)
148        sum([1 for _ in data1])
149
150    except Exception as e:
151        logger.info("Got an exception in DE: {}".format(str(e)))
152        assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
153
154
155def test_shuffle_exception_02():
156    """
157    Test shuffle exception: buffer_size=0
158    """
159    logger.info("test_shuffle_exception_02")
160
161    # apply dataset operations
162    data1 = ds.TFRecordDataset(DATA_DIR)
163    ds.config.set_seed(1)
164    try:
165        data1 = data1.shuffle(buffer_size=0)
166        sum([1 for _ in data1])
167
168    except Exception as e:
169        logger.info("Got an exception in DE: {}".format(str(e)))
170        assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
171
172
173def test_shuffle_exception_03():
174    """
175    Test shuffle exception: buffer_size=1
176    """
177    logger.info("test_shuffle_exception_03")
178
179    # apply dataset operations
180    data1 = ds.TFRecordDataset(DATA_DIR)
181    ds.config.set_seed(1)
182    try:
183        data1 = data1.shuffle(buffer_size=1)
184        sum([1 for _ in data1])
185
186    except Exception as e:
187        logger.info("Got an exception in DE: {}".format(str(e)))
188        assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
189
190
191def test_shuffle_exception_05():
192    """
193    Test shuffle exception: Missing mandatory buffer_size input parameter
194    """
195    logger.info("test_shuffle_exception_05")
196
197    # apply dataset operations
198    data1 = ds.TFRecordDataset(DATA_DIR)
199    ds.config.set_seed(1)
200    try:
201        data1 = data1.shuffle()
202        sum([1 for _ in data1])
203
204    except Exception as e:
205        logger.info("Got an exception in DE: {}".format(str(e)))
206        assert "buffer_size" in str(e)
207
208
209def test_shuffle_exception_06():
210    """
211    Test shuffle exception: buffer_size wrong type, boolean value False
212    """
213    logger.info("test_shuffle_exception_06")
214
215    # apply dataset operations
216    data1 = ds.TFRecordDataset(DATA_DIR)
217    ds.config.set_seed(1)
218    try:
219        data1 = data1.shuffle(buffer_size=False)
220        sum([1 for _ in data1])
221
222    except Exception as e:
223        logger.info("Got an exception in DE: {}".format(str(e)))
224        assert "buffer_size" in str(e)
225
226
227def test_shuffle_exception_07():
228    """
229    Test shuffle exception: buffer_size wrong type, boolean value True
230    """
231    logger.info("test_shuffle_exception_07")
232
233    # apply dataset operations
234    data1 = ds.TFRecordDataset(DATA_DIR)
235    ds.config.set_seed(1)
236    try:
237        data1 = data1.shuffle(buffer_size=True)
238        sum([1 for _ in data1])
239
240    except Exception as e:
241        logger.info("Got an exception in DE: {}".format(str(e)))
242        assert "buffer_size" in str(e)
243
244
245if __name__ == '__main__':
246    test_shuffle_01()
247    test_shuffle_02()
248    test_shuffle_03()
249    test_shuffle_04()
250    test_shuffle_05()
251    test_shuffle_06()
252    test_shuffle_exception_01()
253    test_shuffle_exception_02()
254    test_shuffle_exception_03()
255    test_shuffle_exception_05()
256    test_shuffle_exception_06()
257    test_shuffle_exception_07()
258    logger.info('\n')
259