• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Test Repeat Op
17"""
18import numpy as np
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.vision.c_transforms as vision
22from mindspore import log as logger
23from util import save_and_check_dict
24
25DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
26SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
27
28DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31GENERATE_GOLDEN = False
32
33
34def test_tf_repeat_01():
35    """
36    Test a simple repeat operation.
37    """
38    logger.info("Test Simple Repeat")
39    # define parameters
40    repeat_count = 2
41
42    # apply dataset operations
43    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
44    data1 = data1.repeat(repeat_count)
45
46    filename = "repeat_result.npz"
47    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
48
49
50def test_tf_repeat_02():
51    """
52    Test Infinite Repeat.
53    """
54    logger.info("Test Infinite Repeat")
55    # define parameters
56    repeat_count = -1
57
58    # apply dataset operations
59    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
60    data1 = data1.repeat(repeat_count)
61
62    itr = 0
63    for _ in data1:
64        itr = itr + 1
65        if itr == 100:
66            break
67    assert itr == 100
68
69
70def test_tf_repeat_03():
71    """
72    Test Repeat then Batch.
73    """
74    logger.info("Test Repeat then Batch")
75    data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
76
77    batch_size = 32
78    resize_height, resize_width = 32, 32
79    decode_op = vision.Decode()
80    resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
81    data1 = data1.map(operations=decode_op, input_columns=["image"])
82    data1 = data1.map(operations=resize_op, input_columns=["image"])
83    data1 = data1.repeat(22)
84    data1 = data1.batch(batch_size, drop_remainder=True)
85
86    num_iter = 0
87    for _ in data1.create_dict_iterator(num_epochs=1):
88        num_iter += 1
89    logger.info("Number of tf data in data1: {}".format(num_iter))
90    assert num_iter == 2
91
92
93def test_tf_repeat_04():
94    """
95    Test a simple repeat operation with column list.
96    """
97    logger.info("Test Simple Repeat Column List")
98    # define parameters
99    repeat_count = 2
100    columns_list = ["col_sint64", "col_sint32"]
101    # apply dataset operations
102    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
103    data1 = data1.repeat(repeat_count)
104
105    filename = "repeat_list_result.npz"
106    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
107
108
109def generator():
110    for i in range(3):
111        (yield np.array([i]),)
112
113
114def test_nested_repeat1():
115    logger.info("test_nested_repeat1")
116    data = ds.GeneratorDataset(generator, ["data"])
117    data = data.repeat(2)
118    data = data.repeat(3)
119
120    for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
121        assert i % 3 == d[0][0]
122
123    assert sum([1 for _ in data]) == 2 * 3 * 3
124
125
126def test_nested_repeat2():
127    logger.info("test_nested_repeat2")
128    data = ds.GeneratorDataset(generator, ["data"])
129    data = data.repeat(1)
130    data = data.repeat(1)
131
132    for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
133        assert i % 3 == d[0][0]
134
135    assert sum([1 for _ in data]) == 3
136
137
138def test_nested_repeat3():
139    logger.info("test_nested_repeat3")
140    data = ds.GeneratorDataset(generator, ["data"])
141    data = data.repeat(1)
142    data = data.repeat(2)
143
144    for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
145        assert i % 3 == d[0][0]
146
147    assert sum([1 for _ in data]) == 2 * 3
148
149
150def test_nested_repeat4():
151    logger.info("test_nested_repeat4")
152    data = ds.GeneratorDataset(generator, ["data"])
153    data = data.repeat(2)
154    data = data.repeat(1)
155
156    for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)):
157        assert i % 3 == d[0][0]
158
159    assert sum([1 for _ in data]) == 2 * 3
160
161
162def test_nested_repeat5():
163    logger.info("test_nested_repeat5")
164    data = ds.GeneratorDataset(generator, ["data"])
165    data = data.batch(3)
166    data = data.repeat(2)
167    data = data.repeat(3)
168
169    for _, d in enumerate(data):
170        np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]]))
171
172    assert sum([1 for _ in data]) == 6
173
174
175def test_nested_repeat6():
176    logger.info("test_nested_repeat6")
177    data = ds.GeneratorDataset(generator, ["data"])
178    data = data.repeat(2)
179    data = data.batch(3)
180    data = data.repeat(3)
181
182    for _, d in enumerate(data):
183        np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]]))
184
185    assert sum([1 for _ in data]) == 6
186
187
188def test_nested_repeat7():
189    logger.info("test_nested_repeat7")
190    data = ds.GeneratorDataset(generator, ["data"])
191    data = data.repeat(2)
192    data = data.repeat(3)
193    data = data.batch(3)
194
195    for _, d in enumerate(data):
196        np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]]))
197
198    assert sum([1 for _ in data]) == 6
199
200
201def test_nested_repeat8():
202    logger.info("test_nested_repeat8")
203    data = ds.GeneratorDataset(generator, ["data"])
204    data = data.batch(2, drop_remainder=False)
205    data = data.repeat(2)
206    data = data.repeat(3)
207
208    for i, d in enumerate(data):
209        if i % 2 == 0:
210            np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1]]))
211        else:
212            np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[2]]))
213
214    assert sum([1 for _ in data]) == 6 * 2
215
216
217def test_nested_repeat9():
218    logger.info("test_nested_repeat9")
219    data = ds.GeneratorDataset(generator, ["data"])
220    data = data.repeat()
221    data = data.repeat(3)
222
223    for i, d in enumerate(data):
224        assert i % 3 == d[0].asnumpy()[0]
225        if i == 10:
226            break
227
228
229def test_nested_repeat10():
230    logger.info("test_nested_repeat10")
231    data = ds.GeneratorDataset(generator, ["data"])
232    data = data.repeat(3)
233    data = data.repeat()
234
235    for i, d in enumerate(data):
236        assert i % 3 == d[0].asnumpy()[0]
237        if i == 10:
238            break
239
240
241def test_nested_repeat11():
242    logger.info("test_nested_repeat11")
243    data = ds.GeneratorDataset(generator, ["data"])
244    data = data.repeat(2)
245    data = data.repeat(3)
246    data = data.repeat(4)
247    data = data.repeat(5)
248
249    for i, d in enumerate(data):
250        assert i % 3 == d[0].asnumpy()[0]
251
252    assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
253
254
255def test_repeat_count1():
256    data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
257    data1_size = data1.get_dataset_size()
258    logger.info("dataset size is {}".format(data1_size))
259    batch_size = 2
260    repeat_count = 4
261    resize_height, resize_width = 32, 32
262    decode_op = vision.Decode()
263    resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
264    data1 = data1.map(operations=decode_op, input_columns=["image"])
265    data1 = data1.map(operations=resize_op, input_columns=["image"])
266    data1 = data1.repeat(repeat_count)
267    data1 = data1.batch(batch_size, drop_remainder=False)
268    dataset_size = data1.get_dataset_size()
269    logger.info("dataset repeat then batch's size is {}".format(dataset_size))
270    num1_iter = 0
271    for _ in data1.create_dict_iterator(num_epochs=1):
272        num1_iter += 1
273
274    assert data1_size == 3
275    assert dataset_size == num1_iter == 6
276
277
278def test_repeat_count2():
279    data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
280    data1_size = data1.get_dataset_size()
281    logger.info("dataset size is {}".format(data1_size))
282    batch_size = 2
283    repeat_count = 4
284    resize_height, resize_width = 32, 32
285    decode_op = vision.Decode()
286    resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
287    data1 = data1.map(operations=decode_op, input_columns=["image"])
288    data1 = data1.map(operations=resize_op, input_columns=["image"])
289    data1 = data1.batch(batch_size, drop_remainder=False)
290    data1 = data1.repeat(repeat_count)
291    dataset_size = data1.get_dataset_size()
292    logger.info("dataset batch then repeat's size is {}".format(dataset_size))
293    num1_iter = 0
294    for _ in data1.create_dict_iterator(num_epochs=1):
295        num1_iter += 1
296
297    assert data1_size == 3
298    assert dataset_size == num1_iter == 8
299
300
301def test_repeat_count0():
302    """
303    Test Repeat with invalid count 0.
304    """
305    logger.info("Test Repeat with invalid count 0")
306    with pytest.raises(ValueError) as info:
307        data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
308        data1.repeat(0)
309    assert "count" in str(info.value)
310
311
312def test_repeat_countneg2():
313    """
314    Test Repeat with invalid count -2.
315    """
316    logger.info("Test Repeat with invalid count -2")
317    with pytest.raises(ValueError) as info:
318        data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
319        data1.repeat(-2)
320    assert "count" in str(info.value)
321
322
323if __name__ == "__main__":
324    test_tf_repeat_01()
325    test_tf_repeat_02()
326    test_tf_repeat_03()
327    test_tf_repeat_04()
328    test_nested_repeat1()
329    test_nested_repeat2()
330    test_nested_repeat3()
331    test_nested_repeat4()
332    test_nested_repeat5()
333    test_nested_repeat6()
334    test_nested_repeat7()
335    test_nested_repeat8()
336    test_nested_repeat9()
337    test_nested_repeat10()
338    test_nested_repeat11()
339    test_repeat_count1()
340    test_repeat_count2()
341    test_repeat_count0()
342    test_repeat_countneg2()
343