• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""backup and restore related classes and functions."""
16from __future__ import absolute_import
17
18import os
19import stat
20
21from mindspore import log as logger
22from mindspore.train.serialization import load_checkpoint, save_checkpoint
23from mindspore.train.callback._callback import Callback
24from mindspore.train._utils import _make_directory
25from mindspore import _checkparam as Validator
26
27
28class BackupAndRestore(Callback):
29    """
30    Callback to back up and restore the parameters during training.
31
32    Note:
33        This function can only use in training.
34
35    Args:
36        backup_dir (str): Path to store and load the checkpoint file.
37        save_freq (Union["epoch", int]): When set to ``"epoch"`` the callback saves the checkpoint at the end of
38                                        each epoch. When set to an integer, the callback saves the checkpoint
39                                        every `save_freq` epoch. Default: ``"epoch"`` .
40        delete_checkpoint (bool): If `delete_checkpoint=True`, the checkpoint will be deleted after
41                                        training is finished. Default: ``True`` .
42
43    Raises:
44        ValueError: If backup_dir is not str.
45        ValueError: If save_freq is not ``"epoch"`` or int.
46        ValueError: If delete_checkpoint is not bool.
47
48    Examples:
49        >>> from mindspore import nn
50        >>> from mindspore.train import Model, BackupAndRestore, RunContext
51        >>>
52        >>> # Define the network structure of LeNet5. Refer to
53        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
54        >>> net = LeNet5()
55        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
56        >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
57        >>> model = Model(net, loss_fn=loss, optimizer=optim)
58        >>> # Create the dataset taking MNIST as an example. Refer to
59        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
60        >>> dataset = create_dataset()
61        >>> backup_ckpt = BackupAndRestore("backup")
62        >>> model.train(10, dataset, callbacks=backup_ckpt)
63    """
64    def __init__(self, backup_dir, save_freq="epoch", delete_checkpoint=True):
65        super(BackupAndRestore, self).__init__()
66        ckpt_dir = _make_directory(backup_dir)
67        self.backup_file = os.path.join(ckpt_dir, 'backup.ckpt')
68        if save_freq != "epoch":
69            self.save_freq = Validator.check_positive_int(save_freq)
70        else:
71            self.save_freq = 1
72        self.delete_checkpoint = Validator.check_bool(delete_checkpoint)
73
74    def on_train_begin(self, run_context):
75        """
76        Load the backup checkpoint file at the beginning of epoch.
77
78        Args:
79            run_context (RunContext): Context of the process running. For more details,
80                    please refer to :class:`mindspore.train.RunContext`.
81        """
82        if os.path.exists(self.backup_file):
83            cb_params = run_context.original_args()
84            train_net = cb_params.train_network
85            logger.info("Restore checkpoint file is {}, load checkpoint into train net.".format(self.backup_file))
86            load_checkpoint(self.backup_file, net=train_net)
87
88    def on_train_epoch_end(self, run_context):
89        """
90        Backup checkpoint file at the end of train epoch.
91
92        Args:
93           run_context (RunContext): Context of the process running. For more details,
94                   please refer to :class:`mindspore.train.RunContext`.
95        """
96        cb_params = run_context.original_args()
97        cur_epoch_num = cb_params.cur_epoch_num
98        if cur_epoch_num % self.save_freq == 0:
99            train_net = cb_params.train_network
100            logger.info("Train task end, backup checkpoint file: {}.".format(self.backup_file))
101            save_checkpoint(train_net, self.backup_file)
102
103    def on_train_end(self, run_context):
104        """
105        Deleted checkpoint file at the end of train.
106
107        Args:
108            run_context (RunContext): Context of the process running. For more details,
109                    please refer to :class:`mindspore.train.RunContext`.
110        """
111        run_context.original_args()
112        cb_params = run_context.original_args()
113        cur_epoch_num = cb_params.cur_epoch_num
114        if self.delete_checkpoint:
115            logger.info("Delete restore checkpoint file {} at {} epoch.".format(self.backup_file, cur_epoch_num))
116            os.chmod(self.backup_file, stat.S_IWRITE)
117            os.remove(self.backup_file)
118