• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Checkpoint related classes and functions."""
16
17import os
18import sys
19import copy
20from mindspore.train.serialization import save_checkpoint, _convert_cell_param_and_names_to_dict, _get_merged_param_data
21from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
22from mindspore.parallel._utils import _get_device_num
23from mindspore import _checkparam as Validator
24from mindspore.train.callback._callback import Callback
25from mindspore.common.tensor import Tensor
26from mindspore import context
27import mindspore as ms
28from mindspore.communication import get_rank
29from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
30
31from mindspore.train._utils import get_parameter_redundancy
32from mindspore import log as logger
33from mindspore.parallel._utils import _is_in_auto_parallel_mode
34from mindspore.common.api import _get_parameter_layout
35
36
37def _get_dp_from_layout(parameter_layout_dict):
38    """ Get dp and tp from layout dict. """
39    pp_num = _get_auto_parallel_context("pipeline_stages")
40    dev_num = _get_device_num()
41    global_rank = get_rank()
42    pipe_size = dev_num // pp_num
43    initial_rank = (global_rank // pipe_size) * pipe_size
44    parameter_redundancy_dict = get_parameter_redundancy(
45        parameter_layout_dict, initial_rank)
46    value_len = sys.maxsize
47    min_value = ()
48    for key, value in parameter_redundancy_dict.items():
49        if "accu_grads" in key or "inputs" in key:
50            continue
51        for item in value:
52            if len(item) < value_len and global_rank in item:
53                value_len = len(item)
54                min_value = item
55    return min_value
56
57
58def _get_ckpt_dir(append_dict, ckpt_save_path, is_tmp_file):
59    """ Common func to generate ckpt dir name."""
60    tmp = "_tmp" if is_tmp_file else ""
61    mid_dir = f"ttp_saved_checkpoints-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}{tmp}"
62    return os.path.join(ckpt_save_path, mid_dir)
63
64
65def _flush_from_cache(cb_params):
66    """ Flush cache data to host if tensor is cache enable."""
67    params = cb_params.train_network.get_parameters()
68    for param in params:
69        if param.cache_enable:
70            Tensor(param).flush_from_cache()
71
72
73def _save_checkpoint_on_failure(save_rank, step, rank_list, save_args):
74    """ Callback used for TTP save ckpt function when errors occur."""
75    logger.info("Enter _save_checkpoint_on_failure function")
76    ckpt_save_path, save_params, append_dict = save_args
77    ckpt_file = f"iteration-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}.ckpt"
78    cur_ckpt_dir = _get_ckpt_dir(
79        append_dict, ckpt_save_path, True) + "/rank_" + str(save_rank)
80    os.makedirs(cur_ckpt_dir)
81    cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
82    save_checkpoint(save_params, cur_file,
83                    integrated_save=False, append_dict=append_dict)
84    logger.info("Finish _save_checkpoint_on_failure function")
85
86
87def _convert_net_to_param_list(save_obj):
88    """Convert nn.Cell to param_list."""
89    sync_pipeline_shared_parameters(save_obj)
90    param_list = []
91    parameter_layout_dict = save_obj.parameter_layout_dict
92    if _is_in_auto_parallel_mode() and not parameter_layout_dict:
93        parameter_layout_dict = _get_parameter_layout()
94    if not _is_in_auto_parallel_mode():
95        save_obj.init_parameters_data()
96    param_dict = _convert_cell_param_and_names_to_dict(save_obj, None)
97    for (key, value) in param_dict.items():
98        each_param = {"name": key}
99        param_data = Tensor(value.asnumpy())
100        # in automatic model parallel scenario, some parameters were split to all the devices,
101        # which should be combined before saving
102        if key in parameter_layout_dict:
103            param_data = _get_merged_param_data(
104                save_obj, parameter_layout_dict, key, param_data, False)
105        each_param["data"] = param_data
106        param_list.append(each_param)
107    return param_list
108
109
110def _rename_save_result(rename_args):
111    """ Callback used for TTP rename function after ckpt save callback was finished and successful."""
112    logger.info("Enter _rename_save_result function")
113    ckpt_save_path, _, append_dict = rename_args
114
115    tmp_dir = _get_ckpt_dir(append_dict, ckpt_save_path, True)
116    fin_dir = _get_ckpt_dir(append_dict, ckpt_save_path, False)
117
118    os.rename(tmp_dir, fin_dir)
119    logger.info("Finish _rename_save_result function")
120
121
122class MindIOTTPAdapter(Callback):
123    """
124    This callback is used to enable the feature
125    `MindIO TTP <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindiottp/mindiottp001.html>`_.
126    This callback will execute TTP operations during training process, such as TTP init, report and exception handle.
127
128    Note:
129        Required for Ascend GE LazyInline mode only. And pipline size must be greater than 1.
130
131    Args:
132        controller_ip (str): TTP controller's ip address, used for init TTP controller.
133        controller_port (int): TTP controller's ip port, used for init TTP controller and processor.
134        ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
135           named ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num} under this directory.
136
137    Raises:
138        Exception: TTP init failed.
139        ModuleNotFoundError: Mindio TTP whl package is not installed.
140
141    Examples:
142        >>> import numpy as np
143        >>> import os
144        >>> import math
145        >>> import mindspore as ms
146        >>> import mindspore.dataset as ds
147        >>> from mindspore import nn, ops, Parameter, train
148        >>> from mindspore.communication import init
149        >>> from mindspore.common.initializer import initializer, HeUniform
150        >>> from mindspore.train import Model, MindIOTTPAdapter
151        >>> from mindspore import dataset as ds
152        >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
153        >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
154        >>> init()
155        >>> ms.set_seed(1)
156        >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
157        >>>                             "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
158        >>> class MatMulCell(nn.Cell):
159        ...     def __init__(self, param=None, shape=None):
160        ...         super().__init__()
161        ...         if shape is None:
162        ...             shape = [28 * 28, 512]
163        ...         weight_init = HeUniform(math.sqrt(5))
164        ...         self.param = Parameter(initializer(weight_init, shape), name="param")
165        ...         if param is not None:
166        ...             self.param = param
167        ...         self.print = ops.Print()
168        ...         self.matmul = ops.MatMul()
169        ...
170        ...     def construct(self, x):
171        ...         out = self.matmul(x, self.param)
172        ...         self.print("out is:", out)
173        ...         return out
174        >>>
175        >>> class Network(nn.Cell):
176        ...     def __init__(self):
177        ...         super().__init__()
178        ...         self.flatten = nn.Flatten()
179        ...         self.layer1 = MatMulCell()
180        ...         self.relu1 = nn.ReLU()
181        ...         self.layer2 = nn.Dense(512, 512)
182        ...         self.relu2 = nn.ReLU()
183        ...         self.layer3 = nn.Dense(512, 10)
184        ...
185        ...     def construct(self, x):
186        ...         x = self.flatten(x)
187        ...         x = self.layer1(x)
188        ...         x = self.relu1(x)
189        ...         x = self.layer2(x)
190        ...         x = self.relu2(x)
191        ...         logits = self.layer3(x)
192        ...         return logits
193        >>>
194        >>> net = Network()
195        >>> net.layer1.pipeline_stage = 0
196        >>> net.relu1.pipeline_stage = 0
197        >>> net.layer2.pipeline_stage = 0
198        >>> net.relu2.pipeline_stage = 1
199        >>> net.layer3.pipeline_stage = 1
200        >>>
201        >>> def create_dataset(batch_size):
202        ...     dataset_path = os.getenv("DATA_PATH")
203        ...     dataset = ds.MnistDataset(dataset_path)
204        ...     image_transforms = [
205        ...         ds.vision.Rescale(1.0 / 255.0, 0),
206        ...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
207        ...         ds.vision.HWC2CHW()
208        ...     ]
209        ...     label_transform = ds.transforms.TypeCast(ms.int32)
210        ...     dataset = dataset.map(image_transforms, 'image')
211        ...     dataset = dataset.map(label_transform, 'label')
212        ...     dataset = dataset.batch(batch_size)
213        ...     return dataset
214        >>>
215        >>> data_set = create_dataset(32)
216        >>>
217        >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
218        >>> loss_fn = nn.CrossEntropyLoss()
219        >>>
220        >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
221        >>> net_with_loss.set_train()
222        >>> model = Model(net_with_loss, optimizer=optimizer)
223        >>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/")
224        >>> loss_cb = train.LossMonitor(1)
225        >>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb])
226    """
227
228    def __init__(self, controller_ip, controller_port, ckpt_save_path):
229        super(MindIOTTPAdapter, self).__init__()
230        # let it raises errors if not install mindio_ttp package
231        from mindio_ttp import framework_ttp as ttp
232        self.ttp = ttp
233        Validator.check_non_negative_int(controller_port)
234        self.has_init = False
235        self.enable = False
236        mode = context.get_context("mode")
237        if context.get_context("device_target") != "Ascend" or mode != context.GRAPH_MODE:
238            logger.warning(
239                "MindIO adataper only support on Ascend device with GRAPH Mode.")
240            return
241        if os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT") != "true":
242            logger.warning("MindIO adataper need custom switch on.")
243            return
244        ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
245        if ttp_lib_path is None or os.path.isfile(ttp_lib_path) is False:
246            logger.warning(
247                "MindIO adataper switch on, but ttp library path is not correct.")
248            return
249        self.enable = True
250        self._controller_ip = controller_ip
251        self._controller_port = controller_port
252        self._ckpt_save_path = ckpt_save_path
253
254    def wrapper_ttp_persist(self, func):
255        """
256        This method is used to wrapper TTP exception handler for the input func.
257
258        Args:
259            func (function): train method that need to be wrapper.
260
261        Returns:
262            Function, if the TTP is enabled, return the encapsulated function,
263            otherwise the original function is returned.
264
265        """
266        if self.enable:
267            return self.ttp.ttp_to_persist(func)
268        return func
269
270    def _init_ttp(self, run_context):
271        """ Init Mindio TTP, used internal. """
272        logger.info("Begin to init ttp.")
273        dev_num = _get_device_num()
274
275        cb_params = run_context.original_args()
276        param_layout_dict = cb_params.train_network.parameter_layout_dict
277        dp = _get_dp_from_layout(param_layout_dict)
278        logger.info("Init TTP with dp: {}.".format(dp))
279
280        self.ttp.ttp_register_save_ckpt_handler(_save_checkpoint_on_failure)
281        self.ttp.ttp_register_rename_handler(_rename_save_result)
282
283        world_size = dev_num
284        cur_rank = get_rank()
285        is_odd = len(dp) % 2
286        replica = 2 if is_odd else len(dp) // 2
287        enable_local_copy = False
288        if cur_rank == 0:
289            logger.info("Begin to start ttp controller.")
290            self.ttp.ttp_init_controller(
291                cur_rank, world_size, replica, enable_local_copy)
292            self.ttp.ttp_start_controller(
293                self._controller_ip, self._controller_port)
294            logger.info("Finish start ttp controller.")
295
296        logger.info("Begin to start ttp processor.")
297        self.ttp.ttp_init_processor(cur_rank, dp, len(
298            dp), world_size, replica, enable_local_copy)
299        self.ttp.ttp_start_processor(
300            self._controller_ip, self._controller_port)
301        logger.info("Finished start ttp processor.")
302
303        logger.info("Finish init ttp.")
304
305    def on_train_step_end(self, run_context):
306        """
307        Init TTP Controller only once after first step finished.
308        And report status to MindIO TTP after every step finished.
309
310        Args:
311            run_context (RunContext): Context of the train running. Refer to
312                                      :class:`mindspore.train.RunContext` for detail.
313
314        """
315
316        if self.enable is False:
317            return
318        pp_num = _get_auto_parallel_context("pipeline_stages")
319        if pp_num < 2:
320            self.enable = False
321            return
322        cb_params = run_context.original_args()
323        if cb_params.dataset_sink_mode is True and cb_params.sink_size > 1:
324            self.enable = False
325            return
326        if self.has_init is False:
327            self.has_init = True
328            self._init_ttp(run_context)
329        _flush_from_cache(cb_params)
330        cur_rank = get_rank()
331        append_dict = {}
332        append_dict["cur_epoch_num"] = cb_params.cur_epoch_num
333        append_dict["cur_step_num"] = int(
334            (cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
335        append_dict["cur_rank"] = cur_rank
336        append_dict["batch_num"] = cb_params.batch_num
337        append_dict["global_step"] = cb_params.cur_step_num
338
339        save_params = _convert_net_to_param_list(cb_params.train_network)
340        save_params_copy = copy.deepcopy(save_params)
341
342        logger.info("Set ckpt args to TTP.")
343        self.ttp.ttp_set_ckpt_args(
344            (self._ckpt_save_path, save_params_copy, append_dict))
345        logger.info("Set optimizer finish step status to TTP.")
346        self.ttp.ttp_end_updating_os(cb_params.cur_step_num)
347
348    @staticmethod
349    def load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net):
350        """
351        Load checkpoint into network, and use strategy file to find backup checkpoint file
352        when origin checkpoint file not found.
353
354        Note:
355           This API must be called after the communication is initialized because the cluster information
356           needs to be obtained internally.
357
358        Args:
359            ckpt_file_path (str): the checkpoint file to be loaded.
360            strategy_file_path (str): strategy file path for current rank.
361            net (Cell): network that needs to load checkpoint.
362
363        Returns:
364            Dict, checkpoint weights after loaded.
365
366        Raises:
367            ValueError: Failed to load the checkpoint file.
368
369        Examples:
370            >>> import numpy as np
371            >>> from mindspore import nn
372            >>> from mindspore.train import Model, MindIOTTPAdapter
373            >>> from mindspore import dataset as ds
374            >>> ms.set_context(mode=ms.GRAPH_MODE)
375            >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
376            >>> init()
377            >>> ms.set_seed(1)
378            >>> class Network(nn.Cell):
379            ...     def __init__(self):
380            ...         super().__init__()
381            ...         self.flatten = nn.Flatten()
382            ...         self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros")
383            ...         self.relu = nn.ReLU()
384            ...
385            ...     def construct(self, x):
386            ...         x = self.flatten(x)
387            ...         logits = self.relu(self.fc(x))
388            ...         return logits
389            >>>
390            >>> net = Network()
391            >>>
392            >>> def create_dataset(batch_size):
393            ...     dataset_path = os.getenv("DATA_PATH")
394            ...     rank_id = get_rank()
395            ...     rank_size = get_group_size()
396            ...     dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
397            ...     image_transforms = [
398            ...         ds.vision.Rescale(1.0 / 255.0, 0),
399            ...         ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
400            ...         ds.vision.HWC2CHW()
401            ...     ]
402            ...     label_transform = ds.transforms.TypeCast(ms.int32)
403            ...     dataset = dataset.map(image_transforms, 'image')
404            ...     dataset = dataset.map(label_transform, 'label')
405            ...     dataset = dataset.batch(batch_size)
406            ...     return dataset
407            >>> data_set = create_dataset(32)
408            >>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
409            >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
410            >>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net)
411            >>> data_set.set_init_step(param_dict["global_step"])
412        """
413        logger.info("Start load checkpoint with strategy file.")
414        try:
415            param_dict = ms.load_checkpoint(ckpt_file_path)
416        except ValueError as e:
417            logger.warning(
418                "Loading origin checkpoint file failed, the reason is:{}.".format(str(e)))
419            dp = _get_dp_from_layout(strategy_file_path)
420            rank = get_rank()
421            logger.info(
422                "Can't load origin checkpoint file, found dp:{}.".format(dp))
423            for i in dp:
424                if i == rank:
425                    continue
426                new_ckpt = ckpt_file_path.replace(
427                    f"/rank_{rank}/", f"/rank_{str(i)}/")
428                if not os.path.isfile(new_ckpt):
429                    continue
430                try:
431                    param_dict = ms.load_checkpoint(new_ckpt)
432                except ValueError as e1:
433                    logger.warning(
434                        "Loading strategy checkpoint file failed, the reason is:{}.".format(str(e1)))
435                    param_dict = None
436        if param_dict:
437            logger.info("Found param dict, load it into network.")
438            ms.load_param_into_net(net, param_dict)
439        else:
440            raise ValueError(
441                "Load checkpoint file failed, please check your config is set correctly.")
442        logger.info("Finish load checkpoint with strategy file.")
443        return param_dict
444