• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17import pytest
18import mindspore.dataset as ds
19from mindspore import log as logger
20
21
22def gen():
23    for i in range(100):
24        yield (np.array(i),)
25
26
27class Augment:
28    def __init__(self, loss):
29        self.loss = loss
30
31    def preprocess(self, input_):
32        return input_
33
34    def update(self, data):
35        self.loss = data["loss"]
36
37
38def test_simple_sync_wait():
39    """
40    Test simple sync wait: test sync in dataset pipeline
41    """
42    logger.info("test_simple_sync_wait")
43    batch_size = 4
44    dataset = ds.GeneratorDataset(gen, column_names=["input"])
45
46    aug = Augment(0)
47    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
48    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
49    dataset = dataset.batch(batch_size)
50    count = 0
51    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
52        assert data["input"][0] == count
53        count += batch_size
54        data = {"loss": count}
55        dataset.sync_update(condition_name="policy", data=data)
56
57
58def test_simple_shuffle_sync():
59    """
60    Test simple shuffle sync: test shuffle before sync
61    """
62    logger.info("test_simple_shuffle_sync")
63    shuffle_size = 4
64    batch_size = 10
65
66    dataset = ds.GeneratorDataset(gen, column_names=["input"])
67
68    aug = Augment(0)
69    dataset = dataset.shuffle(shuffle_size)
70    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
71    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
72    dataset = dataset.batch(batch_size)
73
74    count = 0
75    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
76        count += 1
77        data = {"loss": count}
78        dataset.sync_update(condition_name="policy", data=data)
79
80
81def test_two_sync():
82    """
83    Test two sync: dataset pipeline with with two sync_operators
84    """
85    logger.info("test_two_sync")
86    batch_size = 6
87
88    dataset = ds.GeneratorDataset(gen, column_names=["input"])
89
90    aug = Augment(0)
91    # notice that with our design, we need to have step_size = shuffle size
92    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
93
94    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
95
96    dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
97
98    dataset = dataset.batch(batch_size)
99
100    count = 0
101    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
102        count += 1
103        data = {"loss": count}
104        dataset.sync_update(condition_name="every batch", data=data)
105        if count % 2 == 0:
106            dataset.sync_update(condition_name="every 2 batches")
107
108
109def test_sync_epoch():
110    """
111    Test sync wait with epochs: test sync with epochs in dataset pipeline
112    """
113    logger.info("test_sync_epoch")
114    batch_size = 30
115    dataset = ds.GeneratorDataset(gen, column_names=["input"])
116
117    aug = Augment(0)
118    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
119    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
120    dataset = dataset.batch(batch_size, drop_remainder=True)
121
122    for _ in range(3):
123        aug.update({"loss": 0})
124        count = 0
125        for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
126            assert data["input"][0] == count
127            count += batch_size
128            data = {"loss": count}
129            dataset.sync_update(condition_name="policy", data=data)
130
131
132def test_multiple_iterators():
133    """
134    Test sync wait with multiple iterators: will start multiple
135    """
136    logger.info("test_sync_epoch")
137    batch_size = 30
138    dataset = ds.GeneratorDataset(gen, column_names=["input"])
139
140    aug = Augment(0)
141    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
142    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
143    dataset = dataset.batch(batch_size, drop_remainder=True)
144    # 2nd dataset
145    dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
146
147    aug = Augment(0)
148    dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
149    dataset2 = dataset2.map(operations=[aug.preprocess], input_columns=["input"])
150    dataset2 = dataset2.batch(batch_size, drop_remainder=True)
151
152    for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True),
153                            dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
154        assert item1["input"][0] == item2["input"][0]
155        data1 = {"loss": item1["input"][0]}
156        data2 = {"loss": item2["input"][0]}
157        dataset.sync_update(condition_name="policy", data=data1)
158        dataset2.sync_update(condition_name="policy", data=data2)
159
160
161def test_sync_exception_01():
162    """
163    Test sync: with shuffle in sync mode
164    """
165    logger.info("test_sync_exception_01")
166    shuffle_size = 4
167
168    dataset = ds.GeneratorDataset(gen, column_names=["input"])
169
170    aug = Augment(0)
171    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
172    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
173
174    with pytest.raises(RuntimeError) as e:
175        dataset.shuffle(shuffle_size)
176    assert "No shuffle after sync operators" in str(e.value)
177
178
179def test_sync_exception_02():
180    """
181    Test sync: with duplicated condition name
182    """
183    logger.info("test_sync_exception_02")
184
185    dataset = ds.GeneratorDataset(gen, column_names=["input"])
186
187    aug = Augment(0)
188    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
189
190    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
191
192    with pytest.raises(RuntimeError) as e:
193        dataset.sync_wait(num_batch=2, condition_name="every batch")
194    assert "Condition name is already in use" in str(e.value)
195
196
197def test_sync_exception_03():
198    """
199    Test sync: with wrong batch size
200    """
201    logger.info("test_sync_exception_03")
202
203    dataset = ds.GeneratorDataset(gen, column_names=["input"])
204
205    aug = Augment(0)
206    # try to create dataset with batch_size < 0
207    with pytest.raises(ValueError) as e:
208        dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
209    assert "num_batch need to be greater than 0." in str(e.value)
210
211
212def test_sync_exception_04():
213    """
214    Test sync: with negative batch size in update
215    """
216    logger.info("test_sync_exception_04")
217
218    dataset = ds.GeneratorDataset(gen, column_names=["input"])
219
220    aug = Augment(0)
221    # try to create dataset with batch_size < 0
222    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
223    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
224    count = 0
225    with pytest.raises(RuntimeError) as e:
226        for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
227            count += 1
228            data = {"loss": count}
229            dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
230    assert "Sync_update batch size can only be positive" in str(e.value)
231
232
233def test_sync_exception_05():
234    """
235    Test sync: with wrong batch size in update
236    """
237    logger.info("test_sync_exception_05")
238
239    dataset = ds.GeneratorDataset(gen, column_names=["input"])
240    count = 0
241    aug = Augment(0)
242    # try to create dataset with batch_size < 0
243    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
244    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
245    with pytest.raises(RuntimeError) as e:
246        for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
247            dataset.disable_sync()
248            count += 1
249            data = {"loss": count}
250            dataset.disable_sync()
251            dataset.sync_update(condition_name="every", data=data)
252    assert "Condition name not found" in str(e.value)
253
254
255def test_simple_sync_wait_empty_condition_name():
256    """ callback is none, sync_wait and sync_update's condition_name is empty string ('') """
257    logger.info("test_simple_sync_wait_empty_condition_name")
258    batch_size = 10
259    dataset = ds.GeneratorDataset(gen, column_names=["input"])
260
261    aug = Augment(0)
262    dataset = dataset.sync_wait(condition_name='', num_batch=1)
263    dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
264    dataset = dataset.batch(batch_size)
265
266    count = 0
267    for data in dataset.create_dict_iterator(output_numpy=True):
268        count += 1
269        data = {"loss": count}
270        dataset.sync_update(condition_name="", data=data)
271
272
273def test_sync_exception_06():
274    """
275    Test sync: with string batch size
276    """
277    logger.info("test_sync_exception_03")
278
279    dataset = ds.GeneratorDataset(gen, column_names=["input"])
280
281    aug = Augment(0)
282    # try to create dataset with batch_size < 0
283    with pytest.raises(TypeError) as e:
284        dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update)
285    assert "is not of type [<class 'int'>]" in str(e.value)
286
287
288if __name__ == "__main__":
289    test_simple_sync_wait()
290    test_simple_shuffle_sync()
291    test_two_sync()
292    test_sync_exception_01()
293    test_sync_exception_02()
294    test_sync_exception_03()
295    test_sync_exception_04()
296    test_sync_exception_05()
297    test_sync_exception_06()
298    test_sync_epoch()
299    test_multiple_iterators()
300    test_simple_sync_wait_empty_condition_name()
301