1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from builtins import range, super 16import time 17 18import pytest 19 20from mindspore import context 21from mindspore import log as logger 22from mindspore.dataset.callback import DSCallback, WaitedDSCallback 23from mindspore.train import Model 24from mindspore.train.callback import Callback 25 26import mindspore.dataset as ds 27import mindspore.nn as nn 28 29context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 30 31 32class BaseCallback(DSCallback): 33 def __init__(self, step_size=1, events=None, cb_id=0): 34 super().__init__(step_size) 35 self.events = events 36 self.cb_id = cb_id 37 38 def append(self, event_name, ds_run_context): 39 event = [event_name, ds_run_context.cur_epoch_num, 40 ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num] 41 event = '_'.join([str(e) for e in event]) 42 index = -1 43 for i, e in enumerate(self.events): 44 if e[0] == event: 45 index = i 46 break 47 if index != -1: 48 self.events[index][1].append(self.cb_id) 49 else: 50 self.events.append((event, [self.cb_id])) 51 52 53class Begin(BaseCallback): 54 def ds_begin(self, ds_run_context): 55 self.append("begin", ds_run_context) 56 57 58class EpochBegin(BaseCallback): 59 def ds_epoch_begin(self, ds_run_context): 60 self.append("epoch_begin", ds_run_context) 61 62 63class EpochEnd(BaseCallback): 64 def ds_epoch_end(self, ds_run_context): 65 self.append("epoch_end", ds_run_context) 66 67 68class StepBegin(BaseCallback): 69 def ds_step_begin(self, ds_run_context): 70 self.append("step_begin", ds_run_context) 71 72 73class StepEnd(BaseCallback): 74 def ds_step_end(self, ds_run_context): 75 self.append("step_end", ds_run_context) 76 77 78class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd): 79 pass 80 81 82def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): 83 events = [] 84 cb_id = list(range(map_num)) 85 86 def append(name, e, s): 87 event = [name, e + 1, s + 1, e * step_num * repeat + s + 1] 88 event = '_'.join([str(ev) for ev in event]) 89 events.append((event, cb_id)) 90 91 events.append(("begin_0_0_0", cb_id)) 92 for e in range(epoch_num): 93 append("epoch_begin", e, -1) 94 for s in range(step_num * repeat): 95 if s % step_size == 0: 96 append("step_begin", e, s) 97 append("step_end", e, s) 98 append("epoch_end", e, step_num * repeat - 1) 99 return events 100 101 102def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): 103 events = [] 104 105 arr = list(range(1, steps + 1)) 106 data = ds.NumpySlicesDataset(arr, shuffle=False) 107 108 my_cb = MyDSCallback(step_size=step_size, events=events) 109 110 data = data.map(operations=(lambda x: x), callbacks=my_cb) 111 if repeat != 1: 112 if repeat % 2 == 0 and repeat != 2: 113 data = data.repeat(2) 114 data = data.map(operations=(lambda x: x)) 115 data = data.repeat(repeat // 2) 116 else: 117 data = data.repeat(repeat) 118 itr = data.create_tuple_iterator(num_epochs=epochs) 119 for _ in range(epochs): 120 for _ in itr: 121 pass 122 123 expected_events = generate_expected(epochs, steps, step_size, 1, repeat) 124 assert expected_events == events 125 126 127def build_test_case_2cbs(epochs, steps): 128 events1 = [] 129 events2 = [] 130 my_cb1 = MyDSCallback(events=events1) 131 my_cb2 = MyDSCallback(events=events2) 132 133 arr = list(range(1, steps + 1)) 134 data = ds.NumpySlicesDataset(arr, shuffle=False) 135 136 data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2]) 137 138 itr = data.create_tuple_iterator(num_epochs=epochs) 139 for _ in range(epochs): 140 for _ in itr: 141 pass 142 143 expected_events = generate_expected(epochs, steps) 144 assert expected_events == events1 145 assert expected_events == events2 146 147 148def build_test_case_2maps(epochs, steps): 149 events = [] 150 my_cb1 = MyDSCallback(events=events, cb_id=0) 151 my_cb2 = MyDSCallback(events=events, cb_id=1) 152 153 arr = list(range(1, steps + 1)) 154 data = ds.NumpySlicesDataset(arr, shuffle=False) 155 156 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 157 data = data.map(operations=(lambda x: x), callbacks=my_cb2) 158 159 itr = data.create_tuple_iterator(num_epochs=epochs) 160 for _ in range(epochs): 161 for _ in itr: 162 pass 163 164 expected_events = generate_expected(epochs, steps, map_num=2) 165 166 assert expected_events[1:] == events[1:] 167 168 for event in events: 169 assert len(event) == 2 170 event, cb_ids = event 171 if event != "begin_0_0_0": 172 assert cb_ids[0] == 0 173 assert cb_ids[1] == 1 174 175 176def test_callbacks_all_methods(): 177 logger.info("test_callbacks_all_methods") 178 179 build_test_case_1cb(1, 1) 180 build_test_case_1cb(1, 2) 181 build_test_case_1cb(1, 3) 182 build_test_case_1cb(1, 4) 183 184 build_test_case_1cb(2, 1) 185 build_test_case_1cb(2, 2) 186 build_test_case_1cb(2, 3) 187 build_test_case_1cb(2, 4) 188 189 build_test_case_1cb(3, 1) 190 build_test_case_1cb(3, 2) 191 build_test_case_1cb(3, 3) 192 build_test_case_1cb(3, 4) 193 194 195def test_callbacks_var_step_size(): 196 logger.info("test_callbacks_var_step_size") 197 198 build_test_case_1cb(1, 2, 2) 199 build_test_case_1cb(1, 3, 2) 200 build_test_case_1cb(1, 4, 2) 201 202 build_test_case_1cb(2, 2, 2) 203 build_test_case_1cb(2, 3, 2) 204 build_test_case_1cb(2, 4, 2) 205 206 build_test_case_1cb(3, 2, 2) 207 build_test_case_1cb(3, 3, 2) 208 build_test_case_1cb(3, 4, 2) 209 210 211def test_callbacks_all_2cbs(): 212 logger.info("test_callbacks_all_2cbs") 213 214 build_test_case_2cbs(4, 1) 215 build_test_case_2cbs(4, 2) 216 build_test_case_2cbs(4, 3) 217 build_test_case_2cbs(4, 4) 218 219 220class MyWaitedCallback(WaitedDSCallback): 221 def __init__(self, events, step_size=1): 222 super().__init__(step_size) 223 self.events = events 224 225 def sync_epoch_begin(self, train_run_context, ds_run_context): 226 event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" 227 self.events.append(event) 228 229 def sync_step_begin(self, train_run_context, ds_run_context): 230 event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" 231 self.events.append(event) 232 233 234class MyMSCallback(Callback): 235 def __init__(self, events): 236 self.events = events 237 238 def epoch_end(self, run_context): 239 cb_params = run_context.original_args() 240 event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" 241 self.events.append(event) 242 243 def step_end(self, run_context): 244 cb_params = run_context.original_args() 245 event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" 246 self.events.append(event) 247 248 249class Net(nn.Cell): 250 def construct(self, x, y): 251 return x 252 253 254def test_callbacks_non_sink(): 255 logger.info("test_callbacks_non_sink") 256 257 events = [] 258 my_cb1 = MyWaitedCallback(events, 1) 259 my_cb2 = MyMSCallback(events) 260 arr = [1, 2, 3, 4] 261 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 262 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 263 264 net = Net() 265 model = Model(net) 266 267 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 268 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', 269 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', 270 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', 271 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', 272 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', 273 'ms_step_end_2_8', 'ms_epoch_end_2_8'] 274 275 assert events[:18] == expected_synced_events 276 277 278def test_callbacks_non_sink_batch_size2(): 279 logger.info("test_callbacks_non_sink_batch_size2") 280 281 events = [] 282 my_cb1 = MyWaitedCallback(events, 2) 283 my_cb2 = MyMSCallback(events) 284 arr = [1, 2, 3, 4] 285 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 286 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 287 data = data.batch(2) 288 net = Net() 289 model = Model(net) 290 291 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 292 293 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3', 294 'ms_step_end_1_2', 295 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4', 296 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7', 297 'ms_step_end_2_4', 'ms_epoch_end_2_4'] 298 299 assert events[:10] == expected_synced_events 300 301 302def test_callbacks_non_sink_mismatch_size(): 303 logger.info("test_callbacks_non_sink_mismatch_size") 304 default_timeout = ds.config.get_callback_timeout() 305 ds.config.set_callback_timeout(1) 306 307 events = [] 308 my_cb1 = MyWaitedCallback(events, 2) 309 my_cb2 = MyMSCallback(events) 310 arr = [1, 2, 3, 4] 311 data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) 312 data = data.map(operations=(lambda x: x), callbacks=my_cb1) 313 data = data.batch(3) 314 net = Net() 315 model = Model(net) 316 with pytest.raises(Exception) as err: 317 model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) 318 assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value) 319 320 ds.config.set_callback_timeout(default_timeout) 321 322 323def test_callbacks_validations(): 324 logger.info("test_callbacks_validations") 325 326 with pytest.raises(Exception) as err: 327 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 328 data.map(operations=(lambda x: x), callbacks=0) 329 assert "Argument callbacks with value 0 is not " in str(err.value) 330 331 with pytest.raises(Exception) as err: 332 my_cb1 = MyDSCallback() 333 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 334 data.map(operations=(lambda x: x), callbacks=[my_cb1, 0]) 335 assert "Argument callbacks[1] with value 0 is not " in str(err.value) 336 337 with pytest.raises(Exception) as err: 338 class BadCB(DSCallback): 339 pass 340 341 my_cb = BadCB() 342 343 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 344 data = data.map(operations=(lambda x: x), callbacks=my_cb) 345 for _ in data: 346 pass 347 assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value) 348 349 350def test_callbacks_sink_simulation(): 351 logger.info("test_callback_sink_simulation") 352 353 events = [] 354 epochs = 2 355 my_cb = MyWaitedCallback(events, 1) 356 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 357 data = data.map(operations=(lambda x: x), callbacks=my_cb) 358 data = data.to_device() 359 data.send(num_epochs=epochs) 360 for e in range(epochs): 361 for s in range(4): 362 time.sleep(0.5) 363 events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}") 364 my_cb.step_end(run_context=0) 365 events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}") 366 my_cb.epoch_end(run_context=0) 367 expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', 368 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', 369 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', 370 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', 371 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', 372 'ms_step_end_2_8', 'ms_epoch_end_2_8'] 373 374 assert events == expected_synced_events 375 376 377def test_callbacks_repeat(): 378 logger.info("test_callbacks_repeat") 379 380 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) 381 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3) 382 build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) 383 build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) 384 385 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) 386 build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4) 387 build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8) 388 build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16) 389 390 391def test_callbacks_exceptions(): 392 logger.info("test_callbacks_exceptions") 393 394 class BadCB(DSCallback): 395 def ds_begin(self, ds_run_context): 396 raise RuntimeError("Bad begin") 397 398 with pytest.raises(Exception) as err: 399 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 400 data = data.map(operations=(lambda x: x), callbacks=BadCB()) 401 for _ in data: 402 pass 403 assert "RuntimeError: Bad begin" in str(err.value) 404 405 406def test_callbacks_train_end(): 407 logger.info("test_callback_sink_simulation") 408 # No asserts are needed, just test there is no deadlock or exceptions 409 events = [] 410 epochs = 2 411 412 my_cb = MyWaitedCallback(events, 1) 413 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 414 data = data.map(operations=(lambda x: x), callbacks=[my_cb]) 415 data = data.to_device() 416 data.send(num_epochs=epochs) 417 time.sleep(0.5) 418 my_cb.end(run_context={}) 419 time.sleep(0.5) 420 421 422def test_callbacks_one_cb(): 423 logger.info("test_callbacks_one_cb") 424 425 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 426 events1 = [] 427 events2 = [] 428 events3 = [] 429 my_begin = Begin(events=events1, cb_id=1) 430 my_epoch_begin = EpochBegin(events=events2, cb_id=2) 431 my_epoch_end = EpochEnd(events=events3, cb_id=3) 432 my_step_begin = StepBegin(events=events3, cb_id=3) 433 my_step_end = StepEnd(events=events2, cb_id=2) 434 435 data = data.map(operations=(lambda x: x), callbacks=my_begin) 436 data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end]) 437 data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin]) 438 439 itr = data.create_tuple_iterator(num_epochs=2) 440 for _ in range(2): 441 for _ in itr: 442 pass 443 expected_events1 = [('begin_0_0_0', [1])] 444 expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]), 445 ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]), 446 ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]), 447 ('step_end_2_4_8', [2])] 448 expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]), 449 ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]), 450 ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]), 451 ('epoch_end_2_4_8', [3])] 452 assert events1 == expected_events1 453 assert events2 == expected_events2 454 assert events3 == expected_events3 455 456 457def test_clear_callback(): 458 logger.info("test_clear_callback") 459 460 # this test case will test that callback is removed for get_dataset_size and output_shape/type 461 class FlagCallback(DSCallback): 462 def __init__(self): 463 super().__init__(step_size=1) 464 self.flag = False 465 self.row_cnt = 0 466 467 def ds_begin(self, ds_run_context): 468 # if callback isn't removed in getter pass, this function will be called 469 self.flag = True 470 471 def ds_step_begin(self, ds_run_context): 472 self.row_cnt += 1 473 474 data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) 475 cb = FlagCallback() 476 # make sure variables are properly initialized before testing 477 assert not cb.flag and cb.row_cnt == 0 478 data = data.map(operations=(lambda x: x), callbacks=cb) 479 assert data.get_dataset_size() == 4 480 assert data.output_shapes() == [[]] 481 # make sure callback is never called by checking flag and row_cnt 482 assert not cb.flag and cb.row_cnt == 0 483 for _ in data.create_dict_iterator(num_epochs=1): 484 pass 485 # this ensure that callback is indeed called 486 assert cb.flag and cb.row_cnt == 4 487 488 489if __name__ == '__main__': 490 test_callbacks_all_2cbs() 491 test_callbacks_all_methods() 492 test_callbacks_exceptions() 493 test_callbacks_repeat() 494 test_callbacks_sink_simulation() 495 test_callbacks_validations() 496 test_callbacks_var_step_size() 497 test_callbacks_non_sink_batch_size2() 498 test_callbacks_non_sink() 499 test_callbacks_one_cb() 500 test_callbacks_non_sink_mismatch_size() 501 test_callbacks_train_end() 502 test_clear_callback() 503