1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Checkpoint related classes and functions.""" 16from __future__ import absolute_import 17 18import os 19import stat 20import time 21 22import threading 23import mindspore.context as context 24from mindspore import log as logger 25from mindspore import nn 26from mindspore import _checkparam as Validator 27from mindspore.train._utils import _make_directory 28from mindspore.train.serialization import save_checkpoint, _save_graph 29from mindspore.parallel._cell_wrapper import destroy_allgather_cell 30from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context 31from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context 32from mindspore.parallel._utils import _get_device_num 33from mindspore.communication.management import get_rank 34from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy 35from mindspore.train.callback._callback import Callback, set_cur_net 36from mindspore.common.tensor import Tensor 37from mindspore.common.parameter import Parameter 38from mindspore.common.generator import Generator 39from mindspore.common.api import _cell_graph_executor 40from mindspore._c_expression import _collect_host_info 41 42 43_cur_dir = os.getcwd() 44SAVE_DIR = _cur_dir 45_info_list = ["epoch_num", "step_num"] 46 47 48def _get_dp_tp_from_redundancy(redundancy_tuple): 49 """From redundancy get dp and tp""" 50 dp = [] 51 tp = [] 52 for dp_value in redundancy_tuple: 53 dp.append(list(dp_value)) 54 for i in range(len(redundancy_tuple[0])): 55 tp.append([v[i] for v in redundancy_tuple]) 56 return dp, tp 57 58 59def _get_dp_tp_from_layout(parameter_redundancy_dict): 60 """From layout dict get dp and tp""" 61 tp = [] 62 dp = [] 63 value_len = 0 64 for _, value in parameter_redundancy_dict.items(): 65 if len(value) > value_len: 66 value_len = len(value) 67 dp, tp = _get_dp_tp_from_redundancy(value) 68 return dp, tp 69 70 71def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False): 72 """Check if there is a file with the same name.""" 73 if callable(prefix) or callable(directory): 74 return prefix 75 files = os.listdir(directory) 76 suffix_num = 0 77 pre_len = len(prefix) 78 for filename in files: 79 name_ext = os.path.splitext(filename) 80 if exception and filename[-16:] != "_breakpoint.ckpt": 81 continue 82 if not exception and (name_ext[-1] != ".ckpt" or filename[-16:] == "_breakpoint.ckpt"): 83 continue 84 # find same prefix file 85 if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): 86 # add the max suffix + 1 87 index = filename[pre_len:].find("-") 88 if index == 0: 89 suffix_num = max(suffix_num, 1) 90 elif index != -1: 91 num = filename[pre_len+1:pre_len+index] 92 if num.isdigit(): 93 suffix_num = max(suffix_num, int(num)+1) 94 95 if suffix_num != 0: 96 prefix = f'{prefix}_{suffix_num}' 97 98 return prefix 99 100 101class CheckpointConfig: 102 """ 103 The configuration of model checkpoint. 104 105 Note: 106 - During the training process, if dataset is transmitted through the data channel, 107 it is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size. 108 Otherwise, the time to save the checkpoint may be biased. 109 It is recommended to set only one save strategy and one keep strategy at the same time. 110 If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set, 111 `save_checkpoint_seconds` will be invalid. 112 If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set, 113 `keep_checkpoint_per_n_minutes` will be invalid. 114 - The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously. 115 116 Args: 117 save_checkpoint_steps (int): Steps to save checkpoint. Default: ``1`` . 118 save_checkpoint_seconds (int): Seconds to save checkpoint. 119 Can't be used with save_checkpoint_steps at the same time. Default: ``0`` . 120 keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: ``5`` . 121 keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes. 122 Can't be used with keep_checkpoint_max at the same time. Default: ``0`` . 123 integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario. 124 Integrated save function is only supported in automatic parallel scene, not supported 125 in manual parallel. Default: ``True`` . 126 async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: ``False`` . 127 saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation 128 with the network in training, the initial value of saved_network will be saved. Default: ``None`` . 129 append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and 130 dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor. 131 Default: ``None`` . 132 enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption 133 is not required. Default: ``None`` . 134 enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption 135 mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: ``'AES-GCM'`` . 136 exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` . 137 crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation 138 result to the end of ckpt. Default: ``False`` . 139 kwargs (dict): Configuration options dictionary. 140 141 Raises: 142 ValueError: If input parameter is not the correct type. 143 144 Examples: 145 >>> from mindspore import nn 146 >>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint 147 >>> 148 >>> # Define the network structure of LeNet5. Refer to 149 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 150 >>> net = LeNet5() 151 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 152 >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) 153 >>> model = Model(net, loss_fn=loss, optimizer=optim) 154 >>> # Create the dataset taking MNIST as an example. Refer to 155 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 156 >>> dataset = create_dataset() 157 >>> config = CheckpointConfig(save_checkpoint_seconds=100, keep_checkpoint_per_n_minutes=5, saved_network=net) 158 >>> config.save_checkpoint_steps 159 1 160 >>> config.save_checkpoint_seconds 161 >>> config.keep_checkpoint_max 162 5 163 >>> config.keep_checkpoint_per_n_minutes 164 >>> config.integrated_save 165 True 166 >>> config.async_save 167 False 168 >>> config.saved_network 169 >>> config.enc_key 170 >>> config.enc_mode 171 'AES-GCM' 172 >>> config.append_dict 173 >>> config.get_checkpoint_policy 174 >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config) 175 >>> model.train(10, dataset, callbacks=ckpoint_cb) 176 """ 177 178 def __init__(self, 179 save_checkpoint_steps=1, 180 save_checkpoint_seconds=0, 181 keep_checkpoint_max=5, 182 keep_checkpoint_per_n_minutes=0, 183 integrated_save=True, 184 async_save=False, 185 saved_network=None, 186 append_info=None, 187 enc_key=None, 188 enc_mode='AES-GCM', 189 exception_save=False, 190 crc_check=False, 191 **kwargs): 192 193 if save_checkpoint_steps is not None: 194 save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) 195 if save_checkpoint_seconds is not None: 196 save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds) 197 if keep_checkpoint_max is not None: 198 keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max) 199 if keep_checkpoint_per_n_minutes is not None: 200 keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) 201 202 if saved_network is not None and not isinstance(saved_network, nn.Cell): 203 raise TypeError(f"For 'CheckpointConfig', the type of 'saved_network' must be None or Cell, " 204 f"but got {str(type(saved_network))}.") 205 206 if not save_checkpoint_steps and not save_checkpoint_seconds and \ 207 not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: 208 raise ValueError("For 'CheckpointConfig', the input arguments 'save_checkpoint_steps', " 209 "'save_checkpoint_seconds', " 210 "'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.") 211 Validator.check_bool(exception_save) 212 self.exception_save = exception_save 213 214 self._save_checkpoint_steps = save_checkpoint_steps 215 self._save_checkpoint_seconds = save_checkpoint_seconds 216 if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: 217 self._save_checkpoint_seconds = None 218 219 self._keep_checkpoint_max = keep_checkpoint_max 220 self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes 221 if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: 222 self._keep_checkpoint_per_n_minutes = None 223 else: 224 if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: 225 self._keep_checkpoint_max = 1 226 227 self._integrated_save = Validator.check_bool(integrated_save) 228 self._async_save = Validator.check_bool(async_save) 229 self._saved_network = saved_network 230 self._append_dict = self._handle_append_info(append_info) 231 self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) 232 self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) 233 self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool) 234 self._map_param_inc = kwargs.get('incremental', False) 235 self.enable_redundance = kwargs.get('enable_redundance', False) 236 237 @property 238 def save_checkpoint_steps(self): 239 """ 240 Get the value of steps to save checkpoint. 241 242 Returns: 243 int, steps to save checkpoint. 244 """ 245 return self._save_checkpoint_steps 246 247 @property 248 def save_checkpoint_seconds(self): 249 """Get the value of _save_checkpoint_seconds. 250 251 Returns: 252 int, seconds to save the checkpoint file. 253 """ 254 return self._save_checkpoint_seconds 255 256 @property 257 def keep_checkpoint_max(self): 258 """ 259 Get the value of maximum number of checkpoint files can be saved. 260 261 Returns: 262 int, Maximum number of checkpoint files can be saved. 263 """ 264 return self._keep_checkpoint_max 265 266 @property 267 def keep_checkpoint_per_n_minutes(self): 268 """ 269 Get the value of save the checkpoint file every n minutes. 270 271 Returns: 272 Int, save the checkpoint file every n minutes. 273 """ 274 return self._keep_checkpoint_per_n_minutes 275 276 @property 277 def integrated_save(self): 278 """ 279 Get the value of whether to merge and save the split Tensor in the automatic parallel scenario. 280 281 Returns: 282 bool, whether to merge and save the split Tensor in the automatic parallel scenario. 283 """ 284 return self._integrated_save 285 286 @property 287 def async_save(self): 288 """ 289 Get the value of whether asynchronous execution saves the checkpoint to a file. 290 291 Returns: 292 bool, whether asynchronous execution saves the checkpoint to a file. 293 """ 294 return self._async_save 295 296 @property 297 def saved_network(self): 298 """ 299 Get the value of network to be saved in checkpoint file. 300 301 Returns: 302 Cell, network to be saved in checkpoint file. 303 """ 304 return self._saved_network 305 306 @property 307 def enc_key(self): 308 """ 309 Get the value of byte type key used for encryption. 310 311 Returns: 312 (None, bytes), byte type key used for encryption. 313 """ 314 return self._enc_key 315 316 @property 317 def enc_mode(self): 318 """ 319 Get the value of the encryption mode. 320 321 Returns: 322 str, encryption mode. 323 """ 324 return self._enc_mode 325 326 @property 327 def crc_check(self): 328 """ 329 Get the value of the whether to enable crc check. 330 331 Returns: 332 bool, whether to enable crc check. 333 """ 334 return self._crc_check 335 336 @property 337 def append_dict(self): 338 """ 339 Get the value of information dict saved to checkpoint file. 340 341 Returns: 342 dict, the information saved to checkpoint file. 343 """ 344 return self._append_dict 345 346 @property 347 def map_param_inc(self): 348 """ 349 Get the value of whether to save map Parameter incrementally. 350 351 Returns: 352 bool, whether to save map Parameter incrementally. 353 """ 354 return self._map_param_inc 355 356 def get_checkpoint_policy(self): 357 """ 358 Get the policy of checkpoint. 359 360 Returns: 361 dict, the information of checkpoint policy. 362 """ 363 checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, 364 'save_checkpoint_seconds': self.save_checkpoint_seconds, 365 'keep_checkpoint_max': self.keep_checkpoint_max, 366 'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes, 367 'saved_network': self.saved_network} 368 369 return checkpoint_policy 370 371 @staticmethod 372 def _handle_append_info(append_info): 373 """Handle ckpt append info.""" 374 if append_info is None or append_info == []: 375 return None 376 if not isinstance(append_info, list): 377 raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' must be list," 378 f"but got {str(type(append_info))}.") 379 handle_append_info = {} 380 if "epoch_num" in append_info: 381 handle_append_info["epoch_num"] = 0 382 if "step_num" in append_info: 383 handle_append_info["step_num"] = 0 384 if "random_op" in append_info: 385 handle_append_info["random_op"] = 0 386 dict_num = 0 387 for element in append_info: 388 if not isinstance(element, str) and not isinstance(element, dict): 389 raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' element must be str or dict," 390 f"but got {str(type(element))}.") 391 if isinstance(element, str) and element not in _info_list: 392 raise ValueError(f"For 'CheckpointConfig', the value of element in the argument 'append_info' " 393 f"must be in {_info_list}, " 394 f"but got {element}.") 395 if isinstance(element, dict): 396 dict_num += 1 397 if dict_num > 1: 398 raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict, " 399 "but got {dict_num}") 400 for key, value in element.items(): 401 if isinstance(key, str) and isinstance(value, 402 (int, float, bool, str, Parameter, Tensor, Generator)): 403 handle_append_info[key] = value 404 else: 405 raise TypeError(f"For 'CheckpointConfig', the key type of the dict 'append_info' " 406 f"must be string, the value type must be int or float or bool, " 407 f"but got key type {type(key)}, value type {type(value)}") 408 409 return handle_append_info 410 411 412class ModelCheckpoint(Callback): 413 """ 414 The checkpoint callback class. 415 416 It is called to combine with train process and save the model and network parameters after training. 417 418 Note: 419 In the distributed training scenario, please specify different directories for each training process 420 to save the checkpoint file. Otherwise, the training may fail. 421 If this callback is used in the `model` function, the checkpoint file will saved 422 parameters of the optimizer by default. 423 424 Args: 425 prefix (Union[str, callable object]): The prefix name or callable object to generate name of checkpoint files. 426 Default: ``'CKP'`` . 427 directory (Union[str, callable object]): The folder path where the checkpoint is stored, or the callable object 428 used to generate the path. By default, the file is saved in the current directory. 429 Default: ``None`` . 430 config (CheckpointConfig): Checkpoint strategy configuration. Default: ``None`` . 431 432 Raises: 433 ValueError: If `prefix` is not str or contains the '/' character and is not a callable object. 434 ValueError: If `directory` is not str and is not a callable object. 435 TypeError: If the config is not CheckpointConfig type. 436 437 Examples: 438 >>> import numpy as np 439 >>> import mindspore.dataset as ds 440 >>> from mindspore import nn 441 >>> from mindspore.train import Model, ModelCheckpoint 442 >>> 443 >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} 444 >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) 445 >>> net = nn.Dense(10, 5) 446 >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 447 >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) 448 >>> ckpt_callback = ModelCheckpoint(prefix="myckpt") 449 >>> model = Model(network=net, optimizer=opt, loss_fn=crit) 450 >>> model.train(2, train_dataset, callbacks=[ckpt_callback]) 451 """ 452 453 def __init__(self, prefix='CKP', directory=None, config=None): 454 super(ModelCheckpoint, self).__init__() 455 self._latest_ckpt_file_name = "" 456 self._init_time = time.time() 457 self._last_time = time.time() 458 self._last_time_for_keep = time.time() 459 self._last_triggered_step = 0 460 """a callable for users to set self-defined prefix.""" 461 self._prefix_func = None 462 """a callable for users to set self-defined directory.""" 463 self._directory_func = None 464 465 if not callable(prefix) and (not isinstance(prefix, str) or prefix.find('/') >= 0): 466 raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " 467 "for checkpoint file name is invalid, it must be " 468 "callable or string that does not contain '/', but got {}.".format(prefix)) 469 self._prefix = prefix 470 self._exception_prefix = prefix 471 472 if directory is not None: 473 if callable(directory): 474 self._directory_func = directory 475 else: 476 self._directory = _make_directory(directory) 477 else: 478 self._directory = _cur_dir 479 480 if callable(prefix): 481 self._prefix_func = prefix 482 483 if _get_recovery_context("enable_recovery"): 484 _set_recovery_context(ckpt_path=self._directory) 485 486 if config is None: 487 self._config = CheckpointConfig() 488 else: 489 if not isinstance(config, CheckpointConfig): 490 raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " 491 "'CheckpointConfig', " 492 "but got {}.".format(type(config))) 493 self._config = config 494 495 self._aiturbo_init_flag = os.getenv("AITURBO") == "1" 496 # get existing checkpoint files 497 if self._aiturbo_init_flag: 498 import aiturbo 499 self._manager = aiturbo.CheckpointShmManager() 500 else: 501 self._manager = CheckpointManager() 502 if not callable(directory) and not callable(prefix): 503 self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) 504 self._append_dict = self._config.append_dict or {} 505 self._append_epoch_num = self._append_dict.get("epoch_num") if "epoch_num" in self._append_dict else 0 506 self._append_step_num = self._append_dict.get("step_num") if "step_num" in self._append_dict else 0 507 self._graph_saved = False 508 self._need_flush_from_cache = True 509 self._map_param_inc = self._config.map_param_inc 510 511 def step_end(self, run_context): 512 """ 513 Save the checkpoint at the end of step. 514 515 Args: 516 run_context (RunContext): Context of the train running. 517 """ 518 cb_params = run_context.original_args() 519 if self._aiturbo_init_flag: 520 import aiturbo 521 ckpt_storage_path = self._directory 522 rank_id = get_rank() 523 stage_num = _get_auto_parallel_context("pipeline_stages") 524 stage_rank_num = _get_device_num() // stage_num 525 param_layout = cb_params.train_network.parameter_layout_dict 526 if not param_layout: 527 layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, "stage_layout": None} 528 aiturbo.init(ckpt_storage_path, rank_id, layout, None, False, None) 529 else: 530 device_num = _get_device_num() 531 chunk_size = device_num // stage_num 532 initial_rank = (rank_id // chunk_size) * chunk_size 533 param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank) 534 dp, _ = _get_dp_tp_from_layout(param_redundancy_dict) 535 layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, 536 "stage_layout": param_redundancy_dict} 537 single_params = remove_param_redundancy(param_redundancy_dict) 538 single_params = {device_id: list(params) for device_id, params in single_params.items()} 539 aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, self._config.enable_redundance, dp) 540 self._aiturbo_init_flag = False 541 if self._prefix_func: 542 self._prefix = self._prefix_func(cb_params) 543 if not isinstance(self._prefix, str) or self._prefix.find('/') >= 0: 544 raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " 545 "for checkpoint file name is callable, it must return a " 546 "string that does not contain '/', but got {}.".format(self._prefix)) 547 if self._directory_func: 548 self._directory = self._directory_func(cb_params) 549 _collect_host_info("Callback", "ModelCheckpoint", "step_end", level=1) 550 # In disaster recovery scenario, the training process may be rolled back to the last step where 551 # the ckpt was successfully saved, so the _last_triggered_step should be updated. 552 if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None: 553 self._last_triggered_step = cb_params.last_save_ckpt_step 554 cb_params.last_save_ckpt_step = None 555 556 _make_directory(self._directory) 557 # save graph (only once) 558 if not self._graph_saved: 559 graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') 560 if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: 561 os.remove(graph_file_name) 562 _save_graph(cb_params.train_network, graph_file_name) 563 self._graph_saved = True 564 thread_list = threading.enumerate() 565 for thread in thread_list: 566 if thread.getName() == "asyn_save_ckpt": 567 thread.join() 568 self._save_ckpt(cb_params) 569 570 def end(self, run_context): 571 """ 572 Save the last checkpoint after training finished. 573 574 Args: 575 run_context (RunContext): Context of the train running. 576 """ 577 cb_params = run_context.original_args() 578 _collect_host_info("Callback", "ModelCheckpoint", "end", level=1) 579 _to_save_last_ckpt = True 580 581 self._save_ckpt(cb_params, _to_save_last_ckpt) 582 583 thread_list = threading.enumerate() 584 for thread in thread_list: 585 if thread.getName() == "asyn_save_ckpt": 586 thread.join() 587 588 destroy_allgather_cell() 589 590 def _check_save_ckpt(self, cb_params, force_to_save): 591 """Check whether save checkpoint files or not.""" 592 if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: 593 if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ 594 or force_to_save is True: 595 return True 596 elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: 597 self._cur_time = time.time() 598 if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: 599 self._last_time = self._cur_time 600 return True 601 602 return False 603 604 def _save_ckpt(self, cb_params, force_to_save=False): 605 """Save checkpoint files.""" 606 if cb_params.cur_step_num == self._last_triggered_step: 607 return 608 609 # if param is cache enable, flush data from cache to host before save_ckpt 610 if self._need_flush_from_cache: 611 self._flush_from_cache(cb_params) 612 613 save_ckpt = self._check_save_ckpt(cb_params, force_to_save) 614 step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) 615 616 if save_ckpt: 617 if self._prefix_func: 618 cur_ckpoint_file = self._prefix + ".ckpt" 619 else: 620 cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ 621 + str(step_num_in_epoch) + ".ckpt" 622 # update checkpoint file list. 623 self._manager.update_ckpoint_filelist(self._directory, self._prefix) 624 # keep checkpoint files number equal max number. 625 if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: 626 self._manager.remove_oldest_ckpoint_file() 627 elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: 628 self._cur_time_for_keep = time.time() 629 if (self._cur_time_for_keep - self._last_time_for_keep) \ 630 < self._config.keep_checkpoint_per_n_minutes * 60: 631 self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, 632 self._cur_time_for_keep) 633 634 # generate the new checkpoint file and rename it. 635 global SAVE_DIR 636 SAVE_DIR = self._directory 637 cur_file = os.path.join(self._directory, cur_ckpoint_file) 638 self._last_time_for_keep = time.time() 639 self._last_triggered_step = cb_params.cur_step_num 640 641 # TODO(MS_DISABLE_REF_MODE): Delete when remove MS_DISABLE_REF_MODE env. 642 if context.get_context("enable_ge") and os.getenv('MS_DISABLE_REF_MODE') \ 643 and context.get_context("mode") == context.GRAPH_MODE: 644 set_cur_net(cb_params.train_network) 645 cb_params.train_network.add_flags(ge_sync_data=True) 646 _cell_graph_executor(cb_params.train_network, phase='save') 647 if "epoch_num" in self._append_dict: 648 self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num 649 if "step_num" in self._append_dict: 650 self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num 651 network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network 652 if os.getenv("AITURBO") == "1": 653 save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, 654 self._append_dict, self._config.enc_key, self._config.enc_mode, 655 crc_check=self._config.crc_check, incremental=self._map_param_inc, 656 global_step_num=cb_params.cur_step_num) 657 else: 658 save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, 659 self._append_dict, self._config.enc_key, self._config.enc_mode, 660 crc_check=self._config.crc_check, incremental=self._map_param_inc) 661 662 self._latest_ckpt_file_name = cur_file 663 664 def _flush_from_cache(self, cb_params): 665 """Flush cache data to host if tensor is cache enable.""" 666 has_cache_params = False 667 params = cb_params.train_network.get_parameters() 668 for param in params: 669 if param.cache_enable: 670 has_cache_params = True 671 Tensor(param).flush_from_cache() 672 if not has_cache_params: 673 self._need_flush_from_cache = False 674 675 @property 676 def latest_ckpt_file_name(self): 677 """Return the latest checkpoint path and file name.""" 678 return self._latest_ckpt_file_name 679 680 @property 681 def _get_save_checkpoint_steps(self): 682 """Return save ckpt steps""" 683 return self._config.save_checkpoint_steps 684 685 @property 686 def _get_last_trigger_step(self): 687 """Return last triggered steps""" 688 return self._last_triggered_step 689 690 691class CheckpointManager: 692 """Manage checkpoint files according to train_config of checkpoint.""" 693 694 def __init__(self): 695 self._ckpoint_filelist = [] 696 697 @property 698 def ckpoint_filelist(self): 699 """Get all the related checkpoint files managed here.""" 700 return self._ckpoint_filelist 701 702 @property 703 def ckpoint_num(self): 704 """Get the number of the related checkpoint files managed here.""" 705 return len(self._ckpoint_filelist) 706 707 def update_ckpoint_filelist(self, directory, prefix): 708 """Update the checkpoint file list.""" 709 self._ckpoint_filelist = [] 710 files = os.listdir(directory) 711 for filename in files: 712 if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): 713 mid_name = filename[len(prefix):-5] 714 flag = not (True in [char.isalpha() for char in mid_name]) 715 if flag: 716 self._ckpoint_filelist.append(os.path.join(directory, filename)) 717 718 def remove_ckpoint_file(self, file_name): 719 """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" 720 try: 721 os.chmod(file_name, stat.S_IWRITE) 722 os.remove(file_name) 723 self._ckpoint_filelist.remove(file_name) 724 except OSError: 725 logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) 726 except ValueError: 727 logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) 728 729 def remove_oldest_ckpoint_file(self): 730 """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" 731 ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) 732 self.remove_ckpoint_file(ckpoint_files[0]) 733 734 def keep_one_ckpoint_per_minutes(self, minutes, cur_time): 735 """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" 736 del_list = [] 737 oldest_file = '' 738 oldest_time = cur_time 739 for ck_file in self._ckpoint_filelist: 740 modify_time = os.path.getmtime(ck_file) 741 if cur_time - modify_time < 60 * minutes: 742 del_list.append(ck_file) 743 744 if modify_time < oldest_time: 745 oldest_time = modify_time 746 oldest_file = ck_file 747 748 for mv_file in del_list: 749 if mv_file == oldest_file: 750 continue 751 self.remove_ckpoint_file(mv_file) 752