• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Checkpoint related classes and functions."""
16from __future__ import absolute_import
17
18import os
19import stat
20import time
21
22import threading
23import mindspore.context as context
24from mindspore import log as logger
25from mindspore import nn
26from mindspore import _checkparam as Validator
27from mindspore.train._utils import _make_directory
28from mindspore.train.serialization import save_checkpoint, _save_graph
29from mindspore.parallel._cell_wrapper import destroy_allgather_cell
30from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
31from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
32from mindspore.parallel._utils import _get_device_num
33from mindspore.communication.management import get_rank
34from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
35from mindspore.train.callback._callback import Callback, set_cur_net
36from mindspore.common.tensor import Tensor
37from mindspore.common.parameter import Parameter
38from mindspore.common.generator import Generator
39from mindspore.common.api import _cell_graph_executor
40from mindspore._c_expression import _collect_host_info
41
42
43_cur_dir = os.getcwd()
44SAVE_DIR = _cur_dir
45_info_list = ["epoch_num", "step_num"]
46
47
48def _get_dp_tp_from_redundancy(redundancy_tuple):
49    """From redundancy get dp and tp"""
50    dp = []
51    tp = []
52    for dp_value in redundancy_tuple:
53        dp.append(list(dp_value))
54    for i in range(len(redundancy_tuple[0])):
55        tp.append([v[i] for v in redundancy_tuple])
56    return dp, tp
57
58
59def _get_dp_tp_from_layout(parameter_redundancy_dict):
60    """From layout dict get dp and tp"""
61    tp = []
62    dp = []
63    value_len = 0
64    for _, value in parameter_redundancy_dict.items():
65        if len(value) > value_len:
66            value_len = len(value)
67            dp, tp = _get_dp_tp_from_redundancy(value)
68    return dp, tp
69
70
71def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
72    """Check if there is a file with the same name."""
73    if callable(prefix) or callable(directory):
74        return prefix
75    files = os.listdir(directory)
76    suffix_num = 0
77    pre_len = len(prefix)
78    for filename in files:
79        name_ext = os.path.splitext(filename)
80        if exception and filename[-16:] != "_breakpoint.ckpt":
81            continue
82        if not exception and (name_ext[-1] != ".ckpt" or filename[-16:] == "_breakpoint.ckpt"):
83            continue
84        # find same prefix file
85        if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
86            # add the max suffix + 1
87            index = filename[pre_len:].find("-")
88            if index == 0:
89                suffix_num = max(suffix_num, 1)
90            elif index != -1:
91                num = filename[pre_len+1:pre_len+index]
92                if num.isdigit():
93                    suffix_num = max(suffix_num, int(num)+1)
94
95    if suffix_num != 0:
96        prefix = f'{prefix}_{suffix_num}'
97
98    return prefix
99
100
101class CheckpointConfig:
102    """
103    The configuration of model checkpoint.
104
105    Note:
106        - During the training process, if dataset is transmitted through the data channel,
107          it is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
108          Otherwise, the time to save the checkpoint may be biased.
109          It is recommended to set only one save strategy and one keep strategy at the same time.
110          If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set,
111          `save_checkpoint_seconds` will be invalid.
112          If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set,
113          `keep_checkpoint_per_n_minutes` will be invalid.
114        - The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
115
116    Args:
117        save_checkpoint_steps (int): Steps to save checkpoint. Default: ``1`` .
118        save_checkpoint_seconds (int): Seconds to save checkpoint.
119            Can't be used with save_checkpoint_steps at the same time. Default: ``0`` .
120        keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: ``5`` .
121        keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes.
122            Can't be used with keep_checkpoint_max at the same time. Default: ``0`` .
123        integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
124            Integrated save function is only supported in automatic parallel scene, not supported
125            in manual parallel. Default: ``True`` .
126        async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: ``False`` .
127        saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
128            with the network in training, the initial value of saved_network will be saved. Default: ``None`` .
129        append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and
130            dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor.
131            Default: ``None`` .
132        enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
133                                      is not required. Default: ``None`` .
134        enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
135                        mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: ``'AES-GCM'`` .
136        exception_save (bool): Whether to save the current checkpoint when an exception occurs. Default: ``False`` .
137        crc_check (bool): Whether to perform crc32 calculation when saving checkpoint and save the calculation
138                          result to the end of ckpt. Default: ``False`` .
139        kwargs (dict): Configuration options dictionary.
140
141    Raises:
142        ValueError: If input parameter is not the correct type.
143
144    Examples:
145        >>> from mindspore import nn
146        >>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
147        >>>
148        >>> # Define the network structure of LeNet5. Refer to
149        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
150        >>> net = LeNet5()
151        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
152        >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
153        >>> model = Model(net, loss_fn=loss, optimizer=optim)
154        >>> # Create the dataset taking MNIST as an example. Refer to
155        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
156        >>> dataset = create_dataset()
157        >>> config = CheckpointConfig(save_checkpoint_seconds=100, keep_checkpoint_per_n_minutes=5, saved_network=net)
158        >>> config.save_checkpoint_steps
159        1
160        >>> config.save_checkpoint_seconds
161        >>> config.keep_checkpoint_max
162        5
163        >>> config.keep_checkpoint_per_n_minutes
164        >>> config.integrated_save
165        True
166        >>> config.async_save
167        False
168        >>> config.saved_network
169        >>> config.enc_key
170        >>> config.enc_mode
171        'AES-GCM'
172        >>> config.append_dict
173        >>> config.get_checkpoint_policy
174        >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
175        >>> model.train(10, dataset, callbacks=ckpoint_cb)
176    """
177
178    def __init__(self,
179                 save_checkpoint_steps=1,
180                 save_checkpoint_seconds=0,
181                 keep_checkpoint_max=5,
182                 keep_checkpoint_per_n_minutes=0,
183                 integrated_save=True,
184                 async_save=False,
185                 saved_network=None,
186                 append_info=None,
187                 enc_key=None,
188                 enc_mode='AES-GCM',
189                 exception_save=False,
190                 crc_check=False,
191                 **kwargs):
192
193        if save_checkpoint_steps is not None:
194            save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
195        if save_checkpoint_seconds is not None:
196            save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds)
197        if keep_checkpoint_max is not None:
198            keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max)
199        if keep_checkpoint_per_n_minutes is not None:
200            keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
201
202        if saved_network is not None and not isinstance(saved_network, nn.Cell):
203            raise TypeError(f"For 'CheckpointConfig', the type of 'saved_network' must be None or Cell, "
204                            f"but got {str(type(saved_network))}.")
205
206        if not save_checkpoint_steps and not save_checkpoint_seconds and \
207                not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
208            raise ValueError("For 'CheckpointConfig', the input arguments 'save_checkpoint_steps', "
209                             "'save_checkpoint_seconds', "
210                             "'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.")
211        Validator.check_bool(exception_save)
212        self.exception_save = exception_save
213
214        self._save_checkpoint_steps = save_checkpoint_steps
215        self._save_checkpoint_seconds = save_checkpoint_seconds
216        if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
217            self._save_checkpoint_seconds = None
218
219        self._keep_checkpoint_max = keep_checkpoint_max
220        self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
221        if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
222            self._keep_checkpoint_per_n_minutes = None
223        else:
224            if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
225                self._keep_checkpoint_max = 1
226
227        self._integrated_save = Validator.check_bool(integrated_save)
228        self._async_save = Validator.check_bool(async_save)
229        self._saved_network = saved_network
230        self._append_dict = self._handle_append_info(append_info)
231        self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
232        self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
233        self._crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
234        self._map_param_inc = kwargs.get('incremental', False)
235        self.enable_redundance = kwargs.get('enable_redundance', False)
236
237    @property
238    def save_checkpoint_steps(self):
239        """
240        Get the value of steps to save checkpoint.
241
242        Returns:
243            int, steps to save checkpoint.
244        """
245        return self._save_checkpoint_steps
246
247    @property
248    def save_checkpoint_seconds(self):
249        """Get the value of _save_checkpoint_seconds.
250
251        Returns:
252            int, seconds to save the checkpoint file.
253        """
254        return self._save_checkpoint_seconds
255
256    @property
257    def keep_checkpoint_max(self):
258        """
259        Get the value of maximum number of checkpoint files can be saved.
260
261        Returns:
262            int, Maximum number of checkpoint files can be saved.
263        """
264        return self._keep_checkpoint_max
265
266    @property
267    def keep_checkpoint_per_n_minutes(self):
268        """
269        Get the value of save the checkpoint file every n minutes.
270
271        Returns:
272            Int, save the checkpoint file every n minutes.
273        """
274        return self._keep_checkpoint_per_n_minutes
275
276    @property
277    def integrated_save(self):
278        """
279        Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.
280
281        Returns:
282            bool, whether to merge and save the split Tensor in the automatic parallel scenario.
283        """
284        return self._integrated_save
285
286    @property
287    def async_save(self):
288        """
289        Get the value of whether asynchronous execution saves the checkpoint to a file.
290
291        Returns:
292            bool, whether asynchronous execution saves the checkpoint to a file.
293        """
294        return self._async_save
295
296    @property
297    def saved_network(self):
298        """
299        Get the value of network to be saved in checkpoint file.
300
301        Returns:
302            Cell, network to be saved in checkpoint file.
303        """
304        return self._saved_network
305
306    @property
307    def enc_key(self):
308        """
309        Get the value of byte type key used for encryption.
310
311        Returns:
312            (None, bytes), byte type key used for encryption.
313        """
314        return self._enc_key
315
316    @property
317    def enc_mode(self):
318        """
319        Get the value of the encryption mode.
320
321        Returns:
322            str, encryption mode.
323        """
324        return self._enc_mode
325
326    @property
327    def crc_check(self):
328        """
329        Get the value of the whether to enable crc check.
330
331        Returns:
332            bool, whether to enable crc check.
333        """
334        return self._crc_check
335
336    @property
337    def append_dict(self):
338        """
339        Get the value of information dict saved to checkpoint file.
340
341        Returns:
342            dict, the information saved to checkpoint file.
343        """
344        return self._append_dict
345
346    @property
347    def map_param_inc(self):
348        """
349        Get the value of whether to save map Parameter incrementally.
350
351        Returns:
352            bool, whether to save map Parameter incrementally.
353        """
354        return self._map_param_inc
355
356    def get_checkpoint_policy(self):
357        """
358        Get the policy of checkpoint.
359
360        Returns:
361            dict, the information of checkpoint policy.
362        """
363        checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
364                             'save_checkpoint_seconds': self.save_checkpoint_seconds,
365                             'keep_checkpoint_max': self.keep_checkpoint_max,
366                             'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes,
367                             'saved_network': self.saved_network}
368
369        return checkpoint_policy
370
371    @staticmethod
372    def _handle_append_info(append_info):
373        """Handle ckpt append info."""
374        if append_info is None or append_info == []:
375            return None
376        if not isinstance(append_info, list):
377            raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' must be list,"
378                            f"but got {str(type(append_info))}.")
379        handle_append_info = {}
380        if "epoch_num" in append_info:
381            handle_append_info["epoch_num"] = 0
382        if "step_num" in append_info:
383            handle_append_info["step_num"] = 0
384        if "random_op" in append_info:
385            handle_append_info["random_op"] = 0
386        dict_num = 0
387        for element in append_info:
388            if not isinstance(element, str) and not isinstance(element, dict):
389                raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' element must be str or dict,"
390                                f"but got {str(type(element))}.")
391            if isinstance(element, str) and element not in _info_list:
392                raise ValueError(f"For 'CheckpointConfig', the value of element in the argument 'append_info' "
393                                 f"must be in {_info_list}, "
394                                 f"but got {element}.")
395            if isinstance(element, dict):
396                dict_num += 1
397                if dict_num > 1:
398                    raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict, "
399                                    "but got {dict_num}")
400                for key, value in element.items():
401                    if isinstance(key, str) and isinstance(value,
402                                                           (int, float, bool, str, Parameter, Tensor, Generator)):
403                        handle_append_info[key] = value
404                    else:
405                        raise TypeError(f"For 'CheckpointConfig', the key type of the dict 'append_info' "
406                                        f"must be string, the value type must be int or float or bool, "
407                                        f"but got key type {type(key)}, value type {type(value)}")
408
409        return handle_append_info
410
411
412class ModelCheckpoint(Callback):
413    """
414    The checkpoint callback class.
415
416    It is called to combine with train process and save the model and network parameters after training.
417
418    Note:
419        In the distributed training scenario, please specify different directories for each training process
420        to save the checkpoint file. Otherwise, the training may fail.
421        If this callback is used in the `model` function, the checkpoint file will saved
422        parameters of the optimizer by default.
423
424    Args:
425        prefix (Union[str, callable object]): The prefix name or callable object to generate name of checkpoint files.
426            Default: ``'CKP'`` .
427        directory (Union[str, callable object]): The folder path where the checkpoint is stored, or the callable object
428            used to generate the path. By default, the file is saved in the current directory.
429            Default: ``None`` .
430        config (CheckpointConfig): Checkpoint strategy configuration. Default: ``None`` .
431
432    Raises:
433        ValueError: If `prefix` is not str or contains the '/' character and is not a callable object.
434        ValueError: If `directory` is not str and is not a callable object.
435        TypeError: If the config is not CheckpointConfig type.
436
437    Examples:
438        >>> import numpy as np
439        >>> import mindspore.dataset as ds
440        >>> from mindspore import nn
441        >>> from mindspore.train import Model, ModelCheckpoint
442        >>>
443        >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
444        >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
445        >>> net = nn.Dense(10, 5)
446        >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
447        >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
448        >>> ckpt_callback = ModelCheckpoint(prefix="myckpt")
449        >>> model = Model(network=net, optimizer=opt, loss_fn=crit)
450        >>> model.train(2, train_dataset, callbacks=[ckpt_callback])
451    """
452
453    def __init__(self, prefix='CKP', directory=None, config=None):
454        super(ModelCheckpoint, self).__init__()
455        self._latest_ckpt_file_name = ""
456        self._init_time = time.time()
457        self._last_time = time.time()
458        self._last_time_for_keep = time.time()
459        self._last_triggered_step = 0
460        """a callable for users to set self-defined prefix."""
461        self._prefix_func = None
462        """a callable for users to set self-defined directory."""
463        self._directory_func = None
464
465        if not callable(prefix) and (not isinstance(prefix, str) or prefix.find('/') >= 0):
466            raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
467                             "for checkpoint file name is invalid, it must be "
468                             "callable or string that does not contain '/', but got {}.".format(prefix))
469        self._prefix = prefix
470        self._exception_prefix = prefix
471
472        if directory is not None:
473            if callable(directory):
474                self._directory_func = directory
475            else:
476                self._directory = _make_directory(directory)
477        else:
478            self._directory = _cur_dir
479
480        if callable(prefix):
481            self._prefix_func = prefix
482
483        if _get_recovery_context("enable_recovery"):
484            _set_recovery_context(ckpt_path=self._directory)
485
486        if config is None:
487            self._config = CheckpointConfig()
488        else:
489            if not isinstance(config, CheckpointConfig):
490                raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
491                                "'CheckpointConfig', "
492                                "but got {}.".format(type(config)))
493            self._config = config
494
495        self._aiturbo_init_flag = os.getenv("AITURBO") == "1"
496        # get existing checkpoint files
497        if self._aiturbo_init_flag:
498            import aiturbo
499            self._manager = aiturbo.CheckpointShmManager()
500        else:
501            self._manager = CheckpointManager()
502        if not callable(directory) and not callable(prefix):
503            self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
504        self._append_dict = self._config.append_dict or {}
505        self._append_epoch_num = self._append_dict.get("epoch_num") if "epoch_num" in self._append_dict else 0
506        self._append_step_num = self._append_dict.get("step_num") if "step_num" in self._append_dict else 0
507        self._graph_saved = False
508        self._need_flush_from_cache = True
509        self._map_param_inc = self._config.map_param_inc
510
511    def step_end(self, run_context):
512        """
513        Save the checkpoint at the end of step.
514
515        Args:
516            run_context (RunContext): Context of the train running.
517        """
518        cb_params = run_context.original_args()
519        if self._aiturbo_init_flag:
520            import aiturbo
521            ckpt_storage_path = self._directory
522            rank_id = get_rank()
523            stage_num = _get_auto_parallel_context("pipeline_stages")
524            stage_rank_num = _get_device_num() // stage_num
525            param_layout = cb_params.train_network.parameter_layout_dict
526            if not param_layout:
527                layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num, "stage_layout": None}
528                aiturbo.init(ckpt_storage_path, rank_id, layout, None, False, None)
529            else:
530                device_num = _get_device_num()
531                chunk_size = device_num // stage_num
532                initial_rank = (rank_id // chunk_size) * chunk_size
533                param_redundancy_dict = get_parameter_redundancy(param_layout, initial_rank)
534                dp, _ = _get_dp_tp_from_layout(param_redundancy_dict)
535                layout = {"stage_num": stage_num, "stage_rank_num": stage_rank_num,
536                          "stage_layout": param_redundancy_dict}
537                single_params = remove_param_redundancy(param_redundancy_dict)
538                single_params = {device_id: list(params) for device_id, params in single_params.items()}
539                aiturbo.init(ckpt_storage_path, rank_id, layout, single_params, self._config.enable_redundance, dp)
540            self._aiturbo_init_flag = False
541        if self._prefix_func:
542            self._prefix = self._prefix_func(cb_params)
543            if not isinstance(self._prefix, str) or self._prefix.find('/') >= 0:
544                raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
545                                 "for checkpoint file name is callable, it must return a "
546                                 "string that does not contain '/', but got {}.".format(self._prefix))
547        if self._directory_func:
548            self._directory = self._directory_func(cb_params)
549        _collect_host_info("Callback", "ModelCheckpoint", "step_end", level=1)
550        # In disaster recovery scenario, the training process may be rolled back to the last step where
551        # the ckpt was successfully saved, so the _last_triggered_step should be updated.
552        if _get_recovery_context("enable_recovery") and cb_params.last_save_ckpt_step is not None:
553            self._last_triggered_step = cb_params.last_save_ckpt_step
554            cb_params.last_save_ckpt_step = None
555
556        _make_directory(self._directory)
557        # save graph (only once)
558        if not self._graph_saved:
559            graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
560            if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
561                os.remove(graph_file_name)
562            _save_graph(cb_params.train_network, graph_file_name)
563            self._graph_saved = True
564        thread_list = threading.enumerate()
565        for thread in thread_list:
566            if thread.getName() == "asyn_save_ckpt":
567                thread.join()
568        self._save_ckpt(cb_params)
569
570    def end(self, run_context):
571        """
572        Save the last checkpoint after training finished.
573
574        Args:
575            run_context (RunContext): Context of the train running.
576        """
577        cb_params = run_context.original_args()
578        _collect_host_info("Callback", "ModelCheckpoint", "end", level=1)
579        _to_save_last_ckpt = True
580
581        self._save_ckpt(cb_params, _to_save_last_ckpt)
582
583        thread_list = threading.enumerate()
584        for thread in thread_list:
585            if thread.getName() == "asyn_save_ckpt":
586                thread.join()
587
588        destroy_allgather_cell()
589
590    def _check_save_ckpt(self, cb_params, force_to_save):
591        """Check whether save checkpoint files or not."""
592        if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
593            if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
594                    or force_to_save is True:
595                return True
596        elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
597            self._cur_time = time.time()
598            if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save:
599                self._last_time = self._cur_time
600                return True
601
602        return False
603
604    def _save_ckpt(self, cb_params, force_to_save=False):
605        """Save checkpoint files."""
606        if cb_params.cur_step_num == self._last_triggered_step:
607            return
608
609        # if param is cache enable, flush data from cache to host before save_ckpt
610        if self._need_flush_from_cache:
611            self._flush_from_cache(cb_params)
612
613        save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
614        step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
615
616        if save_ckpt:
617            if self._prefix_func:
618                cur_ckpoint_file = self._prefix + ".ckpt"
619            else:
620                cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
621                    + str(step_num_in_epoch) + ".ckpt"
622            # update checkpoint file list.
623            self._manager.update_ckpoint_filelist(self._directory, self._prefix)
624            # keep checkpoint files number equal max number.
625            if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
626                self._manager.remove_oldest_ckpoint_file()
627            elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
628                self._cur_time_for_keep = time.time()
629                if (self._cur_time_for_keep - self._last_time_for_keep) \
630                        < self._config.keep_checkpoint_per_n_minutes * 60:
631                    self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
632                                                               self._cur_time_for_keep)
633
634            # generate the new checkpoint file and rename it.
635            global SAVE_DIR
636            SAVE_DIR = self._directory
637            cur_file = os.path.join(self._directory, cur_ckpoint_file)
638            self._last_time_for_keep = time.time()
639            self._last_triggered_step = cb_params.cur_step_num
640
641            # TODO(MS_DISABLE_REF_MODE): Delete when remove MS_DISABLE_REF_MODE env.
642            if context.get_context("enable_ge") and os.getenv('MS_DISABLE_REF_MODE') \
643                    and context.get_context("mode") == context.GRAPH_MODE:
644                set_cur_net(cb_params.train_network)
645                cb_params.train_network.add_flags(ge_sync_data=True)
646                _cell_graph_executor(cb_params.train_network, phase='save')
647            if "epoch_num" in self._append_dict:
648                self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
649            if "step_num" in self._append_dict:
650                self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
651            network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
652            if os.getenv("AITURBO") == "1":
653                save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
654                                self._append_dict, self._config.enc_key, self._config.enc_mode,
655                                crc_check=self._config.crc_check, incremental=self._map_param_inc,
656                                global_step_num=cb_params.cur_step_num)
657            else:
658                save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
659                                self._append_dict, self._config.enc_key, self._config.enc_mode,
660                                crc_check=self._config.crc_check, incremental=self._map_param_inc)
661
662            self._latest_ckpt_file_name = cur_file
663
664    def _flush_from_cache(self, cb_params):
665        """Flush cache data to host if tensor is cache enable."""
666        has_cache_params = False
667        params = cb_params.train_network.get_parameters()
668        for param in params:
669            if param.cache_enable:
670                has_cache_params = True
671                Tensor(param).flush_from_cache()
672        if not has_cache_params:
673            self._need_flush_from_cache = False
674
675    @property
676    def latest_ckpt_file_name(self):
677        """Return the latest checkpoint path and file name."""
678        return self._latest_ckpt_file_name
679
680    @property
681    def _get_save_checkpoint_steps(self):
682        """Return save ckpt steps"""
683        return self._config.save_checkpoint_steps
684
685    @property
686    def _get_last_trigger_step(self):
687        """Return last triggered steps"""
688        return self._last_triggered_step
689
690
691class CheckpointManager:
692    """Manage checkpoint files according to train_config of checkpoint."""
693
694    def __init__(self):
695        self._ckpoint_filelist = []
696
697    @property
698    def ckpoint_filelist(self):
699        """Get all the related checkpoint files managed here."""
700        return self._ckpoint_filelist
701
702    @property
703    def ckpoint_num(self):
704        """Get the number of the related checkpoint files managed here."""
705        return len(self._ckpoint_filelist)
706
707    def update_ckpoint_filelist(self, directory, prefix):
708        """Update the checkpoint file list."""
709        self._ckpoint_filelist = []
710        files = os.listdir(directory)
711        for filename in files:
712            if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
713                mid_name = filename[len(prefix):-5]
714                flag = not (True in [char.isalpha() for char in mid_name])
715                if flag:
716                    self._ckpoint_filelist.append(os.path.join(directory, filename))
717
718    def remove_ckpoint_file(self, file_name):
719        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
720        try:
721            os.chmod(file_name, stat.S_IWRITE)
722            os.remove(file_name)
723            self._ckpoint_filelist.remove(file_name)
724        except OSError:
725            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
726        except ValueError:
727            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
728
729    def remove_oldest_ckpoint_file(self):
730        """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
731        ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
732        self.remove_ckpoint_file(ckpoint_files[0])
733
734    def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
735        """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
736        del_list = []
737        oldest_file = ''
738        oldest_time = cur_time
739        for ck_file in self._ckpoint_filelist:
740            modify_time = os.path.getmtime(ck_file)
741            if cur_time - modify_time < 60 * minutes:
742                del_list.append(ck_file)
743
744                if modify_time < oldest_time:
745                    oldest_time = modify_time
746                    oldest_file = ck_file
747
748        for mv_file in del_list:
749            if mv_file == oldest_file:
750                continue
751            self.remove_ckpoint_file(mv_file)
752