• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Checkpoint related classes and functions."""
16
17import os
18import stat
19import time
20
21import threading
22import mindspore.context as context
23from mindspore import log as logger
24from mindspore import nn
25from mindspore._checkparam import Validator
26from mindspore.train._utils import _make_directory
27from mindspore.train.serialization import save_checkpoint, _save_graph
28from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
29from mindspore.parallel._cell_wrapper import destroy_allgather_cell
30from ._callback import Callback, set_cur_net
31from ...common.tensor import Tensor
32
33_cur_dir = os.getcwd()
34_save_dir = _cur_dir
35_info_list = ["epoch_num", "step_num"]
36
37
38def _chg_ckpt_file_name_if_same_exist(directory, prefix):
39    """Check if there is a file with the same name."""
40    files = os.listdir(directory)
41    suffix_num = 0
42    pre_len = len(prefix)
43    for filename in files:
44        name_ext = os.path.splitext(filename)
45        if name_ext[-1] != ".ckpt":
46            continue
47        # find same prefix file
48        if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
49            # add the max suffix + 1
50            index = filename[pre_len:].find("-")
51            if index == 0:
52                suffix_num = max(suffix_num, 1)
53            elif index != -1:
54                num = filename[pre_len+1:pre_len+index]
55                if num.isdigit():
56                    suffix_num = max(suffix_num, int(num)+1)
57
58    if suffix_num != 0:
59        prefix = prefix + "_" + str(suffix_num)
60
61    return prefix
62
63
64class CheckpointConfig:
65    """
66    The configuration of model checkpoint.
67
68    Note:
69        During the training process, if dataset is transmitted through the data channel,
70        It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
71        Otherwise, the time to save the checkpoint may be biased.
72        It is recommended to set only one save strategy and one keep strategy at the same time.
73        If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set,
74        `save_checkpoint_seconds` will be invalid.
75        If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set,
76        `keep_checkpoint_per_n_minutes` will be invalid.
77
78    Args:
79        save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
80        save_checkpoint_seconds (int): Seconds to save checkpoint.
81            Can't be used with save_checkpoint_steps at the same time. Default: 0.
82        keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5.
83        keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes.
84            Can't be used with keep_checkpoint_max at the same time. Default: 0.
85        integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
86            Integrated save function is only supported in automatic parallel scene, not supported
87            in manual parallel. Default: True.
88        async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
89        saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
90            with the network in training, the initial value of saved_network will be saved. Default: None.
91        append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and dict.
92            The key of dict must be str, the value of dict must be one of int float and bool. Default: None.
93        enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
94                                      is not required. Default: None.
95        enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
96                        mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
97
98    Raises:
99        ValueError: If input parameter is not the correct type.
100
101    Examples:
102        >>> from mindspore import Model, nn
103        >>> from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
104        >>>
105        >>> class LeNet5(nn.Cell):
106        ...     def __init__(self, num_class=10, num_channel=1):
107        ...         super(LeNet5, self).__init__()
108        ...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
109        ...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
110        ...         self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
111        ...         self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
112        ...         self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
113        ...         self.relu = nn.ReLU()
114        ...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
115        ...         self.flatten = nn.Flatten()
116        ...
117        ...     def construct(self, x):
118        ...         x = self.max_pool2d(self.relu(self.conv1(x)))
119        ...         x = self.max_pool2d(self.relu(self.conv2(x)))
120        ...         x = self.flatten(x)
121        ...         x = self.relu(self.fc1(x))
122        ...         x = self.relu(self.fc2(x))
123        ...         x = self.fc3(x)
124        ...         return x
125        >>>
126        >>> net = LeNet5()
127        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
128        >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
129        >>> model = Model(net, loss_fn=loss, optimizer=optim)
130        >>> data_path = './MNIST_Data'
131        >>> dataset = create_dataset(data_path)
132        >>> config = CheckpointConfig(saved_network=net)
133        >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
134        >>> model.train(10, dataset, callbacks=ckpoint_cb)
135    """
136
137    def __init__(self,
138                 save_checkpoint_steps=1,
139                 save_checkpoint_seconds=0,
140                 keep_checkpoint_max=5,
141                 keep_checkpoint_per_n_minutes=0,
142                 integrated_save=True,
143                 async_save=False,
144                 saved_network=None,
145                 append_info=None,
146                 enc_key=None,
147                 enc_mode='AES-GCM'):
148
149        if save_checkpoint_steps is not None:
150            save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
151        if save_checkpoint_seconds is not None:
152            save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds)
153        if keep_checkpoint_max is not None:
154            keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max)
155        if keep_checkpoint_per_n_minutes is not None:
156            keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
157
158        if saved_network is not None and not isinstance(saved_network, nn.Cell):
159            raise TypeError(f"For 'CheckpointConfig', the type of 'saved_network' must be None or Cell, "
160                            f"but got {str(type(saved_network))}.")
161
162        if not save_checkpoint_steps and not save_checkpoint_seconds and \
163                not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
164            raise ValueError("The input arguments 'save_checkpoint_steps', 'save_checkpoint_seconds', "
165                             "'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.")
166
167        self._save_checkpoint_steps = save_checkpoint_steps
168        self._save_checkpoint_seconds = save_checkpoint_seconds
169        if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
170            self._save_checkpoint_seconds = None
171
172        self._keep_checkpoint_max = keep_checkpoint_max
173        self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
174        if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
175            self._keep_checkpoint_per_n_minutes = None
176        else:
177            if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
178                self._keep_checkpoint_max = 1
179
180        self._integrated_save = Validator.check_bool(integrated_save)
181        self._async_save = Validator.check_bool(async_save)
182        self._saved_network = saved_network
183        self._append_dict = self._handle_append_info(append_info)
184        self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
185        self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
186
187    @property
188    def save_checkpoint_steps(self):
189        """Get the value of _save_checkpoint_steps."""
190        return self._save_checkpoint_steps
191
192    @property
193    def save_checkpoint_seconds(self):
194        """Get the value of _save_checkpoint_seconds."""
195        return self._save_checkpoint_seconds
196
197    @property
198    def keep_checkpoint_max(self):
199        """Get the value of _keep_checkpoint_max."""
200        return self._keep_checkpoint_max
201
202    @property
203    def keep_checkpoint_per_n_minutes(self):
204        """Get the value of _keep_checkpoint_per_n_minutes."""
205        return self._keep_checkpoint_per_n_minutes
206
207    @property
208    def integrated_save(self):
209        """Get the value of _integrated_save."""
210        return self._integrated_save
211
212    @property
213    def async_save(self):
214        """Get the value of _async_save."""
215        return self._async_save
216
217    @property
218    def saved_network(self):
219        """Get the value of _saved_network"""
220        return self._saved_network
221
222    @property
223    def enc_key(self):
224        """Get the value of _enc_key"""
225        return self._enc_key
226
227    @property
228    def enc_mode(self):
229        """Get the value of _enc_mode"""
230        return self._enc_mode
231
232    @property
233    def append_dict(self):
234        """Get the value of append_dict."""
235        return self._append_dict
236
237    def get_checkpoint_policy(self):
238        """Get the policy of checkpoint."""
239        checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
240                             'save_checkpoint_seconds': self.save_checkpoint_seconds,
241                             'keep_checkpoint_max': self.keep_checkpoint_max,
242                             'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes,
243                             'saved_network': self.saved_network}
244
245        return checkpoint_policy
246
247    @staticmethod
248    def _handle_append_info(append_info):
249        """Handle ckpt append info."""
250        if append_info is None or append_info == []:
251            return None
252        if not isinstance(append_info, list):
253            raise TypeError(f"The type of 'append_info' must be list, but got {str(type(append_info))}.")
254        handle_append_info = {}
255        if "epoch_num" in append_info:
256            handle_append_info["epoch_num"] = 0
257        if "step_num" in append_info:
258            handle_append_info["step_num"] = 0
259        dict_num = 0
260        for element in append_info:
261            if not isinstance(element, str) and not isinstance(element, dict):
262                raise TypeError(f"The type of 'append_info' element must be str or dict, "
263                                f"but got {str(type(element))}.")
264            if isinstance(element, str) and element not in _info_list:
265                raise ValueError(f"The value of element in the argument 'append_info' must be in {_info_list}, "
266                                 f"but got {element}.")
267            if isinstance(element, dict):
268                dict_num += 1
269                if dict_num > 1:
270                    raise TypeError(f"The element of 'append_info' must has only one dict.")
271                for key, value in element.items():
272                    if isinstance(key, str) and isinstance(value, (int, float, bool)):
273                        handle_append_info[key] = value
274                    else:
275                        raise TypeError(f"The type of dict in 'append_info' must be key: string, value: int or float, "
276                                        f"but got key: {type(key)}, value: {type(value)}")
277
278        return handle_append_info
279
280
281class ModelCheckpoint(Callback):
282    """
283    The checkpoint callback class.
284
285    It is called to combine with train process and save the model and network parameters after training.
286
287    Note:
288        In the distributed training scenario, please specify different directories for each training process
289        to save the checkpoint file. Otherwise, the training may fail.
290
291    Args:
292        prefix (str): The prefix name of checkpoint files. Default: "CKP".
293        directory (str): The path of the folder which will be saved in the checkpoint file.
294            By default, the file is saved in the current directory. Default: None.
295        config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
296
297    Raises:
298        ValueError: If the prefix is invalid.
299        TypeError: If the config is not CheckpointConfig type.
300    """
301
302    def __init__(self, prefix='CKP', directory=None, config=None):
303        super(ModelCheckpoint, self).__init__()
304        self._latest_ckpt_file_name = ""
305        self._init_time = time.time()
306        self._last_time = time.time()
307        self._last_time_for_keep = time.time()
308        self._last_triggered_step = 0
309
310        if not isinstance(prefix, str) or prefix.find('/') >= 0:
311            raise ValueError("The argument 'prefix' for checkpoint file name is invalid, 'prefix' must be "
312                             "string and does not contain '/', but got {}.".format(prefix))
313        self._prefix = prefix
314
315        if directory is not None:
316            self._directory = _make_directory(directory)
317        else:
318            self._directory = _cur_dir
319
320        if config is None:
321            self._config = CheckpointConfig()
322        else:
323            if not isinstance(config, CheckpointConfig):
324                raise TypeError("The argument 'config' should be 'CheckpointConfig' type, "
325                                "but got {}.".format(type(config)))
326            self._config = config
327
328        # get existing checkpoint files
329        self._manager = CheckpointManager()
330        self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
331        self._append_dict = self._config.append_dict or {}
332        self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
333        self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
334        self._graph_saved = False
335        self._need_flush_from_cache = True
336
337    def step_end(self, run_context):
338        """
339        Save the checkpoint at the end of step.
340
341        Args:
342            run_context (RunContext): Context of the train running.
343        """
344        if _is_role_pserver():
345            self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
346        cb_params = run_context.original_args()
347        _make_directory(self._directory)
348        # save graph (only once)
349        if not self._graph_saved:
350            graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
351            if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
352                os.remove(graph_file_name)
353            _save_graph(cb_params.train_network, graph_file_name)
354            self._graph_saved = True
355        thread_list = threading.enumerate()
356        for thread in thread_list:
357            if thread.getName() == "asyn_save_ckpt":
358                thread.join()
359        self._save_ckpt(cb_params)
360
361    def end(self, run_context):
362        """
363        Save the last checkpoint after training finished.
364
365        Args:
366            run_context (RunContext): Context of the train running.
367        """
368        cb_params = run_context.original_args()
369        _to_save_last_ckpt = True
370
371        self._save_ckpt(cb_params, _to_save_last_ckpt)
372
373        thread_list = threading.enumerate()
374        for thread in thread_list:
375            if thread.getName() == "asyn_save_ckpt":
376                thread.join()
377
378        destroy_allgather_cell()
379
380    def _check_save_ckpt(self, cb_params, force_to_save):
381        """Check whether save checkpoint files or not."""
382        if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
383            if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
384                    or force_to_save is True:
385                return True
386        elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
387            self._cur_time = time.time()
388            if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
389                self._last_time = self._cur_time
390                return True
391
392        return False
393
394    def _save_ckpt(self, cb_params, force_to_save=False):
395        """Save checkpoint files."""
396        if cb_params.cur_step_num == self._last_triggered_step:
397            return
398
399        # if param is cache enable, flush data from cache to host before save_ckpt
400        if self._need_flush_from_cache:
401            self._flush_from_cache(cb_params)
402
403        save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
404        step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
405
406        if save_ckpt:
407            cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
408                + str(step_num_in_epoch) + ".ckpt"
409            # update checkpoint file list.
410            self._manager.update_ckpoint_filelist(self._directory, self._prefix)
411            # keep checkpoint files number equal max number.
412            if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
413                self._manager.remove_oldest_ckpoint_file()
414            elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
415                self._cur_time_for_keep = time.time()
416                if (self._cur_time_for_keep - self._last_time_for_keep) \
417                        < self._config.keep_checkpoint_per_n_minutes * 60:
418                    self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
419                                                               self._cur_time_for_keep)
420
421            # generate the new checkpoint file and rename it.
422            global _save_dir
423            _save_dir = self._directory
424            cur_file = os.path.join(self._directory, cur_ckpoint_file)
425            self._last_time_for_keep = time.time()
426            self._last_triggered_step = cb_params.cur_step_num
427
428            if context.get_context("enable_ge"):
429                set_cur_net(cb_params.train_network)
430                cb_params.train_network.exec_checkpoint_graph()
431            if "epoch_num" in self._append_dict:
432                self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
433            if "step_num" in self._append_dict:
434                self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
435            network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
436            save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
437                            self._append_dict, self._config.enc_key, self._config.enc_mode)
438
439            self._latest_ckpt_file_name = cur_file
440
441    def _flush_from_cache(self, cb_params):
442        """Flush cache data to host if tensor is cache enable."""
443        has_cache_params = False
444        params = cb_params.train_network.get_parameters()
445        for param in params:
446            if param.cache_enable:
447                has_cache_params = True
448                Tensor(param).flush_from_cache()
449        if not has_cache_params:
450            self._need_flush_from_cache = False
451
452    @property
453    def latest_ckpt_file_name(self):
454        """Return the latest checkpoint path and file name."""
455        return self._latest_ckpt_file_name
456
457
458class CheckpointManager:
459    """Manage checkpoint files according to train_config of checkpoint."""
460
461    def __init__(self):
462        self._ckpoint_filelist = []
463
464    @property
465    def ckpoint_filelist(self):
466        """Get all the related checkpoint files managed here."""
467        return self._ckpoint_filelist
468
469    @property
470    def ckpoint_num(self):
471        """Get the number of the related checkpoint files managed here."""
472        return len(self._ckpoint_filelist)
473
474    def update_ckpoint_filelist(self, directory, prefix):
475        """Update the checkpoint file list."""
476        self._ckpoint_filelist = []
477        files = os.listdir(directory)
478        for filename in files:
479            if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
480                mid_name = filename[len(prefix):-5]
481                flag = not (True in [char.isalpha() for char in mid_name])
482                if flag:
483                    self._ckpoint_filelist.append(os.path.join(directory, filename))
484
485    def remove_ckpoint_file(self, file_name):
486        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
487        try:
488            os.chmod(file_name, stat.S_IWRITE)
489            os.remove(file_name)
490            self._ckpoint_filelist.remove(file_name)
491        except OSError:
492            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
493        except ValueError:
494            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
495
496    def remove_oldest_ckpoint_file(self):
497        """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
498        ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
499        self.remove_ckpoint_file(ckpoint_files[0])
500
501    def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
502        """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
503        del_list = []
504        oldest_file = ''
505        oldest_time = cur_time
506        for ck_file in self._ckpoint_filelist:
507            modify_time = os.path.getmtime(ck_file)
508            if cur_time - modify_time < 60 * minutes:
509                del_list.append(ck_file)
510
511                if modify_time < oldest_time:
512                    oldest_time = modify_time
513                    oldest_file = ck_file
514
515        for mv_file in del_list:
516            if mv_file == oldest_file:
517                continue
518            self.remove_ckpoint_file(mv_file)
519