1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Checkpoint related classes and functions.""" 16 17import os 18import sys 19import copy 20from mindspore.train.serialization import save_checkpoint, _convert_cell_param_and_names_to_dict, _get_merged_param_data 21from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context 22from mindspore.parallel._utils import _get_device_num 23from mindspore import _checkparam as Validator 24from mindspore.train.callback._callback import Callback 25from mindspore.common.tensor import Tensor 26from mindspore import context 27import mindspore as ms 28from mindspore.communication import get_rank 29from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters 30 31from mindspore.train._utils import get_parameter_redundancy 32from mindspore import log as logger 33from mindspore.parallel._utils import _is_in_auto_parallel_mode 34from mindspore.common.api import _get_parameter_layout 35 36 37def _get_dp_from_layout(parameter_layout_dict): 38 """ Get dp and tp from layout dict. """ 39 pp_num = _get_auto_parallel_context("pipeline_stages") 40 dev_num = _get_device_num() 41 global_rank = get_rank() 42 pipe_size = dev_num // pp_num 43 initial_rank = (global_rank // pipe_size) * pipe_size 44 parameter_redundancy_dict = get_parameter_redundancy( 45 parameter_layout_dict, initial_rank) 46 value_len = sys.maxsize 47 min_value = () 48 for key, value in parameter_redundancy_dict.items(): 49 if "accu_grads" in key or "inputs" in key: 50 continue 51 for item in value: 52 if len(item) < value_len and global_rank in item: 53 value_len = len(item) 54 min_value = item 55 return min_value 56 57 58def _get_ckpt_dir(append_dict, ckpt_save_path, is_tmp_file): 59 """ Common func to generate ckpt dir name.""" 60 tmp = "_tmp" if is_tmp_file else "" 61 mid_dir = f"ttp_saved_checkpoints-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}{tmp}" 62 return os.path.join(ckpt_save_path, mid_dir) 63 64 65def _flush_from_cache(cb_params): 66 """ Flush cache data to host if tensor is cache enable.""" 67 params = cb_params.train_network.get_parameters() 68 for param in params: 69 if param.cache_enable: 70 Tensor(param).flush_from_cache() 71 72 73def _save_checkpoint_on_failure(save_rank, step, rank_list, save_args): 74 """ Callback used for TTP save ckpt function when errors occur.""" 75 logger.info("Enter _save_checkpoint_on_failure function") 76 ckpt_save_path, save_params, append_dict = save_args 77 ckpt_file = f"iteration-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}.ckpt" 78 cur_ckpt_dir = _get_ckpt_dir( 79 append_dict, ckpt_save_path, True) + "/rank_" + str(save_rank) 80 os.makedirs(cur_ckpt_dir) 81 cur_file = os.path.join(cur_ckpt_dir, ckpt_file) 82 save_checkpoint(save_params, cur_file, 83 integrated_save=False, append_dict=append_dict) 84 logger.info("Finish _save_checkpoint_on_failure function") 85 86 87def _convert_net_to_param_list(save_obj): 88 """Convert nn.Cell to param_list.""" 89 sync_pipeline_shared_parameters(save_obj) 90 param_list = [] 91 parameter_layout_dict = save_obj.parameter_layout_dict 92 if _is_in_auto_parallel_mode() and not parameter_layout_dict: 93 parameter_layout_dict = _get_parameter_layout() 94 if not _is_in_auto_parallel_mode(): 95 save_obj.init_parameters_data() 96 param_dict = _convert_cell_param_and_names_to_dict(save_obj, None) 97 for (key, value) in param_dict.items(): 98 each_param = {"name": key} 99 param_data = Tensor(value.asnumpy()) 100 # in automatic model parallel scenario, some parameters were split to all the devices, 101 # which should be combined before saving 102 if key in parameter_layout_dict: 103 param_data = _get_merged_param_data( 104 save_obj, parameter_layout_dict, key, param_data, False) 105 each_param["data"] = param_data 106 param_list.append(each_param) 107 return param_list 108 109 110def _rename_save_result(rename_args): 111 """ Callback used for TTP rename function after ckpt save callback was finished and successful.""" 112 logger.info("Enter _rename_save_result function") 113 ckpt_save_path, _, append_dict = rename_args 114 115 tmp_dir = _get_ckpt_dir(append_dict, ckpt_save_path, True) 116 fin_dir = _get_ckpt_dir(append_dict, ckpt_save_path, False) 117 118 os.rename(tmp_dir, fin_dir) 119 logger.info("Finish _rename_save_result function") 120 121 122class MindIOTTPAdapter(Callback): 123 """ 124 This callback is used to enable the feature 125 `MindIO TTP <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindiottp/mindiottp001.html>`_. 126 This callback will execute TTP operations during training process, such as TTP init, report and exception handle. 127 128 Note: 129 Required for Ascend GE LazyInline mode only. And pipline size must be greater than 1. 130 131 Args: 132 controller_ip (str): TTP controller's ip address, used for init TTP controller. 133 controller_port (int): TTP controller's ip port, used for init TTP controller and processor. 134 ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory 135 named ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num} under this directory. 136 137 Raises: 138 Exception: TTP init failed. 139 ModuleNotFoundError: Mindio TTP whl package is not installed. 140 141 Examples: 142 >>> import numpy as np 143 >>> import os 144 >>> import math 145 >>> import mindspore as ms 146 >>> import mindspore.dataset as ds 147 >>> from mindspore import nn, ops, Parameter, train 148 >>> from mindspore.communication import init 149 >>> from mindspore.common.initializer import initializer, HeUniform 150 >>> from mindspore.train import Model, MindIOTTPAdapter 151 >>> from mindspore import dataset as ds 152 >>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2') 153 >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) 154 >>> init() 155 >>> ms.set_seed(1) 156 >>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": 157 >>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())}) 158 >>> class MatMulCell(nn.Cell): 159 ... def __init__(self, param=None, shape=None): 160 ... super().__init__() 161 ... if shape is None: 162 ... shape = [28 * 28, 512] 163 ... weight_init = HeUniform(math.sqrt(5)) 164 ... self.param = Parameter(initializer(weight_init, shape), name="param") 165 ... if param is not None: 166 ... self.param = param 167 ... self.print = ops.Print() 168 ... self.matmul = ops.MatMul() 169 ... 170 ... def construct(self, x): 171 ... out = self.matmul(x, self.param) 172 ... self.print("out is:", out) 173 ... return out 174 >>> 175 >>> class Network(nn.Cell): 176 ... def __init__(self): 177 ... super().__init__() 178 ... self.flatten = nn.Flatten() 179 ... self.layer1 = MatMulCell() 180 ... self.relu1 = nn.ReLU() 181 ... self.layer2 = nn.Dense(512, 512) 182 ... self.relu2 = nn.ReLU() 183 ... self.layer3 = nn.Dense(512, 10) 184 ... 185 ... def construct(self, x): 186 ... x = self.flatten(x) 187 ... x = self.layer1(x) 188 ... x = self.relu1(x) 189 ... x = self.layer2(x) 190 ... x = self.relu2(x) 191 ... logits = self.layer3(x) 192 ... return logits 193 >>> 194 >>> net = Network() 195 >>> net.layer1.pipeline_stage = 0 196 >>> net.relu1.pipeline_stage = 0 197 >>> net.layer2.pipeline_stage = 0 198 >>> net.relu2.pipeline_stage = 1 199 >>> net.layer3.pipeline_stage = 1 200 >>> 201 >>> def create_dataset(batch_size): 202 ... dataset_path = os.getenv("DATA_PATH") 203 ... dataset = ds.MnistDataset(dataset_path) 204 ... image_transforms = [ 205 ... ds.vision.Rescale(1.0 / 255.0, 0), 206 ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), 207 ... ds.vision.HWC2CHW() 208 ... ] 209 ... label_transform = ds.transforms.TypeCast(ms.int32) 210 ... dataset = dataset.map(image_transforms, 'image') 211 ... dataset = dataset.map(label_transform, 'label') 212 ... dataset = dataset.batch(batch_size) 213 ... return dataset 214 >>> 215 >>> data_set = create_dataset(32) 216 >>> 217 >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) 218 >>> loss_fn = nn.CrossEntropyLoss() 219 >>> 220 >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4) 221 >>> net_with_loss.set_train() 222 >>> model = Model(net_with_loss, optimizer=optimizer) 223 >>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/") 224 >>> loss_cb = train.LossMonitor(1) 225 >>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb]) 226 """ 227 228 def __init__(self, controller_ip, controller_port, ckpt_save_path): 229 super(MindIOTTPAdapter, self).__init__() 230 # let it raises errors if not install mindio_ttp package 231 from mindio_ttp import framework_ttp as ttp 232 self.ttp = ttp 233 Validator.check_non_negative_int(controller_port) 234 self.has_init = False 235 self.enable = False 236 mode = context.get_context("mode") 237 if context.get_context("device_target") != "Ascend" or mode != context.GRAPH_MODE: 238 logger.warning( 239 "MindIO adataper only support on Ascend device with GRAPH Mode.") 240 return 241 if os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT") != "true": 242 logger.warning("MindIO adataper need custom switch on.") 243 return 244 ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH") 245 if ttp_lib_path is None or os.path.isfile(ttp_lib_path) is False: 246 logger.warning( 247 "MindIO adataper switch on, but ttp library path is not correct.") 248 return 249 self.enable = True 250 self._controller_ip = controller_ip 251 self._controller_port = controller_port 252 self._ckpt_save_path = ckpt_save_path 253 254 def wrapper_ttp_persist(self, func): 255 """ 256 This method is used to wrapper TTP exception handler for the input func. 257 258 Args: 259 func (function): train method that need to be wrapper. 260 261 Returns: 262 Function, if the TTP is enabled, return the encapsulated function, 263 otherwise the original function is returned. 264 265 """ 266 if self.enable: 267 return self.ttp.ttp_to_persist(func) 268 return func 269 270 def _init_ttp(self, run_context): 271 """ Init Mindio TTP, used internal. """ 272 logger.info("Begin to init ttp.") 273 dev_num = _get_device_num() 274 275 cb_params = run_context.original_args() 276 param_layout_dict = cb_params.train_network.parameter_layout_dict 277 dp = _get_dp_from_layout(param_layout_dict) 278 logger.info("Init TTP with dp: {}.".format(dp)) 279 280 self.ttp.ttp_register_save_ckpt_handler(_save_checkpoint_on_failure) 281 self.ttp.ttp_register_rename_handler(_rename_save_result) 282 283 world_size = dev_num 284 cur_rank = get_rank() 285 is_odd = len(dp) % 2 286 replica = 2 if is_odd else len(dp) // 2 287 enable_local_copy = False 288 if cur_rank == 0: 289 logger.info("Begin to start ttp controller.") 290 self.ttp.ttp_init_controller( 291 cur_rank, world_size, replica, enable_local_copy) 292 self.ttp.ttp_start_controller( 293 self._controller_ip, self._controller_port) 294 logger.info("Finish start ttp controller.") 295 296 logger.info("Begin to start ttp processor.") 297 self.ttp.ttp_init_processor(cur_rank, dp, len( 298 dp), world_size, replica, enable_local_copy) 299 self.ttp.ttp_start_processor( 300 self._controller_ip, self._controller_port) 301 logger.info("Finished start ttp processor.") 302 303 logger.info("Finish init ttp.") 304 305 def on_train_step_end(self, run_context): 306 """ 307 Init TTP Controller only once after first step finished. 308 And report status to MindIO TTP after every step finished. 309 310 Args: 311 run_context (RunContext): Context of the train running. Refer to 312 :class:`mindspore.train.RunContext` for detail. 313 314 """ 315 316 if self.enable is False: 317 return 318 pp_num = _get_auto_parallel_context("pipeline_stages") 319 if pp_num < 2: 320 self.enable = False 321 return 322 cb_params = run_context.original_args() 323 if cb_params.dataset_sink_mode is True and cb_params.sink_size > 1: 324 self.enable = False 325 return 326 if self.has_init is False: 327 self.has_init = True 328 self._init_ttp(run_context) 329 _flush_from_cache(cb_params) 330 cur_rank = get_rank() 331 append_dict = {} 332 append_dict["cur_epoch_num"] = cb_params.cur_epoch_num 333 append_dict["cur_step_num"] = int( 334 (cb_params.cur_step_num - 1) % cb_params.batch_num + 1) 335 append_dict["cur_rank"] = cur_rank 336 append_dict["batch_num"] = cb_params.batch_num 337 append_dict["global_step"] = cb_params.cur_step_num 338 339 save_params = _convert_net_to_param_list(cb_params.train_network) 340 save_params_copy = copy.deepcopy(save_params) 341 342 logger.info("Set ckpt args to TTP.") 343 self.ttp.ttp_set_ckpt_args( 344 (self._ckpt_save_path, save_params_copy, append_dict)) 345 logger.info("Set optimizer finish step status to TTP.") 346 self.ttp.ttp_end_updating_os(cb_params.cur_step_num) 347 348 @staticmethod 349 def load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net): 350 """ 351 Load checkpoint into network, and use strategy file to find backup checkpoint file 352 when origin checkpoint file not found. 353 354 Note: 355 This API must be called after the communication is initialized because the cluster information 356 needs to be obtained internally. 357 358 Args: 359 ckpt_file_path (str): the checkpoint file to be loaded. 360 strategy_file_path (str): strategy file path for current rank. 361 net (Cell): network that needs to load checkpoint. 362 363 Returns: 364 Dict, checkpoint weights after loaded. 365 366 Raises: 367 ValueError: Failed to load the checkpoint file. 368 369 Examples: 370 >>> import numpy as np 371 >>> from mindspore import nn 372 >>> from mindspore.train import Model, MindIOTTPAdapter 373 >>> from mindspore import dataset as ds 374 >>> ms.set_context(mode=ms.GRAPH_MODE) 375 >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) 376 >>> init() 377 >>> ms.set_seed(1) 378 >>> class Network(nn.Cell): 379 ... def __init__(self): 380 ... super().__init__() 381 ... self.flatten = nn.Flatten() 382 ... self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros") 383 ... self.relu = nn.ReLU() 384 ... 385 ... def construct(self, x): 386 ... x = self.flatten(x) 387 ... logits = self.relu(self.fc(x)) 388 ... return logits 389 >>> 390 >>> net = Network() 391 >>> 392 >>> def create_dataset(batch_size): 393 ... dataset_path = os.getenv("DATA_PATH") 394 ... rank_id = get_rank() 395 ... rank_size = get_group_size() 396 ... dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id) 397 ... image_transforms = [ 398 ... ds.vision.Rescale(1.0 / 255.0, 0), 399 ... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), 400 ... ds.vision.HWC2CHW() 401 ... ] 402 ... label_transform = ds.transforms.TypeCast(ms.int32) 403 ... dataset = dataset.map(image_transforms, 'image') 404 ... dataset = dataset.map(label_transform, 'label') 405 ... dataset = dataset.batch(batch_size) 406 ... return dataset 407 >>> data_set = create_dataset(32) 408 >>> ckpt_file= "./rank_5/iteration-1_40.ckpt" 409 >>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt" 410 >>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net) 411 >>> data_set.set_init_step(param_dict["global_step"]) 412 """ 413 logger.info("Start load checkpoint with strategy file.") 414 try: 415 param_dict = ms.load_checkpoint(ckpt_file_path) 416 except ValueError as e: 417 logger.warning( 418 "Loading origin checkpoint file failed, the reason is:{}.".format(str(e))) 419 dp = _get_dp_from_layout(strategy_file_path) 420 rank = get_rank() 421 logger.info( 422 "Can't load origin checkpoint file, found dp:{}.".format(dp)) 423 for i in dp: 424 if i == rank: 425 continue 426 new_ckpt = ckpt_file_path.replace( 427 f"/rank_{rank}/", f"/rank_{str(i)}/") 428 if not os.path.isfile(new_ckpt): 429 continue 430 try: 431 param_dict = ms.load_checkpoint(new_ckpt) 432 except ValueError as e1: 433 logger.warning( 434 "Loading strategy checkpoint file failed, the reason is:{}.".format(str(e1))) 435 param_dict = None 436 if param_dict: 437 logger.info("Found param dict, load it into network.") 438 ms.load_param_into_net(net, param_dict) 439 else: 440 raise ValueError( 441 "Load checkpoint file failed, please check your config is set correctly.") 442 logger.info("Finish load checkpoint with strategy file.") 443 return param_dict 444