1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""backup and restore related classes and functions.""" 16from __future__ import absolute_import 17 18import os 19import stat 20 21from mindspore import log as logger 22from mindspore.train.serialization import load_checkpoint, save_checkpoint 23from mindspore.train.callback._callback import Callback 24from mindspore.train._utils import _make_directory 25from mindspore import _checkparam as Validator 26 27 28class BackupAndRestore(Callback): 29 """ 30 Callback to back up and restore the parameters during training. 31 32 Note: 33 This function can only use in training. 34 35 Args: 36 backup_dir (str): Path to store and load the checkpoint file. 37 save_freq (Union["epoch", int]): When set to ``"epoch"`` the callback saves the checkpoint at the end of 38 each epoch. When set to an integer, the callback saves the checkpoint 39 every `save_freq` epoch. Default: ``"epoch"`` . 40 delete_checkpoint (bool): If `delete_checkpoint=True`, the checkpoint will be deleted after 41 training is finished. Default: ``True`` . 42 43 Raises: 44 ValueError: If backup_dir is not str. 45 ValueError: If save_freq is not ``"epoch"`` or int. 46 ValueError: If delete_checkpoint is not bool. 47 48 Examples: 49 >>> from mindspore import nn 50 >>> from mindspore.train import Model, BackupAndRestore, RunContext 51 >>> 52 >>> # Define the network structure of LeNet5. Refer to 53 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 54 >>> net = LeNet5() 55 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 56 >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) 57 >>> model = Model(net, loss_fn=loss, optimizer=optim) 58 >>> # Create the dataset taking MNIST as an example. Refer to 59 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 60 >>> dataset = create_dataset() 61 >>> backup_ckpt = BackupAndRestore("backup") 62 >>> model.train(10, dataset, callbacks=backup_ckpt) 63 """ 64 def __init__(self, backup_dir, save_freq="epoch", delete_checkpoint=True): 65 super(BackupAndRestore, self).__init__() 66 ckpt_dir = _make_directory(backup_dir) 67 self.backup_file = os.path.join(ckpt_dir, 'backup.ckpt') 68 if save_freq != "epoch": 69 self.save_freq = Validator.check_positive_int(save_freq) 70 else: 71 self.save_freq = 1 72 self.delete_checkpoint = Validator.check_bool(delete_checkpoint) 73 74 def on_train_begin(self, run_context): 75 """ 76 Load the backup checkpoint file at the beginning of epoch. 77 78 Args: 79 run_context (RunContext): Context of the process running. For more details, 80 please refer to :class:`mindspore.train.RunContext`. 81 """ 82 if os.path.exists(self.backup_file): 83 cb_params = run_context.original_args() 84 train_net = cb_params.train_network 85 logger.info("Restore checkpoint file is {}, load checkpoint into train net.".format(self.backup_file)) 86 load_checkpoint(self.backup_file, net=train_net) 87 88 def on_train_epoch_end(self, run_context): 89 """ 90 Backup checkpoint file at the end of train epoch. 91 92 Args: 93 run_context (RunContext): Context of the process running. For more details, 94 please refer to :class:`mindspore.train.RunContext`. 95 """ 96 cb_params = run_context.original_args() 97 cur_epoch_num = cb_params.cur_epoch_num 98 if cur_epoch_num % self.save_freq == 0: 99 train_net = cb_params.train_network 100 logger.info("Train task end, backup checkpoint file: {}.".format(self.backup_file)) 101 save_checkpoint(train_net, self.backup_file) 102 103 def on_train_end(self, run_context): 104 """ 105 Deleted checkpoint file at the end of train. 106 107 Args: 108 run_context (RunContext): Context of the process running. For more details, 109 please refer to :class:`mindspore.train.RunContext`. 110 """ 111 run_context.original_args() 112 cb_params = run_context.original_args() 113 cur_epoch_num = cb_params.cur_epoch_num 114 if self.delete_checkpoint: 115 logger.info("Delete restore checkpoint file {} at {} epoch.".format(self.backup_file, cur_epoch_num)) 116 os.chmod(self.backup_file, stat.S_IWRITE) 117 os.remove(self.backup_file) 118