• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from builtins import range, super
16import time
17
18import pytest
19
20from mindspore import context
21from mindspore import log as logger
22from mindspore.dataset.callback import DSCallback, WaitedDSCallback
23from mindspore.train import Model, Callback
24
25import mindspore.dataset as ds
26import mindspore.nn as nn
27
28context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
29
30
31class BaseCallback(DSCallback):
32    def __init__(self, step_size=1, events=None, cb_id=0):
33        super().__init__(step_size)
34        self.events = events
35        self.cb_id = cb_id
36
37    def append(self, event_name, ds_run_context):
38        event = [event_name, ds_run_context.cur_epoch_num,
39                 ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
40        event = '_'.join([str(e) for e in event])
41        index = -1
42        for i, e in enumerate(self.events):
43            if e[0] == event:
44                index = i
45                break
46        if index != -1:
47            self.events[index][1].append(self.cb_id)
48        else:
49            self.events.append((event, [self.cb_id]))
50
51
52class Begin(BaseCallback):
53    def ds_begin(self, ds_run_context):
54        self.append("begin", ds_run_context)
55
56
57class EpochBegin(BaseCallback):
58    def ds_epoch_begin(self, ds_run_context):
59        self.append("epoch_begin", ds_run_context)
60
61
62class EpochEnd(BaseCallback):
63    def ds_epoch_end(self, ds_run_context):
64        self.append("epoch_end", ds_run_context)
65
66
67class StepBegin(BaseCallback):
68    def ds_step_begin(self, ds_run_context):
69        self.append("step_begin", ds_run_context)
70
71
72class StepEnd(BaseCallback):
73    def ds_step_end(self, ds_run_context):
74        self.append("step_end", ds_run_context)
75
76
77class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
78    pass
79
80
81def verify_events(events, epoch_num, step_num, step_size=1, map_num=1, repeat=1):
82    '''
83    Make sure that the events are in correct order.
84    * begin is the first
85    * epoch x begin before epoch x end
86    * epoch x end before epoch x+1 begin
87    * step x begin before step x end
88    * step x begin before step x+1 begin
89    * step x end before step x+1 end
90    '''
91    assert events[0][0] == "begin_0_0_0"
92    epochs = list(filter(lambda e: 'epoch' in e[0], events))
93    i = 0
94    while i < len(epochs):
95        epoch_num = epochs[i][0].split('_')[2]
96        e_type = epochs[i][0].split('_')[1]
97        assert str(i // 2 + 1) == epoch_num
98        assert e_type == "begin"
99        i += 1
100        epoch_num = epochs[i][0].split('_')[2]
101        e_type = epochs[i][0].split('_')[1]
102        assert str(i // 2 + 1) == epoch_num
103        assert e_type == "end"
104        i += 1
105
106    steps = list(filter(lambda e: 'step' in e[0], events))
107    steps = [(s[0].split('_')[1], s[0].split('_')[-1]) for s in steps]
108
109    steps_map = {}
110    max_step = 0
111    for s in steps:
112        if s[1] in steps_map:
113            assert steps_map[s[1]] == 'begin'
114            assert s[0] == 'end'
115        else:
116            assert s[0] == 'begin'
117            steps_map[s[1]] = 'begin'
118            assert int(s[1]) > max_step
119            max_step = max(max_step, int(s[1]))
120
121
122def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
123    events = []
124    cb_id = list(range(map_num))
125
126    def append(name, e, s):
127        event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
128        event = '_'.join([str(ev) for ev in event])
129        events.append((event, cb_id))
130
131    events.append(("begin_0_0_0", cb_id))
132    for e in range(epoch_num):
133        append("epoch_begin", e, -1)
134        for s in range(step_num * repeat):
135            if s % step_size == 0:
136                append("step_begin", e, s)
137                append("step_end", e, s)
138        append("epoch_end", e, step_num * repeat - 1)
139    return events
140
141
142def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
143    events = []
144
145    arr = list(range(1, steps + 1))
146    data = ds.NumpySlicesDataset(arr, shuffle=False)
147
148    my_cb = MyDSCallback(step_size=step_size, events=events)
149
150    data = data.map(operations=(lambda x: x), callbacks=my_cb)
151    if repeat != 1:
152        if repeat % 2 == 0 and repeat != 2:
153            data = data.repeat(2)
154            data = data.map(operations=(lambda x: x))
155            data = data.repeat(repeat // 2)
156        else:
157            data = data.repeat(repeat)
158    itr = data.create_tuple_iterator(num_epochs=epochs)
159    for _ in range(epochs):
160        for _ in itr:
161            pass
162
163    expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
164    expected_events = [e[0] for e in expected_events]
165    verify_events(events, epochs, steps, step_size, repeat)
166    events = [e[0] for e in events]
167    expected_events.sort()
168    events.sort()
169    assert expected_events == events
170
171
172def build_test_case_2cbs(epochs, steps):
173    events1 = []
174    events2 = []
175    my_cb1 = MyDSCallback(events=events1)
176    my_cb2 = MyDSCallback(events=events2)
177
178    arr = list(range(1, steps + 1))
179    data = ds.NumpySlicesDataset(arr, shuffle=False)
180
181    data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
182
183    itr = data.create_tuple_iterator(num_epochs=epochs)
184    for _ in range(epochs):
185        for _ in itr:
186            pass
187
188    expected_events = generate_expected(epochs, steps)
189    expected_events.sort()
190    events1.sort()
191    events2.sort()
192    assert expected_events == events1
193    assert expected_events == events2
194
195
196def build_test_case_2maps(epochs, steps):
197    events = []
198    my_cb1 = MyDSCallback(events=events, cb_id=0)
199    my_cb2 = MyDSCallback(events=events, cb_id=1)
200
201    arr = list(range(1, steps + 1))
202    data = ds.NumpySlicesDataset(arr, shuffle=False)
203
204    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
205    data = data.map(operations=(lambda x: x), callbacks=my_cb2)
206
207    itr = data.create_tuple_iterator(num_epochs=epochs)
208    for _ in range(epochs):
209        for _ in itr:
210            pass
211
212    expected_events = generate_expected(epochs, steps, map_num=2)
213
214    assert expected_events[1:] == events[1:]
215
216    for event in events:
217        assert len(event) == 2
218        event, cb_ids = event
219        if event != "begin_0_0_0":
220            assert cb_ids[0] == 0
221            assert cb_ids[1] == 1
222
223
224def test_callbacks_all_methods():
225    """
226    Feature: Callback
227    Description: Test Map op with 1 callback with various num_epochs and num_steps args combinations
228    Expectation: Output is equal to the expected output
229    """
230    logger.info("test_callbacks_all_methods")
231
232    build_test_case_1cb(1, 1)
233    build_test_case_1cb(1, 2)
234    build_test_case_1cb(1, 3)
235    build_test_case_1cb(1, 4)
236
237    build_test_case_1cb(2, 1)
238    build_test_case_1cb(2, 2)
239    build_test_case_1cb(2, 3)
240    build_test_case_1cb(2, 4)
241
242    build_test_case_1cb(3, 1)
243    build_test_case_1cb(3, 2)
244    build_test_case_1cb(3, 3)
245    build_test_case_1cb(3, 4)
246
247
248def test_callbacks_var_step_size():
249    """
250    Feature: Callback
251    Description: Test Map op with 1 callback with step_size=2 and various num_epochs and num_steps args combinations
252    Expectation: Output is equal to the expected output
253    """
254    logger.info("test_callbacks_var_step_size")
255
256    build_test_case_1cb(1, 2, 2)
257    build_test_case_1cb(1, 3, 2)
258    build_test_case_1cb(1, 4, 2)
259
260    build_test_case_1cb(2, 2, 2)
261    build_test_case_1cb(2, 3, 2)
262    build_test_case_1cb(2, 4, 2)
263
264    build_test_case_1cb(3, 2, 2)
265    build_test_case_1cb(3, 3, 2)
266    build_test_case_1cb(3, 4, 2)
267
268
269def test_callbacks_all_2cbs():
270    """
271    Feature: Callback
272    Description: Test Map op with 2 callbacks with various num_epochs and num_steps args combinations
273    Expectation: Output is equal to the expected output
274    """
275    logger.info("test_callbacks_all_2cbs")
276
277    build_test_case_2cbs(4, 1)
278    build_test_case_2cbs(4, 2)
279    build_test_case_2cbs(4, 3)
280    build_test_case_2cbs(4, 4)
281
282
283class MyWaitedCallback(WaitedDSCallback):
284    def __init__(self, events, step_size=1):
285        super().__init__(step_size)
286        self.events = events
287
288    def sync_epoch_begin(self, train_run_context, ds_run_context):
289        event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
290        self.events.append(event)
291
292    def sync_step_begin(self, train_run_context, ds_run_context):
293        event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
294        self.events.append(event)
295
296
297class MyMSCallback(Callback):
298    def __init__(self, events):
299        self.events = events
300
301    def epoch_end(self, run_context):
302        cb_params = run_context.original_args()
303        event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
304        self.events.append(event)
305
306    def step_end(self, run_context):
307        cb_params = run_context.original_args()
308        event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
309        self.events.append(event)
310
311
312class Net(nn.Cell):
313    def construct(self, x, y):
314        return x
315
316
317def test_callbacks_non_sink():
318    """
319    Feature: Callback
320    Description: Test callbacks with dataset_sink_mode=False in train
321    Expectation: Output is equal to the expected output
322    """
323    logger.info("test_callbacks_non_sink")
324
325    events = []
326    my_cb1 = MyWaitedCallback(events, 1)
327    my_cb2 = MyMSCallback(events)
328    arr = [1, 2, 3, 4]
329    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
330    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
331
332    net = Net()
333    model = Model(net)
334
335    model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
336    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
337                              'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
338                              'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
339                              'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
340                              'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
341                              'ms_step_end_2_8', 'ms_epoch_end_2_8']
342
343    assert events[:18] == expected_synced_events
344
345
346def test_callbacks_non_sink_batch_size2():
347    """
348    Feature: Callback
349    Description: Test callbacks with dataset_sink_mode=False in train after batch(2) is applied to the dataset
350    Expectation: Output is equal to the expected output
351    """
352    logger.info("test_callbacks_non_sink_batch_size2")
353
354    events = []
355    my_cb1 = MyWaitedCallback(events, 2)
356    my_cb2 = MyMSCallback(events)
357    arr = [1, 2, 3, 4]
358    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
359    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
360    data = data.batch(2)
361    net = Net()
362    model = Model(net)
363
364    model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
365
366    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
367                              'ms_step_end_1_2',
368                              'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
369                              'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
370                              'ms_step_end_2_4', 'ms_epoch_end_2_4']
371
372    assert events[:10] == expected_synced_events
373
374
375def test_callbacks_non_sink_mismatch_size():
376    """
377    Feature: Callback
378    Description: Test callbacks with dataset_sink_mode=False in train with mismatch size
379    Expectation: Exception is raised as expected
380    """
381    logger.info("test_callbacks_non_sink_mismatch_size")
382    default_timeout = ds.config.get_callback_timeout()
383    ds.config.set_callback_timeout(1)
384
385    events = []
386    my_cb1 = MyWaitedCallback(events, 2)
387    my_cb2 = MyMSCallback(events)
388    arr = [1, 2, 3, 4]
389    data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
390    data = data.map(operations=(lambda x: x), callbacks=my_cb1)
391    data = data.batch(3)
392    net = Net()
393    model = Model(net)
394    with pytest.raises(Exception) as err:
395        model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
396    assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value)
397
398    ds.config.set_callback_timeout(default_timeout)
399
400
401def test_callbacks_validations():
402    """
403    Feature: Callback
404    Description: Test callbacks param in Map op with invalid argument
405    Expectation: Exception is raised as expected
406    """
407    logger.info("test_callbacks_validations")
408
409    with pytest.raises(Exception) as err:
410        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
411        data.map(operations=(lambda x: x), callbacks=0)
412    assert "Argument callbacks with value 0 is not " in str(err.value)
413
414    with pytest.raises(Exception) as err:
415        my_cb1 = MyDSCallback()
416        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
417        data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
418    assert "Argument callbacks[1] with value 0 is not " in str(err.value)
419
420    with pytest.raises(Exception) as err:
421        class BadCB(DSCallback):
422            pass
423
424        my_cb = BadCB()
425
426        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
427        data = data.map(operations=(lambda x: x), callbacks=my_cb)
428        for _ in data:
429            pass
430    assert "Inheriting Callback class without overriding any methods" in str(err.value)
431
432
433def test_callbacks_sink_simulation():
434    """
435    Feature: Callback
436    Description: Test callbacks under sink simulation
437    Expectation: Output is equal to the expected output
438    """
439    logger.info("test_callback_sink_simulation")
440
441    events = []
442    epochs = 2
443    my_cb = MyWaitedCallback(events, 1)
444    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
445    data = data.map(operations=(lambda x: x), callbacks=my_cb)
446    data = data.device_que()
447    data.send(num_epochs=epochs)
448    for e in range(epochs):
449        for s in range(4):
450            time.sleep(0.5)
451            events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
452            my_cb.step_end(run_context=0)
453        events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
454        my_cb.epoch_end(run_context=0)
455    expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
456                              'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
457                              'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
458                              'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
459                              'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
460                              'ms_step_end_2_8', 'ms_epoch_end_2_8']
461
462    assert events == expected_synced_events
463
464
465def test_callbacks_repeat():
466    """
467    Feature: Callback
468    Description: Test Map op with 1 callback with various num_epochs, num_steps, step_size, and repeat args combinations
469    Expectation: Output is equal to the expected output
470    """
471    logger.info("test_callbacks_repeat")
472
473    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
474    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
475    build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
476    build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
477
478    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
479    build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4)
480    build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8)
481    build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16)
482
483
484def test_callbacks_exceptions():
485    logger.info("test_callbacks_exceptions")
486
487    class BadCB(DSCallback):
488        def ds_begin(self, ds_run_context):
489            raise RuntimeError("Bad begin")
490
491    with pytest.raises(Exception) as err:
492        data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
493        data = data.map(operations=(lambda x: x), callbacks=BadCB())
494        for _ in data:
495            pass
496        assert "RuntimeError: Bad begin" in str(err.value)
497
498
499def test_callbacks_train_end():
500    """
501    Feature: Callback
502    Description: Test callback end op under sink simulation
503    Expectation: Runs successfully
504    """
505    logger.info("test_callback_sink_simulation")
506    # No asserts are needed, just test there is no deadlock or exceptions
507    events = []
508    epochs = 2
509
510    my_cb = MyWaitedCallback(events, 1)
511    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
512    data = data.map(operations=(lambda x: x), callbacks=[my_cb])
513    data = data.device_que()
514    data.send(num_epochs=epochs)
515    time.sleep(0.5)
516    my_cb.end(run_context={})
517    time.sleep(0.5)
518
519
520def test_callbacks_one_cb():
521    """
522    Feature: Callback
523    Description: Test callbacks with Begin, EpochBegin, EpochEnd, StepBegin, and StepEnd as the args
524    Expectation: Output is equal to the expected output
525    """
526    logger.info("test_callbacks_one_cb")
527
528    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
529    events1 = []
530    events2 = []
531    events3 = []
532    my_begin = Begin(events=events1, cb_id=1)
533    my_epoch_begin = EpochBegin(events=events2, cb_id=2)
534    my_epoch_end = EpochEnd(events=events3, cb_id=3)
535    my_step_begin = StepBegin(events=events3, cb_id=3)
536    my_step_end = StepEnd(events=events2, cb_id=2)
537
538    data = data.map(operations=(lambda x: x), callbacks=my_begin)
539    data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end])
540    data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin])
541
542    itr = data.create_tuple_iterator(num_epochs=2)
543    for _ in range(2):
544        for _ in itr:
545            pass
546    expected_events1 = [('begin_0_0_0', [1])]
547    expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]),
548                        ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]),
549                        ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]),
550                        ('step_end_2_4_8', [2])]
551    expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]),
552                        ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
553                        ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
554                        ('epoch_end_2_4_8', [3])]
555    events1.sort()
556    events2.sort()
557    events3.sort()
558    expected_events1.sort()
559    expected_events2.sort()
560    expected_events3.sort()
561    assert events1 == expected_events1
562    assert events2 == expected_events2
563    assert events3 == expected_events3
564
565
566def test_clear_callback():
567    """
568    Feature: Callback
569    Description: Test callback is removed for get_dataset_size and output_shape/type
570    Expectation: Output is equal to the expected output
571    """
572    logger.info("test_clear_callback")
573
574    # this test case will test that callback is removed for get_dataset_size and output_shape/type
575    class FlagCallback(DSCallback):
576        def __init__(self):
577            super().__init__(step_size=1)
578            self.flag = False
579            self.row_cnt = 0
580
581        def ds_begin(self, ds_run_context):
582            # if callback isn't removed in getter pass, this function will be called
583            self.flag = True
584
585        def ds_step_begin(self, ds_run_context):
586            self.row_cnt += 1
587
588    data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
589    cb = FlagCallback()
590    # make sure variables are properly initialized before testing
591    assert not cb.flag and cb.row_cnt == 0
592    data = data.map(operations=(lambda x: x), callbacks=cb)
593    assert data.get_dataset_size() == 4
594    assert data.output_shapes() == [[]]
595    # make sure callback is never called by checking flag and row_cnt
596    assert not cb.flag and cb.row_cnt == 0
597    for _ in data.create_dict_iterator(num_epochs=1):
598        pass
599    # this ensure that callback is indeed called
600    assert cb.flag and cb.row_cnt == 4
601
602
603if __name__ == '__main__':
604    test_callbacks_all_2cbs()
605    test_callbacks_all_methods()
606    test_callbacks_exceptions()
607    test_callbacks_repeat()
608    test_callbacks_sink_simulation()
609    test_callbacks_validations()
610    test_callbacks_var_step_size()
611    test_callbacks_non_sink_batch_size2()
612    test_callbacks_non_sink()
613    test_callbacks_one_cb()
614    test_callbacks_non_sink_mismatch_size()
615    test_callbacks_train_end()
616    test_clear_callback()
617