1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from builtins import range, super 16import time 17 18import pytest 19 20from mindspore import context 21from mindspore import log as logger 22from mindspore.dataset.callback import DSCallback, WaitedDSCallback 23from mindspore.train import Model, Callback 24 25import mindspore.dataset as ds 26import mindspore.nn as nn 27 28context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 29 30 31class BaseCallback(DSCallback): 32 def __init__(self, step_size=1, events=None, cb_id=0): 33 super().__init__(step_size) 34 self.events = events 35 self.cb_id = cb_id 36 37 def append(self, event_name, ds_run_context): 38 event = [event_name, ds_run_context.cur_epoch_num, 39 ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num] 40 event = '_'.join([str(e) for e in event]) 41 index = -1 42 for i, e in enumerate(self.events): 43 if e[0] == event: 44 index = i 45 break 46 if index != -1: 47 self.events[index][1].append(self.cb_id) 48 else: 49 self.events.append((event, [self.cb_id])) 50 51 52class Begin(BaseCallback): 53 def ds_begin(self, ds_run_context): 54 self.append("begin", ds_run_context) 55 56 57class EpochBegin(BaseCallback): 58 def ds_epoch_begin(self, ds_run_context): 59 self.append("epoch_begin", ds_run_context) 60 61 62class EpochEnd(BaseCallback): 63 def ds_epoch_end(self, ds_run_context): 64 self.append("epoch_end", ds_run_context) 65 66 67class StepBegin(BaseCallback): 68 def ds_step_begin(self, ds_run_context): 69 self.append("step_begin", ds_run_context) 70 71 72class StepEnd(BaseCallback): 73 def ds_step_end(self, ds_run_context): 74 self.append("step_end", ds_run_context) 75 76 77class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd): 78 pass 79 80 81def verify_events(events, epoch_num, step_num, step_size=1, map_num=1, repeat=1): 82 ''' 83 Make sure that the events are in correct order. 84 * begin is the first 85 * epoch x begin before epoch x end 86 * epoch x end before epoch x+1 begin 87 * step x begin before step x end 88 * step x begin before step x+1 begin 89 * step x end before step x+1 end 90 ''' 91 assert events[0][0] == "begin_0_0_0" 92 epochs = list(filter(lambda e: 'epoch' in e[0], events)) 93 i = 0 94 while i < len(epochs): 95 epoch_num = epochs[i][0].split('_')[2] 96 e_type = epochs[i][0].split('_')[1] 97 assert str(i // 2 + 1) == epoch_num 98 assert e_type == "begin" 99 i += 1 100 epoch_num = epochs[i][0].split('_')[2] 101 e_type = epochs[i][0].split('_')[1] 102 assert str(i // 2 + 1) == epoch_num 103 assert e_type == "end" 104 i += 1 105 106 steps = list(filter(lambda e: 'step' in e[0], events)) 107 steps = [(s[0].split('_')[1], s[0].split('_')[-1]) for s in steps] 108 109 steps_map = {} 110 max_step = 0 111 for s in steps: 112 if s[1] in steps_map: 113 assert steps_map[s[1]] == 'begin' 114 assert s[0] == 'end' 115 else: 116 assert s[0] == 'begin' 117 steps_map[s[1]] = 'begin' 118 assert int(s[1]) > max_step 119 max_step = max(max_step, int(s[1])) 120 121 122def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): 123 events = [] 124 cb_id = list(range(map_num)) 125 126 def append(name, e, s): 127 event = [name, e + 1, s + 1, e * step_num * repeat + s + 1] 128 event = '_'.join([str(ev) for ev in event]) 129 events.append((event, cb_id)) 130 131 events.append(("begin_0_0_0", cb_id)) 132 for e in range(epoch_num): 133 append("epoch_begin", e, -1) 134 for s in range(step_num * repeat): 135 if s % step_size == 0: 136 append("step_begin", e, s) 137 append("step_end", e, s) 138 append("epoch_end", e, step_num * repeat - 1) 139 return events 140 141 142def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): 143 events = [] 144 145 arr = list(range(1, steps + 1)) 146 data = ds.NumpySlicesDataset(arr, shuffle=False) 147 148 my_cb = MyDSCallback(step_size=step_size, events=events) 149 150 data = data.map(operations=(lambda x: x), callbacks=my_cb) 151 if repeat != 1: 152 if repeat % 2 == 0 and repeat != 2: 153 data = data.repeat(2) 154 data = data.map(operations=(lambda x: x)) 155 data = data.repeat(repeat // 2) 156 else: 157 data = data.repeat(repeat) 158 itr = data.create_tuple_iterator(num_epochs=epochs) 159 for _ in range(epochs): 160 for _ in itr: 161 pass 162 163 expected_events = generate_expected(epochs, steps, step_size, 1, repeat) 164 expected_events = [e[0] for e in expected_events] 165 verify_events(events, epochs, steps, step_size, repeat) 166 events = [e[0] for e in events] 167 expected_events.sort() 168 events.sort() 169 assert expected_events == events 170 171 172def build_test_case_2cbs(epochs, steps): 173 events1 = [] 174 events2 = [] 175 my_cb1 = MyDSCallback(events=events1) 176 my_cb2 = MyDSCallback(events=events2) 177 178 arr = list(range(1, steps + 1)) 179 data = ds.NumpySlicesDataset(arr, shuffle=False) 180 181 data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2]) 182 183 itr = data.create_tuple_iterator(num_epochs=epochs) 184 for _ in range(epochs): 185 for _ in itr: 186 pass 187 188 expected_events = generate_expected(epochs, steps) 189 expected_events.sort() 190 events1.sort() 191 events2.sort() 192 assert expected_events == events1 193 assert expected_events == events2 194 195 196def build_test_case_2maps(epochs, steps): 197 events = [] 198 my_cb1 = MyDSCallback(events=events, cb_id=0) 199 my_cb2 = MyDSCallback(events=events, cb_id=1) 200 201 arr = list(range(1, steps + 1)) 202 data = ds.NumpySlicesDataset(arr, shuffle=False) 203 204 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 205 data = data.map(operations=(lambda x: x), callbacks=my_cb2) 206 207 itr = data.create_tuple_iterator(num_epochs=epochs) 208 for _ in range(epochs): 209 for _ in itr: 210 pass 211 212 expected_events = generate_expected(epochs, steps, map_num=2) 213 214 assert expected_events[1:] == events[1:] 215 216 for event in events: 217 assert len(event) == 2 218 event, cb_ids = event 219 if event != "begin_0_0_0": 220 assert cb_ids[0] == 0 221 assert cb_ids[1] == 1 222 223 224def test_callbacks_all_methods(): 225 """ 226 Feature: Callback 227 Description: Test Map op with 1 callback with various num_epochs and num_steps args combinations 228 Expectation: Output is equal to the expected output 229 """ 230 logger.info("test_callbacks_all_methods") 231 232 build_test_case_1cb(1, 1) 233 build_test_case_1cb(1, 2) 234 build_test_case_1cb(1, 3) 235 build_test_case_1cb(1, 4) 236 237 build_test_case_1cb(2, 1) 238 build_test_case_1cb(2, 2) 239 build_test_case_1cb(2, 3) 240 build_test_case_1cb(2, 4) 241 242 build_test_case_1cb(3, 1) 243 build_test_case_1cb(3, 2) 244 build_test_case_1cb(3, 3) 245 build_test_case_1cb(3, 4) 246 247 248def test_callbacks_var_step_size(): 249 """ 250 Feature: Callback 251 Description: Test Map op with 1 callback with step_size=2 and various num_epochs and num_steps args combinations 252 Expectation: Output is equal to the expected output 253 """ 254 logger.info("test_callbacks_var_step_size") 255 256 build_test_case_1cb(1, 2, 2) 257 build_test_case_1cb(1, 3, 2) 258 build_test_case_1cb(1, 4, 2) 259 260 build_test_case_1cb(2, 2, 2) 261 build_test_case_1cb(2, 3, 2) 262 build_test_case_1cb(2, 4, 2) 263 264 build_test_case_1cb(3, 2, 2) 265 build_test_case_1cb(3, 3, 2) 266 build_test_case_1cb(3, 4, 2) 267 268 269def test_callbacks_all_2cbs(): 270 """ 271 Feature: Callback 272 Description: Test Map op with 2 callbacks with various num_epochs and num_steps args combinations 273 Expectation: Output is equal to the expected output 274 """ 275 logger.info("test_callbacks_all_2cbs") 276 277 build_test_case_2cbs(4, 1) 278 build_test_case_2cbs(4, 2) 279 build_test_case_2cbs(4, 3) 280 build_test_case_2cbs(4, 4) 281 282 283class MyWaitedCallback(WaitedDSCallback): 284 def __init__(self, events, step_size=1): 285 super().__init__(step_size) 286 self.events = events 287 288 def sync_epoch_begin(self, train_run_context, ds_run_context): 289 event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" 290 self.events.append(event) 291 292 def sync_step_begin(self, train_run_context, ds_run_context): 293 event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" 294 self.events.append(event) 295 296 297class MyMSCallback(Callback): 298 def __init__(self, events): 299 self.events = events 300 301 def epoch_end(self, run_context): 302 cb_params = run_context.original_args() 303 event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" 304 self.events.append(event) 305 306 def step_end(self, run_context): 307 cb_params = run_context.original_args() 308 event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" 309 self.events.append(event) 310 311 312class Net(nn.Cell): 313 def construct(self, x, y): 314 return x 315 316 317def test_callbacks_non_sink(): 318 """ 319 Feature: Callback 320 Description: Test callbacks with dataset_sink_mode=False in train 321 Expectation: Output is equal to the expected output 322 """ 323 logger.info("test_callbacks_non_sink") 324 325 events = [] 326 my_cb1 = MyWaitedCallback(events, 1) 327 my_cb2 = MyMSCallback(events) 328 arr = [1, 2, 3, 4] 329 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 330 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 331 332 net = Net() 333 model = Model(net) 334 335 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 336 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', 337 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', 338 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', 339 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', 340 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', 341 'ms_step_end_2_8', 'ms_epoch_end_2_8'] 342 343 assert events[:18] == expected_synced_events 344 345 346def test_callbacks_non_sink_batch_size2(): 347 """ 348 Feature: Callback 349 Description: Test callbacks with dataset_sink_mode=False in train after batch(2) is applied to the dataset 350 Expectation: Output is equal to the expected output 351 """ 352 logger.info("test_callbacks_non_sink_batch_size2") 353 354 events = [] 355 my_cb1 = MyWaitedCallback(events, 2) 356 my_cb2 = MyMSCallback(events) 357 arr = [1, 2, 3, 4] 358 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 359 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 360 data = data.batch(2) 361 net = Net() 362 model = Model(net) 363 364 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 365 366 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3', 367 'ms_step_end_1_2', 368 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4', 369 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7', 370 'ms_step_end_2_4', 'ms_epoch_end_2_4'] 371 372 assert events[:10] == expected_synced_events 373 374 375def test_callbacks_non_sink_mismatch_size(): 376 """ 377 Feature: Callback 378 Description: Test callbacks with dataset_sink_mode=False in train with mismatch size 379 Expectation: Exception is raised as expected 380 """ 381 logger.info("test_callbacks_non_sink_mismatch_size") 382 default_timeout = ds.config.get_callback_timeout() 383 ds.config.set_callback_timeout(1) 384 385 events = [] 386 my_cb1 = MyWaitedCallback(events, 2) 387 my_cb2 = MyMSCallback(events) 388 arr = [1, 2, 3, 4] 389 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 390 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 391 data = data.batch(3) 392 net = Net() 393 model = Model(net) 394 with pytest.raises(Exception) as err: 395 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 396 assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value) 397 398 ds.config.set_callback_timeout(default_timeout) 399 400 401def test_callbacks_validations(): 402 """ 403 Feature: Callback 404 Description: Test callbacks param in Map op with invalid argument 405 Expectation: Exception is raised as expected 406 """ 407 logger.info("test_callbacks_validations") 408 409 with pytest.raises(Exception) as err: 410 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 411 data.map(operations=(lambda x: x), callbacks=0) 412 assert "Argument callbacks with value 0 is not " in str(err.value) 413 414 with pytest.raises(Exception) as err: 415 my_cb1 = MyDSCallback() 416 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 417 data.map(operations=(lambda x: x), callbacks=[my_cb1, 0]) 418 assert "Argument callbacks[1] with value 0 is not " in str(err.value) 419 420 with pytest.raises(Exception) as err: 421 class BadCB(DSCallback): 422 pass 423 424 my_cb = BadCB() 425 426 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 427 data = data.map(operations=(lambda x: x), callbacks=my_cb) 428 for _ in data: 429 pass 430 assert "Inheriting Callback class without overriding any methods" in str(err.value) 431 432 433def test_callbacks_sink_simulation(): 434 """ 435 Feature: Callback 436 Description: Test callbacks under sink simulation 437 Expectation: Output is equal to the expected output 438 """ 439 logger.info("test_callback_sink_simulation") 440 441 events = [] 442 epochs = 2 443 my_cb = MyWaitedCallback(events, 1) 444 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 445 data = data.map(operations=(lambda x: x), callbacks=my_cb) 446 data = data.device_que() 447 data.send(num_epochs=epochs) 448 for e in range(epochs): 449 for s in range(4): 450 time.sleep(0.5) 451 events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}") 452 my_cb.step_end(run_context=0) 453 events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}") 454 my_cb.epoch_end(run_context=0) 455 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', 456 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', 457 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', 458 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', 459 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', 460 'ms_step_end_2_8', 'ms_epoch_end_2_8'] 461 462 assert events == expected_synced_events 463 464 465def test_callbacks_repeat(): 466 """ 467 Feature: Callback 468 Description: Test Map op with 1 callback with various num_epochs, num_steps, step_size, and repeat args combinations 469 Expectation: Output is equal to the expected output 470 """ 471 logger.info("test_callbacks_repeat") 472 473 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) 474 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3) 475 build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) 476 build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) 477 478 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) 479 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4) 480 build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8) 481 build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16) 482 483 484def test_callbacks_exceptions(): 485 logger.info("test_callbacks_exceptions") 486 487 class BadCB(DSCallback): 488 def ds_begin(self, ds_run_context): 489 raise RuntimeError("Bad begin") 490 491 with pytest.raises(Exception) as err: 492 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 493 data = data.map(operations=(lambda x: x), callbacks=BadCB()) 494 for _ in data: 495 pass 496 assert "RuntimeError: Bad begin" in str(err.value) 497 498 499def test_callbacks_train_end(): 500 """ 501 Feature: Callback 502 Description: Test callback end op under sink simulation 503 Expectation: Runs successfully 504 """ 505 logger.info("test_callback_sink_simulation") 506 # No asserts are needed, just test there is no deadlock or exceptions 507 events = [] 508 epochs = 2 509 510 my_cb = MyWaitedCallback(events, 1) 511 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 512 data = data.map(operations=(lambda x: x), callbacks=[my_cb]) 513 data = data.device_que() 514 data.send(num_epochs=epochs) 515 time.sleep(0.5) 516 my_cb.end(run_context={}) 517 time.sleep(0.5) 518 519 520def test_callbacks_one_cb(): 521 """ 522 Feature: Callback 523 Description: Test callbacks with Begin, EpochBegin, EpochEnd, StepBegin, and StepEnd as the args 524 Expectation: Output is equal to the expected output 525 """ 526 logger.info("test_callbacks_one_cb") 527 528 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 529 events1 = [] 530 events2 = [] 531 events3 = [] 532 my_begin = Begin(events=events1, cb_id=1) 533 my_epoch_begin = EpochBegin(events=events2, cb_id=2) 534 my_epoch_end = EpochEnd(events=events3, cb_id=3) 535 my_step_begin = StepBegin(events=events3, cb_id=3) 536 my_step_end = StepEnd(events=events2, cb_id=2) 537 538 data = data.map(operations=(lambda x: x), callbacks=my_begin) 539 data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end]) 540 data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin]) 541 542 itr = data.create_tuple_iterator(num_epochs=2) 543 for _ in range(2): 544 for _ in itr: 545 pass 546 expected_events1 = [('begin_0_0_0', [1])] 547 expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]), 548 ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]), 549 ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]), 550 ('step_end_2_4_8', [2])] 551 expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]), 552 ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]), 553 ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]), 554 ('epoch_end_2_4_8', [3])] 555 events1.sort() 556 events2.sort() 557 events3.sort() 558 expected_events1.sort() 559 expected_events2.sort() 560 expected_events3.sort() 561 assert events1 == expected_events1 562 assert events2 == expected_events2 563 assert events3 == expected_events3 564 565 566def test_clear_callback(): 567 """ 568 Feature: Callback 569 Description: Test callback is removed for get_dataset_size and output_shape/type 570 Expectation: Output is equal to the expected output 571 """ 572 logger.info("test_clear_callback") 573 574 # this test case will test that callback is removed for get_dataset_size and output_shape/type 575 class FlagCallback(DSCallback): 576 def __init__(self): 577 super().__init__(step_size=1) 578 self.flag = False 579 self.row_cnt = 0 580 581 def ds_begin(self, ds_run_context): 582 # if callback isn't removed in getter pass, this function will be called 583 self.flag = True 584 585 def ds_step_begin(self, ds_run_context): 586 self.row_cnt += 1 587 588 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 589 cb = FlagCallback() 590 # make sure variables are properly initialized before testing 591 assert not cb.flag and cb.row_cnt == 0 592 data = data.map(operations=(lambda x: x), callbacks=cb) 593 assert data.get_dataset_size() == 4 594 assert data.output_shapes() == [[]] 595 # make sure callback is never called by checking flag and row_cnt 596 assert not cb.flag and cb.row_cnt == 0 597 for _ in data.create_dict_iterator(num_epochs=1): 598 pass 599 # this ensure that callback is indeed called 600 assert cb.flag and cb.row_cnt == 4 601 602 603if __name__ == '__main__': 604 test_callbacks_all_2cbs() 605 test_callbacks_all_methods() 606 test_callbacks_exceptions() 607 test_callbacks_repeat() 608 test_callbacks_sink_simulation() 609 test_callbacks_validations() 610 test_callbacks_var_step_size() 611 test_callbacks_non_sink_batch_size2() 612 test_callbacks_non_sink() 613 test_callbacks_one_cb() 614 test_callbacks_non_sink_mismatch_size() 615 test_callbacks_train_end() 616 test_clear_callback() 617