1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Checkpoint related classes and functions.""" 16 17import os 18import stat 19import time 20 21import threading 22import mindspore.context as context 23from mindspore import log as logger 24from mindspore import nn 25from mindspore._checkparam import Validator 26from mindspore.train._utils import _make_directory 27from mindspore.train.serialization import save_checkpoint, _save_graph 28from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank 29from mindspore.parallel._cell_wrapper import destroy_allgather_cell 30from ._callback import Callback, set_cur_net 31from ...common.tensor import Tensor 32 33_cur_dir = os.getcwd() 34_save_dir = _cur_dir 35_info_list = ["epoch_num", "step_num"] 36 37 38def _chg_ckpt_file_name_if_same_exist(directory, prefix): 39 """Check if there is a file with the same name.""" 40 files = os.listdir(directory) 41 suffix_num = 0 42 pre_len = len(prefix) 43 for filename in files: 44 name_ext = os.path.splitext(filename) 45 if name_ext[-1] != ".ckpt": 46 continue 47 # find same prefix file 48 if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): 49 # add the max suffix + 1 50 index = filename[pre_len:].find("-") 51 if index == 0: 52 suffix_num = max(suffix_num, 1) 53 elif index != -1: 54 num = filename[pre_len+1:pre_len+index] 55 if num.isdigit(): 56 suffix_num = max(suffix_num, int(num)+1) 57 58 if suffix_num != 0: 59 prefix = prefix + "_" + str(suffix_num) 60 61 return prefix 62 63 64class CheckpointConfig: 65 """ 66 The configuration of model checkpoint. 67 68 Note: 69 During the training process, if dataset is transmitted through the data channel, 70 It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size. 71 Otherwise, the time to save the checkpoint may be biased. 72 It is recommended to set only one save strategy and one keep strategy at the same time. 73 If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set, 74 `save_checkpoint_seconds` will be invalid. 75 If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set, 76 `keep_checkpoint_per_n_minutes` will be invalid. 77 78 Args: 79 save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. 80 save_checkpoint_seconds (int): Seconds to save checkpoint. 81 Can't be used with save_checkpoint_steps at the same time. Default: 0. 82 keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5. 83 keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes. 84 Can't be used with keep_checkpoint_max at the same time. Default: 0. 85 integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario. 86 Integrated save function is only supported in automatic parallel scene, not supported 87 in manual parallel. Default: True. 88 async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. 89 saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation 90 with the network in training, the initial value of saved_network will be saved. Default: None. 91 append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and dict. 92 The key of dict must be str, the value of dict must be one of int float and bool. Default: None. 93 enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption 94 is not required. Default: None. 95 enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption 96 mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. 97 98 Raises: 99 ValueError: If input parameter is not the correct type. 100 101 Examples: 102 >>> from mindspore import Model, nn 103 >>> from mindspore.train.callback import ModelCheckpoint, CheckpointConfig 104 >>> 105 >>> class LeNet5(nn.Cell): 106 ... def __init__(self, num_class=10, num_channel=1): 107 ... super(LeNet5, self).__init__() 108 ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 109 ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 110 ... self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) 111 ... self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) 112 ... self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) 113 ... self.relu = nn.ReLU() 114 ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 115 ... self.flatten = nn.Flatten() 116 ... 117 ... def construct(self, x): 118 ... x = self.max_pool2d(self.relu(self.conv1(x))) 119 ... x = self.max_pool2d(self.relu(self.conv2(x))) 120 ... x = self.flatten(x) 121 ... x = self.relu(self.fc1(x)) 122 ... x = self.relu(self.fc2(x)) 123 ... x = self.fc3(x) 124 ... return x 125 >>> 126 >>> net = LeNet5() 127 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 128 >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) 129 >>> model = Model(net, loss_fn=loss, optimizer=optim) 130 >>> data_path = './MNIST_Data' 131 >>> dataset = create_dataset(data_path) 132 >>> config = CheckpointConfig(saved_network=net) 133 >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config) 134 >>> model.train(10, dataset, callbacks=ckpoint_cb) 135 """ 136 137 def __init__(self, 138 save_checkpoint_steps=1, 139 save_checkpoint_seconds=0, 140 keep_checkpoint_max=5, 141 keep_checkpoint_per_n_minutes=0, 142 integrated_save=True, 143 async_save=False, 144 saved_network=None, 145 append_info=None, 146 enc_key=None, 147 enc_mode='AES-GCM'): 148 149 if save_checkpoint_steps is not None: 150 save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) 151 if save_checkpoint_seconds is not None: 152 save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds) 153 if keep_checkpoint_max is not None: 154 keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max) 155 if keep_checkpoint_per_n_minutes is not None: 156 keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) 157 158 if saved_network is not None and not isinstance(saved_network, nn.Cell): 159 raise TypeError(f"For 'CheckpointConfig', the type of 'saved_network' must be None or Cell, " 160 f"but got {str(type(saved_network))}.") 161 162 if not save_checkpoint_steps and not save_checkpoint_seconds and \ 163 not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: 164 raise ValueError("The input arguments 'save_checkpoint_steps', 'save_checkpoint_seconds', " 165 "'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.") 166 167 self._save_checkpoint_steps = save_checkpoint_steps 168 self._save_checkpoint_seconds = save_checkpoint_seconds 169 if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: 170 self._save_checkpoint_seconds = None 171 172 self._keep_checkpoint_max = keep_checkpoint_max 173 self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes 174 if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: 175 self._keep_checkpoint_per_n_minutes = None 176 else: 177 if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: 178 self._keep_checkpoint_max = 1 179 180 self._integrated_save = Validator.check_bool(integrated_save) 181 self._async_save = Validator.check_bool(async_save) 182 self._saved_network = saved_network 183 self._append_dict = self._handle_append_info(append_info) 184 self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) 185 self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) 186 187 @property 188 def save_checkpoint_steps(self): 189 """Get the value of _save_checkpoint_steps.""" 190 return self._save_checkpoint_steps 191 192 @property 193 def save_checkpoint_seconds(self): 194 """Get the value of _save_checkpoint_seconds.""" 195 return self._save_checkpoint_seconds 196 197 @property 198 def keep_checkpoint_max(self): 199 """Get the value of _keep_checkpoint_max.""" 200 return self._keep_checkpoint_max 201 202 @property 203 def keep_checkpoint_per_n_minutes(self): 204 """Get the value of _keep_checkpoint_per_n_minutes.""" 205 return self._keep_checkpoint_per_n_minutes 206 207 @property 208 def integrated_save(self): 209 """Get the value of _integrated_save.""" 210 return self._integrated_save 211 212 @property 213 def async_save(self): 214 """Get the value of _async_save.""" 215 return self._async_save 216 217 @property 218 def saved_network(self): 219 """Get the value of _saved_network""" 220 return self._saved_network 221 222 @property 223 def enc_key(self): 224 """Get the value of _enc_key""" 225 return self._enc_key 226 227 @property 228 def enc_mode(self): 229 """Get the value of _enc_mode""" 230 return self._enc_mode 231 232 @property 233 def append_dict(self): 234 """Get the value of append_dict.""" 235 return self._append_dict 236 237 def get_checkpoint_policy(self): 238 """Get the policy of checkpoint.""" 239 checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, 240 'save_checkpoint_seconds': self.save_checkpoint_seconds, 241 'keep_checkpoint_max': self.keep_checkpoint_max, 242 'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes, 243 'saved_network': self.saved_network} 244 245 return checkpoint_policy 246 247 @staticmethod 248 def _handle_append_info(append_info): 249 """Handle ckpt append info.""" 250 if append_info is None or append_info == []: 251 return None 252 if not isinstance(append_info, list): 253 raise TypeError(f"The type of 'append_info' must be list, but got {str(type(append_info))}.") 254 handle_append_info = {} 255 if "epoch_num" in append_info: 256 handle_append_info["epoch_num"] = 0 257 if "step_num" in append_info: 258 handle_append_info["step_num"] = 0 259 dict_num = 0 260 for element in append_info: 261 if not isinstance(element, str) and not isinstance(element, dict): 262 raise TypeError(f"The type of 'append_info' element must be str or dict, " 263 f"but got {str(type(element))}.") 264 if isinstance(element, str) and element not in _info_list: 265 raise ValueError(f"The value of element in the argument 'append_info' must be in {_info_list}, " 266 f"but got {element}.") 267 if isinstance(element, dict): 268 dict_num += 1 269 if dict_num > 1: 270 raise TypeError(f"The element of 'append_info' must has only one dict.") 271 for key, value in element.items(): 272 if isinstance(key, str) and isinstance(value, (int, float, bool)): 273 handle_append_info[key] = value 274 else: 275 raise TypeError(f"The type of dict in 'append_info' must be key: string, value: int or float, " 276 f"but got key: {type(key)}, value: {type(value)}") 277 278 return handle_append_info 279 280 281class ModelCheckpoint(Callback): 282 """ 283 The checkpoint callback class. 284 285 It is called to combine with train process and save the model and network parameters after training. 286 287 Note: 288 In the distributed training scenario, please specify different directories for each training process 289 to save the checkpoint file. Otherwise, the training may fail. 290 291 Args: 292 prefix (str): The prefix name of checkpoint files. Default: "CKP". 293 directory (str): The path of the folder which will be saved in the checkpoint file. 294 By default, the file is saved in the current directory. Default: None. 295 config (CheckpointConfig): Checkpoint strategy configuration. Default: None. 296 297 Raises: 298 ValueError: If the prefix is invalid. 299 TypeError: If the config is not CheckpointConfig type. 300 """ 301 302 def __init__(self, prefix='CKP', directory=None, config=None): 303 super(ModelCheckpoint, self).__init__() 304 self._latest_ckpt_file_name = "" 305 self._init_time = time.time() 306 self._last_time = time.time() 307 self._last_time_for_keep = time.time() 308 self._last_triggered_step = 0 309 310 if not isinstance(prefix, str) or prefix.find('/') >= 0: 311 raise ValueError("The argument 'prefix' for checkpoint file name is invalid, 'prefix' must be " 312 "string and does not contain '/', but got {}.".format(prefix)) 313 self._prefix = prefix 314 315 if directory is not None: 316 self._directory = _make_directory(directory) 317 else: 318 self._directory = _cur_dir 319 320 if config is None: 321 self._config = CheckpointConfig() 322 else: 323 if not isinstance(config, CheckpointConfig): 324 raise TypeError("The argument 'config' should be 'CheckpointConfig' type, " 325 "but got {}.".format(type(config))) 326 self._config = config 327 328 # get existing checkpoint files 329 self._manager = CheckpointManager() 330 self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) 331 self._append_dict = self._config.append_dict or {} 332 self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 333 self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 334 self._graph_saved = False 335 self._need_flush_from_cache = True 336 337 def step_end(self, run_context): 338 """ 339 Save the checkpoint at the end of step. 340 341 Args: 342 run_context (RunContext): Context of the train running. 343 """ 344 if _is_role_pserver(): 345 self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix 346 cb_params = run_context.original_args() 347 _make_directory(self._directory) 348 # save graph (only once) 349 if not self._graph_saved: 350 graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') 351 if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: 352 os.remove(graph_file_name) 353 _save_graph(cb_params.train_network, graph_file_name) 354 self._graph_saved = True 355 thread_list = threading.enumerate() 356 for thread in thread_list: 357 if thread.getName() == "asyn_save_ckpt": 358 thread.join() 359 self._save_ckpt(cb_params) 360 361 def end(self, run_context): 362 """ 363 Save the last checkpoint after training finished. 364 365 Args: 366 run_context (RunContext): Context of the train running. 367 """ 368 cb_params = run_context.original_args() 369 _to_save_last_ckpt = True 370 371 self._save_ckpt(cb_params, _to_save_last_ckpt) 372 373 thread_list = threading.enumerate() 374 for thread in thread_list: 375 if thread.getName() == "asyn_save_ckpt": 376 thread.join() 377 378 destroy_allgather_cell() 379 380 def _check_save_ckpt(self, cb_params, force_to_save): 381 """Check whether save checkpoint files or not.""" 382 if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: 383 if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ 384 or force_to_save is True: 385 return True 386 elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: 387 self._cur_time = time.time() 388 if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True: 389 self._last_time = self._cur_time 390 return True 391 392 return False 393 394 def _save_ckpt(self, cb_params, force_to_save=False): 395 """Save checkpoint files.""" 396 if cb_params.cur_step_num == self._last_triggered_step: 397 return 398 399 # if param is cache enable, flush data from cache to host before save_ckpt 400 if self._need_flush_from_cache: 401 self._flush_from_cache(cb_params) 402 403 save_ckpt = self._check_save_ckpt(cb_params, force_to_save) 404 step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) 405 406 if save_ckpt: 407 cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ 408 + str(step_num_in_epoch) + ".ckpt" 409 # update checkpoint file list. 410 self._manager.update_ckpoint_filelist(self._directory, self._prefix) 411 # keep checkpoint files number equal max number. 412 if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: 413 self._manager.remove_oldest_ckpoint_file() 414 elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: 415 self._cur_time_for_keep = time.time() 416 if (self._cur_time_for_keep - self._last_time_for_keep) \ 417 < self._config.keep_checkpoint_per_n_minutes * 60: 418 self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, 419 self._cur_time_for_keep) 420 421 # generate the new checkpoint file and rename it. 422 global _save_dir 423 _save_dir = self._directory 424 cur_file = os.path.join(self._directory, cur_ckpoint_file) 425 self._last_time_for_keep = time.time() 426 self._last_triggered_step = cb_params.cur_step_num 427 428 if context.get_context("enable_ge"): 429 set_cur_net(cb_params.train_network) 430 cb_params.train_network.exec_checkpoint_graph() 431 if "epoch_num" in self._append_dict: 432 self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num 433 if "step_num" in self._append_dict: 434 self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num 435 network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network 436 save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, 437 self._append_dict, self._config.enc_key, self._config.enc_mode) 438 439 self._latest_ckpt_file_name = cur_file 440 441 def _flush_from_cache(self, cb_params): 442 """Flush cache data to host if tensor is cache enable.""" 443 has_cache_params = False 444 params = cb_params.train_network.get_parameters() 445 for param in params: 446 if param.cache_enable: 447 has_cache_params = True 448 Tensor(param).flush_from_cache() 449 if not has_cache_params: 450 self._need_flush_from_cache = False 451 452 @property 453 def latest_ckpt_file_name(self): 454 """Return the latest checkpoint path and file name.""" 455 return self._latest_ckpt_file_name 456 457 458class CheckpointManager: 459 """Manage checkpoint files according to train_config of checkpoint.""" 460 461 def __init__(self): 462 self._ckpoint_filelist = [] 463 464 @property 465 def ckpoint_filelist(self): 466 """Get all the related checkpoint files managed here.""" 467 return self._ckpoint_filelist 468 469 @property 470 def ckpoint_num(self): 471 """Get the number of the related checkpoint files managed here.""" 472 return len(self._ckpoint_filelist) 473 474 def update_ckpoint_filelist(self, directory, prefix): 475 """Update the checkpoint file list.""" 476 self._ckpoint_filelist = [] 477 files = os.listdir(directory) 478 for filename in files: 479 if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): 480 mid_name = filename[len(prefix):-5] 481 flag = not (True in [char.isalpha() for char in mid_name]) 482 if flag: 483 self._ckpoint_filelist.append(os.path.join(directory, filename)) 484 485 def remove_ckpoint_file(self, file_name): 486 """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" 487 try: 488 os.chmod(file_name, stat.S_IWRITE) 489 os.remove(file_name) 490 self._ckpoint_filelist.remove(file_name) 491 except OSError: 492 logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) 493 except ValueError: 494 logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) 495 496 def remove_oldest_ckpoint_file(self): 497 """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" 498 ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) 499 self.remove_ckpoint_file(ckpoint_files[0]) 500 501 def keep_one_ckpoint_per_minutes(self, minutes, cur_time): 502 """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" 503 del_list = [] 504 oldest_file = '' 505 oldest_time = cur_time 506 for ck_file in self._ckpoint_filelist: 507 modify_time = os.path.getmtime(ck_file) 508 if cur_time - modify_time < 60 * minutes: 509 del_list.append(ck_file) 510 511 if modify_time < oldest_time: 512 oldest_time = modify_time 513 oldest_file = ck_file 514 515 for mv_file in del_list: 516 if mv_file == oldest_file: 517 continue 518 self.remove_ckpoint_file(mv_file) 519