• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from builtins import range, super
16import time
17
18import pytest
19
20from mindspore import context
21from mindspore import log as logger
22from mindspore.dataset.callback import DSCallback, WaitedDSCallback
23from mindspore.train import Model
24from mindspore.train.callback import Callback
25
26import mindspore.dataset as ds
27import mindspore.nn as nn
28
29context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
30
31
32class BaseCallback(DSCallback):
33    def __init__(self, step_size=1, events=None, cb_id=0):
34        super().__init__(step_size)
35        self.events = events
36        self.cb_id = cb_id
37
38    def append(self, event_name, ds_run_context):
39        event = [event_name, ds_run_context.cur_epoch_num,
40                 ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
41        event = '_'.join([str(e) for e in event])
42        index = -1
43        for i, e in enumerate(self.events):
44            if e[0] == event:
45                index = i
46                break
47        if index != -1:
48            self.events[index][1].append(self.cb_id)
49        else:
50            self.events.append((event, [self.cb_id]))
51
52
53class Begin(BaseCallback):
54    def ds_begin(self, ds_run_context):
55        self.append("begin", ds_run_context)
56
57
58class EpochBegin(BaseCallback):
59    def ds_epoch_begin(self, ds_run_context):
60        self.append("epoch_begin", ds_run_context)
61
62
63class EpochEnd(BaseCallback):
64    def ds_epoch_end(self, ds_run_context):
65        self.append("epoch_end", ds_run_context)
66
67
68class StepBegin(BaseCallback):
69    def ds_step_begin(self, ds_run_context):
70        self.append("step_begin", ds_run_context)
71
72
73class StepEnd(BaseCallback):
74    def ds_step_end(self, ds_run_context):
75        self.append("step_end", ds_run_context)
76
77
78class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
79    pass
80
81
82def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
83    events = []
84    cb_id = list(range(map_num))
85
86    def append(name, e, s):
87        event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
88        event = '_'.join([str(ev) for ev in event])
89        events.append((event, cb_id))
90
91    events.append(("begin_0_0_0", cb_id))
92    for e in range(epoch_num):
93        append("epoch_begin", e, -1)
94        for s in range(step_num * repeat):
95            if s % step_size == 0:
96                append("step_begin", e, s)
97                append("step_end", e, s)
98        append("epoch_end", e, step_num * repeat - 1)
99    return events
100
101
102def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
103    events = []
104
105    arr = list(range(1, steps + 1))
106    data = ds.NumpySlicesDataset(arr, shuffle=False)
107
108    my_cb = MyDSCallback(step_size=step_size, events=events)
109
110    data = data.map(operations=(lambda x: x), callbacks=my_cb)
111    if repeat != 1:
112        if repeat % 2 == 0 and repeat != 2:
113            data = data.repeat(2)
114            data = data.map(operations=(lambda x: x))
115            data = data.repeat(repeat // 2)
116        else:
117            data = data.repeat(repeat)
118    itr = data.create_tuple_iterator(num_epochs=epochs)
119    for _ in range(epochs):
120        for _ in itr:
121            pass
122
123    expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
124    assert expected_events == events
125
126
127def build_test_case_2cbs(epochs, steps):
128    events1 = []
129    events2 = []
130    my_cb1 = MyDSCallback(events=events1)
131    my_cb2 = MyDSCallback(events=events2)
132
133    arr = list(range(1, steps + 1))
134    data = ds.NumpySlicesDataset(arr, shuffle=False)
135
136    data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
137
138    itr = data.create_tuple_iterator(num_epochs=epochs)
139    for _ in range(epochs):
140        for _ in itr:
141            pass
142
143    expected_events = generate_expected(epochs, steps)
144    assert expected_events == events1
145    assert expected_events == events2
146
147
148def build_test_case_2maps(epochs, steps):
149    events = []
150    my_cb1 = MyDSCallback(events=events, cb_id=0)
151    my_cb2 = MyDSCallback(events=events, cb_id=1)
152
153    arr = list(range(1, steps + 1))
154    data = ds.NumpySlicesDataset(arr, shuffle=False)
155
156    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
157    data = data.map(operations=(lambda x: x), callbacks=my_cb2)
158
159    itr = data.create_tuple_iterator(num_epochs=epochs)
160    for _ in range(epochs):
161        for _ in itr:
162            pass
163
164    expected_events = generate_expected(epochs, steps, map_num=2)
165
166    assert expected_events[1:] == events[1:]
167
168    for event in events:
169        assert len(event) == 2
170        event, cb_ids = event
171        if event != "begin_0_0_0":
172            assert cb_ids[0] == 0
173            assert cb_ids[1] == 1
174
175
176def test_callbacks_all_methods():
177    logger.info("test_callbacks_all_methods")
178
179    build_test_case_1cb(1, 1)
180    build_test_case_1cb(1, 2)
181    build_test_case_1cb(1, 3)
182    build_test_case_1cb(1, 4)
183
184    build_test_case_1cb(2, 1)
185    build_test_case_1cb(2, 2)
186    build_test_case_1cb(2, 3)
187    build_test_case_1cb(2, 4)
188
189    build_test_case_1cb(3, 1)
190    build_test_case_1cb(3, 2)
191    build_test_case_1cb(3, 3)
192    build_test_case_1cb(3, 4)
193
194
195def test_callbacks_var_step_size():
196    logger.info("test_callbacks_var_step_size")
197
198    build_test_case_1cb(1, 2, 2)
199    build_test_case_1cb(1, 3, 2)
200    build_test_case_1cb(1, 4, 2)
201
202    build_test_case_1cb(2, 2, 2)
203    build_test_case_1cb(2, 3, 2)
204    build_test_case_1cb(2, 4, 2)
205
206    build_test_case_1cb(3, 2, 2)
207    build_test_case_1cb(3, 3, 2)
208    build_test_case_1cb(3, 4, 2)
209
210
211def test_callbacks_all_2cbs():
212    logger.info("test_callbacks_all_2cbs")
213
214    build_test_case_2cbs(4, 1)
215    build_test_case_2cbs(4, 2)
216    build_test_case_2cbs(4, 3)
217    build_test_case_2cbs(4, 4)
218
219
220class MyWaitedCallback(WaitedDSCallback):
221    def __init__(self, events, step_size=1):
222        super().__init__(step_size)
223        self.events = events
224
225    def sync_epoch_begin(self, train_run_context, ds_run_context):
226        event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
227        self.events.append(event)
228
229    def sync_step_begin(self, train_run_context, ds_run_context):
230        event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
231        self.events.append(event)
232
233
234class MyMSCallback(Callback):
235    def __init__(self, events):
236        self.events = events
237
238    def epoch_end(self, run_context):
239        cb_params = run_context.original_args()
240        event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
241        self.events.append(event)
242
243    def step_end(self, run_context):
244        cb_params = run_context.original_args()
245        event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
246        self.events.append(event)
247
248
249class Net(nn.Cell):
250    def construct(self, x, y):
251        return x
252
253
254def test_callbacks_non_sink():
255    logger.info("test_callbacks_non_sink")
256
257    events = []
258    my_cb1 = MyWaitedCallback(events, 1)
259    my_cb2 = MyMSCallback(events)
260    arr = [1, 2, 3, 4]
261    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
262    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
263
264    net = Net()
265    model = Model(net)
266
267    model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
268    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
269                              'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
270                              'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
271                              'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
272                              'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
273                              'ms_step_end_2_8', 'ms_epoch_end_2_8']
274
275    assert events[:18] == expected_synced_events
276
277
278def test_callbacks_non_sink_batch_size2():
279    logger.info("test_callbacks_non_sink_batch_size2")
280
281    events = []
282    my_cb1 = MyWaitedCallback(events, 2)
283    my_cb2 = MyMSCallback(events)
284    arr = [1, 2, 3, 4]
285    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
286    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
287    data = data.batch(2)
288    net = Net()
289    model = Model(net)
290
291    model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
292
293    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
294                              'ms_step_end_1_2',
295                              'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
296                              'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
297                              'ms_step_end_2_4', 'ms_epoch_end_2_4']
298
299    assert events[:10] == expected_synced_events
300
301
302def test_callbacks_non_sink_mismatch_size():
303    logger.info("test_callbacks_non_sink_mismatch_size")
304    default_timeout = ds.config.get_callback_timeout()
305    ds.config.set_callback_timeout(1)
306
307    events = []
308    my_cb1 = MyWaitedCallback(events, 2)
309    my_cb2 = MyMSCallback(events)
310    arr = [1, 2, 3, 4]
311    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
312    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
313    data = data.batch(3)
314    net = Net()
315    model = Model(net)
316    with pytest.raises(Exception) as err:
317        model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
318    assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value)
319
320    ds.config.set_callback_timeout(default_timeout)
321
322
323def test_callbacks_validations():
324    logger.info("test_callbacks_validations")
325
326    with pytest.raises(Exception) as err:
327        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
328        data.map(operations=(lambda x: x), callbacks=0)
329    assert "Argument callbacks with value 0 is not " in str(err.value)
330
331    with pytest.raises(Exception) as err:
332        my_cb1 = MyDSCallback()
333        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
334        data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
335    assert "Argument callbacks[1] with value 0 is not " in str(err.value)
336
337    with pytest.raises(Exception) as err:
338        class BadCB(DSCallback):
339            pass
340
341        my_cb = BadCB()
342
343        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
344        data = data.map(operations=(lambda x: x), callbacks=my_cb)
345        for _ in data:
346            pass
347    assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
348
349
350def test_callbacks_sink_simulation():
351    logger.info("test_callback_sink_simulation")
352
353    events = []
354    epochs = 2
355    my_cb = MyWaitedCallback(events, 1)
356    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
357    data = data.map(operations=(lambda x: x), callbacks=my_cb)
358    data = data.to_device()
359    data.send(num_epochs=epochs)
360    for e in range(epochs):
361        for s in range(4):
362            time.sleep(0.5)
363            events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
364            my_cb.step_end(run_context=0)
365        events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
366        my_cb.epoch_end(run_context=0)
367    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
368                              'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
369                              'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
370                              'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
371                              'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
372                              'ms_step_end_2_8', 'ms_epoch_end_2_8']
373
374    assert events == expected_synced_events
375
376
377def test_callbacks_repeat():
378    logger.info("test_callbacks_repeat")
379
380    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
381    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
382    build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
383    build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
384
385    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
386    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4)
387    build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8)
388    build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16)
389
390
391def test_callbacks_exceptions():
392    logger.info("test_callbacks_exceptions")
393
394    class BadCB(DSCallback):
395        def ds_begin(self, ds_run_context):
396            raise RuntimeError("Bad begin")
397
398    with pytest.raises(Exception) as err:
399        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
400        data = data.map(operations=(lambda x: x), callbacks=BadCB())
401        for _ in data:
402            pass
403        assert "RuntimeError: Bad begin" in str(err.value)
404
405
406def test_callbacks_train_end():
407    logger.info("test_callback_sink_simulation")
408    # No asserts are needed, just test there is no deadlock or exceptions
409    events = []
410    epochs = 2
411
412    my_cb = MyWaitedCallback(events, 1)
413    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
414    data = data.map(operations=(lambda x: x), callbacks=[my_cb])
415    data = data.to_device()
416    data.send(num_epochs=epochs)
417    time.sleep(0.5)
418    my_cb.end(run_context={})
419    time.sleep(0.5)
420
421
422def test_callbacks_one_cb():
423    logger.info("test_callbacks_one_cb")
424
425    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
426    events1 = []
427    events2 = []
428    events3 = []
429    my_begin = Begin(events=events1, cb_id=1)
430    my_epoch_begin = EpochBegin(events=events2, cb_id=2)
431    my_epoch_end = EpochEnd(events=events3, cb_id=3)
432    my_step_begin = StepBegin(events=events3, cb_id=3)
433    my_step_end = StepEnd(events=events2, cb_id=2)
434
435    data = data.map(operations=(lambda x: x), callbacks=my_begin)
436    data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end])
437    data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin])
438
439    itr = data.create_tuple_iterator(num_epochs=2)
440    for _ in range(2):
441        for _ in itr:
442            pass
443    expected_events1 = [('begin_0_0_0', [1])]
444    expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]),
445                        ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]),
446                        ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]),
447                        ('step_end_2_4_8', [2])]
448    expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]),
449                        ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
450                        ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
451                        ('epoch_end_2_4_8', [3])]
452    assert events1 == expected_events1
453    assert events2 == expected_events2
454    assert events3 == expected_events3
455
456
457def test_clear_callback():
458    logger.info("test_clear_callback")
459
460    # this test case will test that callback is removed for get_dataset_size and output_shape/type
461    class FlagCallback(DSCallback):
462        def __init__(self):
463            super().__init__(step_size=1)
464            self.flag = False
465            self.row_cnt = 0
466
467        def ds_begin(self, ds_run_context):
468            # if callback isn't removed in getter pass, this function will be called
469            self.flag = True
470
471        def ds_step_begin(self, ds_run_context):
472            self.row_cnt += 1
473
474    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
475    cb = FlagCallback()
476    # make sure variables are properly initialized before testing
477    assert not cb.flag and cb.row_cnt == 0
478    data = data.map(operations=(lambda x: x), callbacks=cb)
479    assert data.get_dataset_size() == 4
480    assert data.output_shapes() == [[]]
481    # make sure callback is never called by checking flag and row_cnt
482    assert not cb.flag and cb.row_cnt == 0
483    for _ in data.create_dict_iterator(num_epochs=1):
484        pass
485    # this ensure that callback is indeed called
486    assert cb.flag and cb.row_cnt == 4
487
488
489if __name__ == '__main__':
490    test_callbacks_all_2cbs()
491    test_callbacks_all_methods()
492    test_callbacks_exceptions()
493    test_callbacks_repeat()
494    test_callbacks_sink_simulation()
495    test_callbacks_validations()
496    test_callbacks_var_step_size()
497    test_callbacks_non_sink_batch_size2()
498    test_callbacks_non_sink()
499    test_callbacks_one_cb()
500    test_callbacks_non_sink_mismatch_size()
501    test_callbacks_train_end()
502    test_clear_callback()
503