1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Model and parameters serialization.""" 16import os 17import sys 18import stat 19import math 20import shutil 21import time 22import copy 23import json 24import threading 25from threading import Thread, Lock 26from collections import defaultdict 27 28import numpy as np 29 30import mindspore 31import mindspore.nn as nn 32from mindspore import context 33from mindspore import log as logger 34from mindspore.train.checkpoint_pb2 import Checkpoint 35from mindspore.train.print_pb2 import Print 36from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts 37from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model 38from mindspore.train.mind_ir_pb2 import GraphProto as graph_proto 39from mindspore.common.tensor import Tensor 40from mindspore.common.initializer import initializer 41from mindspore.common.parameter import Parameter 42from mindspore.common.api import _cell_graph_executor as _executor 43from mindspore.common import dtype as mstype 44from mindspore._checkparam import check_input_data, Validator 45from mindspore.compression.export import quant_export 46from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index 47from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices 48from mindspore.communication.management import get_rank, get_group_size 49from mindspore.parallel._tensor import _reshape_param_data_with_weight 50from mindspore.parallel._cell_wrapper import get_allgather_cell 51from mindspore.parallel._tensor import _reshape_param_data 52from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file 53 54 55tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, 56 "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, 57 "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64, 58 "Bool": mstype.bool_} 59 60tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16, 61 "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, 62 "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} 63 64_ckpt_mutex = Lock() 65 66# unit is KB 67SLICE_SIZE = 512 * 1024 68PROTO_LIMIT_SIZE = 1024 * 1024 * 2 69TOTAL_SAVE = 1024 * 1024 70 71 72def _special_process_par(par, new_par): 73 """ 74 Processes the special condition. 75 76 Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor. 77 """ 78 par_shape_len = len(par.data.shape) 79 new_par_shape_len = len(new_par.data.shape) 80 if new_par_shape_len <= par_shape_len: 81 return False 82 83 for i in range(new_par_shape_len - par_shape_len): 84 if new_par.data.shape[par_shape_len + i] != 1: 85 return False 86 87 new_val = new_par.data.asnumpy() 88 new_val = new_val.reshape(par.data.shape) 89 par.set_data(Tensor(new_val, par.data.dtype)) 90 return True 91 92 93def _update_param(param, new_param, strict_load): 94 """Updates param's data from new_param's data.""" 95 if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): 96 if param.data.shape != new_param.data.shape: 97 if not _special_process_par(param, new_param): 98 logger.error("Failed to combine the net and the parameters for param %s.", param.name) 99 msg = ("Net parameters {} shape({}) different from parameter_dict's({})" 100 .format(param.name, param.data.shape, new_param.data.shape)) 101 raise RuntimeError(msg) 102 103 if param.data.dtype != new_param.data.dtype: 104 if _type_convert(param, new_param, strict_load): 105 new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype) 106 param.set_data(new_tensor) 107 return 108 109 logger.error("Failed to combine the net and the parameters for param %s.", param.name) 110 msg = ("Net parameters {} type({}) different from parameter_dict's({})" 111 .format(param.name, param.data.dtype, new_param.data.dtype)) 112 raise RuntimeError(msg) 113 114 param.set_data(new_param.data, param.sliced) 115 return 116 117 if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): 118 if param.data.shape != (1,) and param.data.shape != (): 119 logger.error("Failed to combine the net and the parameters for param %s.", param.name) 120 msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)." 121 .format(param.name, param.data.shape)) 122 raise RuntimeError(msg) 123 param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype)) 124 125 elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor): 126 logger.error("Failed to combine the net and the parameters for param %s.", param.name) 127 msg = ("Net parameters {} type({}) different from parameter_dict's({})" 128 .format(param.name, type(param.data), type(new_param.data))) 129 raise RuntimeError(msg) 130 131 else: 132 param.set_data(type(param.data)(new_param.data)) 133 134 135def _type_convert(param, new_param, strict_load): 136 """Whether to convert parameter's type during load checkpoint into network.""" 137 float_type = (mstype.float16, mstype.float32, mstype.float64) 138 int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64) 139 if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or 140 {param.data.dtype, new_param.data.dtype}.issubset(int_type)): 141 logger.warning("ckpt_dict parameter: {}'s type is {}, convert to {} in the network." 142 .format(new_param.name, new_param.data.dtype, param.data.dtype)) 143 return True 144 return False 145 146 147def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): 148 """Execute the process of saving checkpoint into file.""" 149 try: 150 with _ckpt_mutex: 151 if os.path.exists(ckpt_file_name): 152 os.remove(ckpt_file_name) 153 with open(ckpt_file_name, "ab") as f: 154 if enc_key is not None: 155 plain_data = bytes(0) 156 cipher_data = bytes(0) 157 158 for name, value in data_list.items(): 159 data_size = value[2].nbytes / 1024 160 if data_size > SLICE_SIZE: 161 slice_count = math.ceil(data_size / SLICE_SIZE) 162 param_slice_list = np.array_split(value[2], slice_count) 163 else: 164 param_slice_list = [value[2]] 165 166 for param_slice in param_slice_list: 167 checkpoint_list = Checkpoint() 168 param_value = checkpoint_list.value.add() 169 param_value.tag = name 170 param_tensor = param_value.tensor 171 param_tensor.dims.extend(value[0]) 172 param_tensor.tensor_type = value[1] 173 param_tensor.tensor_content = param_slice.tobytes() 174 175 if enc_key is None: 176 f.write(checkpoint_list.SerializeToString()) 177 else: 178 plain_data += checkpoint_list.SerializeToString() 179 180 max_block_size = SLICE_SIZE*1024 181 while len(plain_data) >= max_block_size: 182 cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key, 183 len(enc_key), enc_mode) 184 plain_data = plain_data[max_block_size:] 185 186 if enc_key is not None: 187 if plain_data: 188 cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode) 189 f.write(cipher_data) 190 191 os.chmod(ckpt_file_name, stat.S_IRUSR) 192 193 except BaseException as e: 194 logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) 195 raise e 196 197 198def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, 199 async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"): 200 """ 201 Save checkpoint to a specified file. 202 203 Args: 204 save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like 205 [{"name": param_name, "data": param_data},...], the type of 206 param_name would be string, and the type of param_data would 207 be parameter or Tensor). 208 ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. 209 integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True 210 async_save (bool): Whether to open a independent thread to save the checkpoint file. Default: False 211 append_dict (dict): Additional information that needs to be saved. The key of dict must be str, 212 the value of dict must be one of int float and bool. Default: None 213 enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption 214 is not required. Default: None. 215 enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption 216 mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. 217 218 Raises: 219 TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter 220 `integrated_save` and `async_save` are not bool type. 221 222 Examples: 223 >>> from mindspore import save_checkpoint 224 >>> 225 >>> net = Net() 226 >>> save_checkpoint(net, "lenet.ckpt") 227 """ 228 229 if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): 230 raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) 231 integrated_save = Validator.check_bool(integrated_save) 232 async_save = Validator.check_bool(async_save) 233 append_dict = _check_append_dict(append_dict) 234 enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) 235 enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) 236 237 logger.info("Execute the process of saving checkpoint files.") 238 239 if isinstance(save_obj, nn.Cell): 240 save_obj.init_parameters_data() 241 param_dict = {} 242 for _, param in save_obj.parameters_and_names(): 243 param_dict[param.name] = param 244 param_list = [] 245 for (key, value) in param_dict.items(): 246 each_param = {"name": key} 247 param_data = Tensor(value.data) 248 249 # in automatic model parallel scenario, some parameters were split to all the devices, 250 # which should be combined before saving 251 if key in save_obj.parameter_layout_dict: 252 param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save) 253 254 each_param["data"] = param_data 255 param_list.append(each_param) 256 save_obj = param_list 257 258 if append_dict: 259 append_info_list = [] 260 for k_name, value in append_dict.items(): 261 append_info_list.append({"name": k_name, "data": Tensor(value)}) 262 save_obj.extend(append_info_list) 263 264 data_list = {} 265 with _ckpt_mutex: 266 for param in save_obj: 267 key = param["name"] 268 data_list[key] = [] 269 if isinstance(param["data"], Parameter): 270 param["data"].init_data() 271 dims = [] 272 if param['data'].shape == (): 273 dims.append(0) 274 else: 275 for dim in param['data'].shape: 276 dims.append(dim) 277 data_list[key].append(dims) 278 tensor_type = str(param["data"].dtype) 279 data_list[key].append(tensor_type) 280 data = param["data"].asnumpy().reshape(-1) 281 data_list[key].append(data) 282 283 ckpt_file_name = os.path.realpath(ckpt_file_name) 284 if async_save: 285 thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt") 286 thr.start() 287 else: 288 _exec_save(ckpt_file_name, data_list, enc_key, enc_mode) 289 290 logger.info("Saving checkpoint process is finished.") 291 292 293def _check_param_prefix(filter_prefix, param_name): 294 """Checks whether the prefix of parameter name matches the given filter_prefix.""" 295 for prefix in filter_prefix: 296 if param_name.find(prefix) == 0 \ 297 and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")): 298 return True 299 return False 300 301 302def _check_append_dict(append_dict): 303 if append_dict is None: 304 return append_dict 305 if not isinstance(append_dict, dict): 306 raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.") 307 if not all(isinstance(ele, str) for ele in append_dict.keys()) or \ 308 not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()): 309 raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.") 310 return append_dict 311 312 313def load(file_name, **kwargs): 314 """ 315 Load MindIR. 316 317 The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details. 318 319 Args: 320 file_name (str): MindIR file name. 321 322 kwargs (dict): Configuration options dictionary. 323 324 - dec_key (bytes): Byte type key used for decryption. Tha valid length is 16, 24, or 32. 325 - dec_mode (str): Specifies the decryption mode, take effect when dec_key is set. 326 Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'. 327 Returns: 328 Object, a compiled graph that can executed by `GraphCell`. 329 330 Raises: 331 ValueError: MindIR file name is incorrect. 332 RuntimeError: Failed to parse MindIR file. 333 334 Examples: 335 >>> import numpy as np 336 >>> import mindspore.nn as nn 337 >>> from mindspore import Tensor, export, load 338 >>> 339 >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") 340 >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 341 >>> export(net, input_tensor, file_name="net", file_format="MINDIR") 342 >>> graph = load("net.mindir") 343 >>> net = nn.GraphCell(graph) 344 >>> output = net(input_tensor) 345 >>> print(output) 346 [[[[4. 6. 4.] 347 [6. 9. 6.] 348 [4. 6. 4.]]]] 349 """ 350 if not isinstance(file_name, str): 351 raise ValueError("The file name must be string.") 352 if not file_name.endswith(".mindir"): 353 raise ValueError("The MindIR should end with mindir, please input the correct file name.") 354 if not os.path.exists(file_name): 355 raise ValueError("The file does not exist.") 356 file_name = os.path.realpath(file_name) 357 358 logger.info("Execute the process of loading mindir.") 359 if 'dec_key' in kwargs.keys(): 360 dec_key = Validator.check_isinstance('dec_key', kwargs['dec_key'], bytes) 361 dec_mode = 'AES-GCM' 362 if 'dec_mode' in kwargs.keys(): 363 dec_mode = Validator.check_isinstance('dec_mode', kwargs['dec_mode'], str) 364 graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode) 365 else: 366 graph = load_mindir(file_name) 367 368 if graph is None: 369 if _is_cipher_file(file_name): 370 raise RuntimeError("Load MindIR failed. The file may be encrypted, please pass in the " 371 "correct dec_key and dec_mode.") 372 raise RuntimeError("Load MindIR failed.") 373 return graph 374 375 376def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): 377 """ 378 Load checkpoint info from a specified file. 379 380 Args: 381 ckpt_file_name (str): Checkpoint file name. 382 net (Cell): The network where the parameters will be loaded. Default: None 383 strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter 384 into net when parameter name's suffix in checkpoint file is the same as the 385 parameter in the network. When the types are inconsistent perform type conversion 386 on the parameters of the same type, such as float32 to float16. Default: False. 387 filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix 388 will not be loaded. Default: None. 389 dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption 390 is not required. Default: None. 391 dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption 392 mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. 393 394 Returns: 395 Dict, key is parameter name, value is a Parameter. 396 397 Raises: 398 ValueError: Checkpoint file is incorrect. 399 400 Examples: 401 >>> from mindspore import load_checkpoint 402 >>> 403 >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" 404 >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") 405 >>> print(param_dict["conv2.weight"]) 406 Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True) 407 """ 408 ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) 409 dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) 410 dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) 411 logger.info("Execute the process of loading checkpoint files.") 412 checkpoint_list = Checkpoint() 413 414 try: 415 if dec_key is None: 416 with open(ckpt_file_name, "rb") as f: 417 pb_content = f.read() 418 else: 419 pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) 420 if pb_content is None: 421 raise ValueError 422 checkpoint_list.ParseFromString(pb_content) 423 except BaseException as e: 424 if _is_cipher_file(ckpt_file_name): 425 logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the " 426 "correct dec_key.", ckpt_file_name) 427 else: 428 logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \ 429 ckpt_file_name) 430 raise ValueError(e.__str__()) 431 432 parameter_dict = {} 433 try: 434 param_data_list = [] 435 for element_id, element in enumerate(checkpoint_list.value): 436 if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag): 437 continue 438 data = element.tensor.tensor_content 439 data_type = element.tensor.tensor_type 440 np_type = tensor_to_np_type[data_type] 441 ms_type = tensor_to_ms_type[data_type] 442 element_data = np.frombuffer(data, np_type) 443 param_data_list.append(element_data) 444 if (element_id == len(checkpoint_list.value) - 1) or \ 445 (element.tag != checkpoint_list.value[element_id + 1].tag): 446 param_data = np.concatenate((param_data_list), axis=0) 447 param_data_list.clear() 448 dims = element.tensor.dims 449 if dims == [0]: 450 if 'Float' in data_type: 451 param_data = float(param_data[0]) 452 elif 'Int' in data_type: 453 param_data = int(param_data[0]) 454 parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) 455 elif dims == [1]: 456 parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) 457 else: 458 param_dim = [] 459 for dim in dims: 460 param_dim.append(dim) 461 param_value = param_data.reshape(param_dim) 462 parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) 463 464 logger.info("Loading checkpoint files process is finished.") 465 466 except BaseException as e: 467 logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) 468 raise RuntimeError(e.__str__()) 469 470 if not parameter_dict: 471 raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.") 472 473 if net is not None: 474 load_param_into_net(net, parameter_dict, strict_load) 475 476 return parameter_dict 477 478 479def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): 480 """Check function load_checkpoint's parameter.""" 481 if not isinstance(ckpt_file_name, str): 482 raise ValueError("The ckpt_file_name must be string.") 483 484 if not os.path.exists(ckpt_file_name): 485 raise ValueError("The checkpoint file does not exist.") 486 487 if ckpt_file_name[-5:] != ".ckpt": 488 raise ValueError("Please input the correct checkpoint file name.") 489 ckpt_file_name = os.path.realpath(ckpt_file_name) 490 491 if filter_prefix is not None: 492 if not isinstance(filter_prefix, (str, list, tuple)): 493 raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] " 494 f"when filter_prefix is not None, but got {str(type(filter_prefix))}.") 495 if isinstance(filter_prefix, str): 496 filter_prefix = (filter_prefix,) 497 if not filter_prefix: 498 raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.") 499 for index, prefix in enumerate(filter_prefix): 500 if not isinstance(prefix, str): 501 raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " 502 f"but got {str(type(prefix))} at index {index}.") 503 return ckpt_file_name, filter_prefix 504 505 506def load_param_into_net(net, parameter_dict, strict_load=False): 507 """ 508 Load parameters into network. 509 510 Args: 511 net (Cell): The network where the parameters will be loaded. 512 parameter_dict (dict): The dictionary generated by load checkpoint file. 513 strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter 514 into net when parameter name's suffix in checkpoint file is the same as the 515 parameter in the network. When the types are inconsistent perform type conversion 516 on the parameters of the same type, such as float32 to float16. Default: False. 517 518 Returns: 519 List, parameter name not loaded into the network 520 521 Raises: 522 TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. 523 524 Examples: 525 >>> from mindspore import load_checkpoint, load_param_into_net 526 >>> 527 >>> net = Net() 528 >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" 529 >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") 530 >>> param_not_load = load_param_into_net(net, param_dict) 531 >>> print(param_not_load) 532 ['conv1.weight'] 533 """ 534 if not isinstance(net, nn.Cell): 535 logger.error("Failed to combine the net and the parameters.") 536 msg = ("Argument net should be a Cell, but got {}.".format(type(net))) 537 raise TypeError(msg) 538 539 if not isinstance(parameter_dict, dict): 540 logger.error("Failed to combine the net and the parameters.") 541 msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) 542 raise TypeError(msg) 543 544 strict_load = Validator.check_bool(strict_load) 545 logger.info("Execute the process of loading parameters into net.") 546 net.init_parameters_data() 547 param_not_load = [] 548 for _, param in net.parameters_and_names(): 549 if param.name in parameter_dict: 550 new_param = copy.deepcopy(parameter_dict[param.name]) 551 if not isinstance(new_param, Parameter): 552 logger.error("Failed to combine the net and the parameters.") 553 msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) 554 raise TypeError(msg) 555 _update_param(param, new_param, strict_load) 556 else: 557 param_not_load.append(param.name) 558 559 if param_not_load and not strict_load: 560 _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load) 561 562 logger.debug("Params not matched(in net but not in parameter_dict):") 563 for param_name in param_not_load: 564 logger.debug("%s", param_name) 565 566 logger.info("Loading parameters into net is finished.") 567 if param_not_load: 568 logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load))) 569 for param_name in param_not_load: 570 logger.warning("{} is not loaded.".format(param_name)) 571 return param_not_load 572 573 574def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load): 575 """When some net parameter did not load, try to continue load.""" 576 prefix_name = "" 577 longest_name = param_not_load[0] 578 while prefix_name != longest_name and param_not_load: 579 logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) 580 prefix_name = longest_name 581 for net_param_name in param_not_load: 582 for dict_name in parameter_dict: 583 if dict_name.endswith(net_param_name): 584 prefix_name = dict_name[:-len(net_param_name)] 585 break 586 if prefix_name != longest_name: 587 break 588 589 if prefix_name != longest_name: 590 logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) 591 for _, param in net.parameters_and_names(): 592 new_param_name = prefix_name + param.name 593 if param.name in param_not_load and new_param_name in parameter_dict: 594 new_param = parameter_dict[new_param_name] 595 _update_param(param, new_param, strict_load) 596 param_not_load.remove(param.name) 597 598 599def _save_graph(network, file_name): 600 """ 601 Saves the graph of network to a file. 602 603 Args: 604 network (Cell): Obtain a pipeline through network for saving graph. 605 file_name (str): Graph file name into which the graph will be saved. 606 """ 607 logger.info("Execute the process of saving graph.") 608 609 file_name = os.path.realpath(file_name) 610 graph_pb = network.get_func_graph_proto() 611 if graph_pb: 612 with open(file_name, "wb") as f: 613 os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR) 614 f.write(graph_pb) 615 616 617def _get_merged_param_data(net, param_name, param_data, integrated_save): 618 """ 619 Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. 620 621 Args: 622 net (Cell): MindSpore network. 623 param_name (str): The parameter name, which to be combined. 624 param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data. 625 integrated_save (bool): Whether to integrated save in automatic model parallel scene. 626 Returns: 627 Tensor, the combined tensor which with the whole data value. 628 """ 629 layout = net.parameter_layout_dict[param_name] 630 if len(layout) < 6: 631 logger.info("layout dict does not contain the key %s", param_name) 632 return param_data 633 634 dev_mat = layout[0] 635 tensor_map = layout[1] 636 uniform_split = layout[4] 637 opt_shard_group = layout[5] 638 639 allgather_net = None 640 mp_weight = False 641 for dim in tensor_map: 642 if dim != -1: 643 mp_weight = True 644 break 645 if param_name in net.parallel_parameter_merge_net_dict: 646 allgather_net = net.parallel_parameter_merge_net_dict[param_name] 647 else: 648 logger.info("need to create allgather net for %s", param_name) 649 if integrated_save: 650 if context.get_auto_parallel_context("pipeline_stages") > 1: 651 raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.") 652 if uniform_split == 0: 653 raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.") 654 # while any dim is not equal to -1, means param is split and needs to be merged 655 # pipeline parallel need to be supported here later 656 if mp_weight: 657 allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group)) 658 elif opt_shard_group: 659 allgather_net = get_allgather_cell(opt_shard_group, False) 660 elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"): 661 allgather_net = get_allgather_cell(opt_shard_group, False) 662 net.parallel_parameter_merge_net_dict[param_name] = allgather_net 663 if allgather_net: 664 param_data = allgather_net(param_data) 665 if mp_weight and integrated_save: 666 param_data = _reshape_param_data(param_data, dev_mat, tensor_map) 667 return param_data 668 669 670def _fill_param_into_net(net, parameter_list): 671 """ 672 Fills parameter_list into net. 673 674 Args: 675 net (Cell): train network. 676 parameter_list (list): parameters list from ge callback. 677 """ 678 parameter_dict = {} 679 for each_param in parameter_list: 680 param_name = each_param["name"] 681 if isinstance(each_param["data"], Parameter): 682 each_param["data"].init_data() 683 np_val = each_param["data"].asnumpy() 684 if np_val.shape == (1,): 685 parameter_dict[param_name] = Parameter(np_val, name=param_name) 686 elif np_val.shape == (): 687 parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)), 688 name=param_name) 689 else: 690 parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name) 691 692 load_param_into_net(net, parameter_dict) 693 694 695def export(net, *inputs, file_name, file_format='AIR', **kwargs): 696 """ 697 Export the mindspore network into an offline model in the specified format. 698 699 Note: 700 1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB. 701 2. When file_name does not have a suffix, the system will automatically add one according to the file_format. 702 703 Args: 704 net (Cell): MindSpore network. 705 inputs (Tensor): Inputs of the `net`, if the network has multiple inputs, incoming tuple(Tensor). 706 file_name (str): File name of the model to be exported. 707 file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. 708 709 - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model. 710 - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. 711 - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format 712 for MindSpore models. 713 714 kwargs (dict): Configuration options dictionary. 715 716 - quant_mode (str): If the network is quantization aware training network, the quant_mode should 717 be set to "QUANT", else the quant_mode should be set to "NONQUANT". 718 - mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network. 719 Default: 127.5. 720 - std_dev (float): The variance of input data after preprocessing, 721 used for quantizing the first layer of network. Default: 127.5. 722 - enc_key (byte): Byte type key used for encryption. Tha valid length is 16, 24, or 32. 723 - enc_mode (str): Specifies the encryption mode, take effect when enc_key is set. 724 Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'. 725 - dataset (Dataset): Specifies the preprocess methods of network. 726 727 Examples: 728 >>> import numpy as np 729 >>> from mindspore import export, Tensor 730 >>> 731 >>> net = LeNet() 732 >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) 733 >>> export(net, Tensor(input_tensor), file_name='lenet', file_format='MINDIR') 734 """ 735 logger.info("exporting model file:%s format:%s.", file_name, file_format) 736 check_input_data(*inputs, data_class=Tensor) 737 Validator.check_file_name_by_regular(file_name) 738 file_name = os.path.realpath(file_name) 739 net = _quant_export(net, *inputs, file_format=file_format, **kwargs) 740 if 'enc_key' in kwargs.keys(): 741 if file_format != 'MINDIR': 742 raise ValueError(f"enc_key can be passed in only when file_format=='MINDIR', but got {file_format}") 743 744 enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes) 745 enc_mode = 'AES-GCM' 746 if 'enc_mode' in kwargs.keys(): 747 enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str) 748 dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None 749 _export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset) 750 else: 751 _export(net, file_name, file_format, *inputs, **kwargs) 752 753 754def _export(net, file_name, file_format, *inputs, **kwargs): 755 """ 756 It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format. 757 """ 758 logger.info("exporting model file:%s format:%s.", file_name, file_format) 759 check_input_data(*inputs, data_class=Tensor) 760 if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None: 761 check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset) 762 763 if file_format == 'GEIR': 764 logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.") 765 file_format = 'AIR' 766 767 supported_formats = ['AIR', 'ONNX', 'MINDIR'] 768 if file_format not in supported_formats: 769 raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}') 770 # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) 771 is_dump_onnx_in_training = net.training and file_format == 'ONNX' 772 if is_dump_onnx_in_training: 773 net.set_train(mode=False) 774 775 if file_format == 'AIR': 776 phase_name = 'export.air' 777 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) 778 if not file_name.endswith('.air'): 779 file_name += ".air" 780 if os.path.exists(file_name): 781 os.chmod(file_name, stat.S_IWUSR) 782 if "/" in file_name: 783 real_path = os.path.realpath(file_name[:file_name.rfind("/")]) 784 os.makedirs(real_path, exist_ok=True) 785 _executor.export(file_name, graph_id) 786 os.chmod(file_name, stat.S_IRUSR) 787 elif file_format == 'ONNX': 788 total_size = _calculation_net_size(net) 789 if total_size > PROTO_LIMIT_SIZE: 790 raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.' 791 .format(total_size/1024/1024, PROTO_LIMIT_SIZE/1024/1024)) 792 phase_name = 'export.onnx' 793 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) 794 onnx_stream = _executor._get_func_graph_proto(net, graph_id) 795 if not file_name.endswith('.onnx'): 796 file_name += ".onnx" 797 if os.path.exists(file_name): 798 os.chmod(file_name, stat.S_IWUSR) 799 with open(file_name, 'wb') as f: 800 f.write(onnx_stream) 801 os.chmod(file_name, stat.S_IRUSR) 802 elif file_format == 'MINDIR': 803 _save_mindir(net, file_name, *inputs, **kwargs) 804 805 if is_dump_onnx_in_training: 806 net.set_train(mode=True) 807 808 809def _save_mindir(net, file_name, *inputs, **kwargs): 810 """Save MindIR format file.""" 811 model = mindir_model() 812 813 phase_name = "predict" if net._auto_parallel_mode else "export.mindir" 814 815 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, 816 do_convert=False, auto_parallel_mode=net._auto_parallel_mode) 817 mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') 818 819 net_dict = net.parameters_dict() 820 model.ParseFromString(mindir_stream) 821 822 if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None: 823 dataset = kwargs['dataset'] 824 model.preprocessor = json.dumps(dataset.to_json(), indent=2) 825 826 save_together = _save_together(net_dict, model) 827 is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys() 828 if save_together: 829 _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs) 830 else: 831 logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") 832 # save parameter 833 file_prefix = file_name.split("/")[-1] 834 if file_prefix.endswith(".mindir"): 835 file_prefix = file_prefix[:-7] 836 current_path = os.path.abspath(file_name) 837 dirname = os.path.dirname(current_path) 838 data_path = os.path.join(dirname, file_prefix + "_variables") 839 if os.path.exists(data_path): 840 shutil.rmtree(data_path) 841 os.makedirs(data_path, exist_ok=True) 842 os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 843 index = 0 844 graphproto = graph_proto() 845 data_size = 0 846 847 for name, param in net_dict.items(): 848 for param_proto in model.graph.parameter: 849 if name == param_proto.name[param_proto.name.find(":") + 1:]: 850 parameter = graphproto.parameter.add() 851 parameter.name = param_proto.name 852 parameter.data_type = param_proto.data_type 853 for dim in param_proto.dims: 854 parameter.dims.append(dim) 855 byte_data = param.data.asnumpy().tobytes() 856 parameter.raw_data = byte_data 857 data_size += sys.getsizeof(byte_data) / 1024 858 break 859 if data_size > TOTAL_SAVE: 860 data_file_name = os.path.join(data_path, "data_" + str(index)) 861 if os.path.exists(data_file_name): 862 os.chmod(data_file_name, stat.S_IWUSR) 863 with open(data_file_name, "ab") as f: 864 os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR) 865 graph_string = graphproto.SerializeToString() 866 if is_encrypt(): 867 graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], 868 len(kwargs['enc_key']), kwargs['enc_mode']) 869 f.write(graph_string) 870 os.chmod(data_file_name, stat.S_IRUSR) 871 index += 1 872 data_size = 0 873 del graphproto.parameter[:] 874 875 if graphproto.parameter: 876 data_file_name = os.path.join(data_path, "data_" + str(index)) 877 if os.path.exists(data_file_name): 878 os.chmod(data_file_name, stat.S_IWUSR) 879 with open(data_file_name, "ab") as f: 880 os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR) 881 graph_string = graphproto.SerializeToString() 882 if is_encrypt(): 883 graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']), 884 kwargs['enc_mode']) 885 f.write(graph_string) 886 os.chmod(data_file_name, stat.S_IRUSR) 887 888 # save graph 889 del model.graph.parameter[:] 890 graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") 891 if os.path.exists(graph_file_name): 892 os.chmod(graph_file_name, stat.S_IWUSR) 893 with open(graph_file_name, 'wb') as f: 894 os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) 895 model_string = model.SerializeToString() 896 if is_encrypt(): 897 model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']), 898 kwargs['enc_mode']) 899 f.write(model_string) 900 os.chmod(graph_file_name, stat.S_IRUSR) 901 902 903def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs): 904 """Save graph and parameter together.""" 905 for param_proto in model.graph.parameter: 906 param_name = param_proto.name[param_proto.name.find(":") + 1:] 907 if param_name in net_dict.keys(): 908 param_data = net_dict[param_name].data.asnumpy().tobytes() 909 param_proto.raw_data = param_data 910 else: 911 logger.error("The parameter %s in the graph are not in the network.", param_name) 912 raise ValueError("The parameter in the graph must in the network.") 913 if not file_name.endswith('.mindir'): 914 file_name += ".mindir" 915 current_path = os.path.abspath(file_name) 916 dirname = os.path.dirname(current_path) 917 os.makedirs(dirname, exist_ok=True) 918 if os.path.exists(file_name): 919 os.chmod(file_name, stat.S_IWUSR) 920 with open(file_name, 'wb') as f: 921 os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR) 922 model_string = model.SerializeToString() 923 if is_encrypt(): 924 model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']), 925 kwargs['enc_mode']) 926 f.write(model_string) 927 os.chmod(file_name, stat.S_IRUSR) 928 929 930def _save_together(net_dict, model): 931 """Whether graph and parameter save together during save mindir model.""" 932 data_total = 0 933 for param_proto in model.graph.parameter: 934 name = param_proto.name[param_proto.name.find(":") + 1:] 935 if name in net_dict.keys(): 936 data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 937 else: 938 raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name)) 939 if data_total > TOTAL_SAVE: 940 return False 941 return True 942 943 944def quant_mode_manage(func): 945 """ 946 Inherit the quant_mode in old version. 947 """ 948 def warpper(network, *inputs, file_format, **kwargs): 949 if 'quant_mode' not in kwargs: 950 return network 951 quant_mode = kwargs['quant_mode'] 952 if not isinstance(quant_mode, str): 953 raise TypeError("The type of quant_mode should be str, but got {}.".format(type(quant_mode))) 954 if quant_mode in ('AUTO', 'MANUAL'): 955 kwargs['quant_mode'] = 'QUANT' 956 return func(network, *inputs, file_format=file_format, **kwargs) 957 return warpper 958 959 960@quant_mode_manage 961def _quant_export(network, *inputs, file_format, **kwargs): 962 """ 963 Exports MindSpore quantization predict model to deploy with AIR and MINDIR. 964 """ 965 supported_device = ["Ascend", "GPU"] 966 supported_formats = ['AIR', 'MINDIR'] 967 quant_mode_formats = ['QUANT', 'NONQUANT'] 968 969 quant_mode = kwargs['quant_mode'] 970 if quant_mode not in quant_mode_formats: 971 raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') 972 if quant_mode == 'NONQUANT': 973 return network 974 quant_net = copy.deepcopy(network) 975 quant_net._create_time = int(time.time() * 1e9) 976 977 mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] 978 std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] 979 mean = Validator.check_value_type("mean", mean, (int, float)) 980 std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) 981 982 if context.get_context('device_target') not in supported_device: 983 raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) 984 985 if file_format not in supported_formats: 986 raise ValueError('Illegal file format {}.'.format(file_format)) 987 988 quant_net.set_train(False) 989 if file_format == "MINDIR": 990 exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) 991 else: 992 exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) 993 deploy_net = exporter.run() 994 return deploy_net 995 996 997def parse_print(print_file_name): 998 """ 999 Parse saved data generated by mindspore.ops.Print. Print is used to print data to screen in graph mode. 1000 It can also been turned off by setting the parameter `print_file_path` in `context`, and the data will be saved 1001 in a file specified by print_file_path. parse_print is used to parse the saved file. For more information 1002 please refer to :func:`mindspore.context.set_context` and :class:`mindspore.ops.Print`. 1003 1004 Args: 1005 print_file_name (str): The file name of saved print data. 1006 1007 Returns: 1008 List, element of list is Tensor. 1009 1010 Raises: 1011 ValueError: The print file may be empty, please make sure enter the correct file name. 1012 1013 Examples: 1014 >>> import numpy as np 1015 >>> import mindspore 1016 >>> import mindspore.ops as ops 1017 >>> from mindspore.nn as nn 1018 >>> from mindspore import Tensor, context 1019 >>> context.set_context(mode=context.GRAPH_MODE, print_file_path='log.data') 1020 >>> class PrintInputTensor(nn.Cell): 1021 ... def __init__(self): 1022 ... super().__init__() 1023 ... self.print = ops.Print() 1024 1025 ... def construct(self, input_pra): 1026 ... self.print('print:', input_pra) 1027 ... return input_pra 1028 1029 >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32) 1030 >>> input_pra = Tensor(x) 1031 >>> net = PrintInputTensor() 1032 >>> net(input_pra) 1033 1034 >>> data = mindspore.parse_print('./log.data') 1035 >>> print(data) 1036 ['print:', Tensor(shape=[2, 4], dtype=Float32, value= 1037 [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], 1038 [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])] 1039 """ 1040 1041 print_file_path = os.path.realpath(print_file_name) 1042 1043 if os.path.getsize(print_file_path) == 0: 1044 raise ValueError("The print file may be empty, please make sure enter the correct file name.") 1045 1046 logger.info("Execute load print process.") 1047 print_list = Print() 1048 1049 try: 1050 with open(print_file_path, "rb") as f: 1051 pb_content = f.read() 1052 print_list.ParseFromString(pb_content) 1053 except BaseException as e: 1054 logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name) 1055 raise ValueError(e.__str__()) 1056 1057 tensor_list = [] 1058 1059 try: 1060 for print_ in print_list.value: 1061 # String type 1062 if print_.HasField("desc"): 1063 tensor_list.append(print_.desc) 1064 elif print_.HasField("tensor"): 1065 dims = print_.tensor.dims 1066 data_type = print_.tensor.tensor_type 1067 data = print_.tensor.tensor_content 1068 np_type = tensor_to_np_type[data_type] 1069 param_data = np.fromstring(data, np_type) 1070 ms_type = tensor_to_ms_type[data_type] 1071 if dims and dims != [0]: 1072 param_value = param_data.reshape(dims) 1073 tensor_list.append(Tensor(param_value, ms_type)) 1074 # Scalar type 1075 else: 1076 data_type_ = data_type.lower() 1077 if 'float' in data_type_: 1078 param_data = float(param_data[0]) 1079 elif 'int' in data_type_: 1080 param_data = int(param_data[0]) 1081 elif 'bool' in data_type_: 1082 param_data = bool(param_data[0]) 1083 tensor_list.append(Tensor(param_data, ms_type)) 1084 1085 except BaseException as e: 1086 logger.error("Failed to load the print file %s.", print_list) 1087 raise RuntimeError(e.__str__()) 1088 1089 return tensor_list 1090 1091 1092def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): 1093 """ 1094 Merge data slices to one tensor with whole data when strategy is not None. 1095 1096 Args: 1097 sliced_data (list[numpy.ndarray]): Data slices in order of rank_id. 1098 parameter_name (str): Name of parameter. 1099 strategy (dict): Parameter slice strategy. 1100 is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly. 1101 1102 Returns: 1103 Tensor, the merged Tensor which has the whole data. 1104 1105 Raises: 1106 ValueError: Failed to merge. 1107 """ 1108 layout = strategy.get(parameter_name) 1109 try: 1110 dev_mat = list(layout.dev_matrix[0].dim) 1111 tensor_map = list(layout.tensor_map[0].dim) 1112 param_split_shape = list(layout.param_split_shape[0].dim) 1113 field_size = int(layout.field) 1114 except BaseException as e: 1115 raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") 1116 1117 device_count = 1 1118 for dim in dev_mat: 1119 device_count *= dim 1120 1121 if len(sliced_data) != device_count: 1122 raise ValueError(f"The sliced_parameters length should be equal to device_count. " 1123 f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.") 1124 1125 if not param_split_shape: 1126 if not is_even: 1127 raise ValueError("The shape of every parameter in sliced_parameters should be the same " 1128 "when slice manner is even.") 1129 1130 all_gather_tensor = Tensor(np.concatenate(sliced_data)) 1131 1132 if field_size > 0: 1133 merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size) 1134 else: 1135 merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map) 1136 1137 else: 1138 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 1139 1140 slice_count = 1 1141 for dim in tensor_strategy: 1142 slice_count *= dim 1143 1144 if len(param_split_shape) != slice_count: 1145 raise ValueError(f"The param_split_shape length in strategy should be {slice_count}, " 1146 f"but got {len(param_split_shape)}.") 1147 1148 tensor_slices_new = list(range(slice_count)) 1149 tensor_slices = sliced_data 1150 for i in range(device_count): 1151 slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)) 1152 if tensor_slices[i].shape[0] != param_split_shape[slice_index]: 1153 raise ValueError(f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, " 1154 f"but got {tensor_slices[i].shape[0]}.") 1155 tensor_slices_new[slice_index] = np.array(tensor_slices[i]) 1156 1157 dim_len = len(tensor_strategy) 1158 for i in range(dim_len): 1159 ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) 1160 tensor_slices_new_inner = [] 1161 for j in range(ele_count): 1162 new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] 1163 for l in range(j * tensor_strategy[dim_len - 1 - i] + 1, 1164 (j + 1) * tensor_strategy[dim_len - 1 - i]): 1165 new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i) 1166 tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) 1167 tensor_slices_new = tensor_slices_new_inner 1168 merged_tensor = Tensor(tensor_slices_new[0]) 1169 1170 return merged_tensor 1171 1172 1173def build_searched_strategy(strategy_filename): 1174 """ 1175 Build strategy of every parameter in network. Used in the case of distributed inference. 1176 For details of merge_sliced_parameter, please check: 1177 `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide 1178 /en/r1.5/save_load_model_hybrid_parallel.html>`_. 1179 1180 Args: 1181 strategy_filename (str): Name of strategy file. 1182 1183 Returns: 1184 Dict, whose key is parameter name and value is slice strategy of this parameter. 1185 1186 Raises: 1187 ValueError: Strategy file is incorrect. 1188 TypeError: strategy_filename is not str. 1189 1190 Examples: 1191 >>> strategy = build_searched_strategy("./strategy_train.ckpt") 1192 """ 1193 if not isinstance(strategy_filename, str): 1194 raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.") 1195 1196 if not os.path.isfile(strategy_filename): 1197 raise ValueError(f"No such strategy file: {strategy_filename}.") 1198 1199 if os.path.getsize(strategy_filename) == 0: 1200 raise ValueError("The strategy file should not be empty.") 1201 1202 parallel_strategy_map = ParallelStrategyMap() 1203 1204 with open(strategy_filename, 'rb') as f: 1205 pb_content = f.read() 1206 parallel_strategy_map.ParseFromString(pb_content) 1207 1208 layout_items = parallel_strategy_map.parallel_layout_item 1209 if not layout_items: 1210 raise ValueError("The strategy file has no sliced parameter.") 1211 1212 strategy = {} 1213 for layout_item in layout_items: 1214 parameter_name = layout_item.param_name 1215 layout = layout_item.parallel_layouts 1216 strategy[parameter_name] = layout 1217 1218 return strategy 1219 1220 1221def merge_sliced_parameter(sliced_parameters, strategy=None): 1222 """ 1223 Merge parameter slices into one parameter. Used in the case of distributed inference. 1224 For details of merge_sliced_parameter, please check: 1225 `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide 1226 /en/r1.5/save_load_model_hybrid_parallel.html>`_. 1227 1228 Args: 1229 sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. 1230 strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and 1231 value is slice strategy of this parameter. If strategy is None, just merge 1232 parameter slices in 0 axis order. Default: None. 1233 1234 Returns: 1235 Parameter, the merged parameter which has the whole data. 1236 1237 Raises: 1238 ValueError: Failed to merge. 1239 TypeError: The sliced_parameters is incorrect or strategy is not dict. 1240 KeyError: The parameter name is not in keys of strategy. 1241 1242 Examples: 1243 >>> import numpy as np 1244 >>> from mindspore import Tensor, merge_sliced_parameter, Parameter 1245 >>> 1246 >>> sliced_parameters = [ 1247 ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), 1248 ... "network.embedding_table"), 1249 ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), 1250 ... "network.embedding_table"), 1251 ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), 1252 ... "network.embedding_table"), 1253 ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), 1254 ... "network.embedding_table")] 1255 >>> merged_parameter = merge_sliced_parameter(sliced_parameters) 1256 >>> print(merged_parameter) 1257 Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True) 1258 """ 1259 if not isinstance(sliced_parameters, list): 1260 raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.") 1261 1262 if not sliced_parameters: 1263 raise ValueError("The sliced_parameters should not be empty.") 1264 1265 if strategy and not isinstance(strategy, dict): 1266 raise TypeError(f"The strategy should be dict, but got {type(strategy)}.") 1267 1268 try: 1269 parameter_name = sliced_parameters[0].name 1270 parameter_shape = sliced_parameters[0].data.shape 1271 parameter_shape_length = len(parameter_shape) 1272 except BaseException as e: 1273 raise TypeError(f"{e.__str__()}. the element in sliced_parameters should be Parameter.") 1274 1275 is_even = True 1276 for index, parameter in enumerate(sliced_parameters): 1277 if not isinstance(parameter, Parameter): 1278 raise TypeError(f"The element in sliced_parameters should be Parameter, " 1279 f"but got {type(parameter)} at index {index}.") 1280 1281 if parameter.name != parameter_name \ 1282 or len(parameter.data.shape) != parameter_shape_length \ 1283 or parameter.data.shape[1:] != parameter_shape[1:]: 1284 raise ValueError("Please make sure that the elements in slice_parameters have the same name, " 1285 "dimension length and shape except 0 axis") 1286 1287 if parameter.data.shape != parameter_shape: 1288 is_even = False 1289 1290 layerwise_parallel = sliced_parameters[0].layerwise_parallel 1291 requires_grad = sliced_parameters[0].requires_grad 1292 sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters] 1293 1294 if not strategy: 1295 merged_tensor = Tensor(np.concatenate(sliced_data)) 1296 merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) 1297 1298 else: 1299 if parameter_name not in strategy.keys(): 1300 raise KeyError(f"The parameter name should be one key of strategy. " 1301 f"the parameter name is {parameter_name}.") 1302 merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even) 1303 merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) 1304 1305 return merged_parameter 1306 1307 1308def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, 1309 train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'): 1310 """ 1311 Load checkpoint into net for distributed predication. Used in the case of distributed inference. 1312 For details of distributed inference, please check: 1313 `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide 1314 /en/r1.5/distributed_inference.html>`_. 1315 1316 Args: 1317 network (Cell): Network for distributed predication. 1318 checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. 1319 predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or 1320 a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, 1321 it means that the predication process just uses single device. Default: None. 1322 train_strategy_filename (str): Train strategy proto file name. Default: None. 1323 strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter 1324 into net when parameter name's suffix in checkpoint file is the same as the 1325 parameter in the network. When the types are inconsistent perform type conversion 1326 on the parameters of the same type, such as float32 to float16. Default: False. 1327 dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption 1328 is not required. Default: None. 1329 dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption 1330 mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. 1331 1332 Raises: 1333 TypeError: The type of inputs do not match the requirements. 1334 ValueError: Failed to load checkpoint into net. 1335 """ 1336 network = Validator.check_isinstance("network", network, nn.Cell) 1337 _check_checkpoint_file(checkpoint_filenames) 1338 _check_predict_strategy(predict_strategy) 1339 1340 dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) 1341 dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) 1342 1343 if train_strategy_filename is None: 1344 train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") 1345 _train_strategy = build_searched_strategy(train_strategy_filename) 1346 train_strategy = _convert_to_list(_train_strategy) 1347 1348 train_dev_count = 1 1349 ckpt_file_len = len(checkpoint_filenames) 1350 for dim in train_strategy[list(train_strategy.keys())[0]][0]: 1351 train_dev_count *= dim 1352 if train_dev_count != ckpt_file_len: 1353 raise ValueError( 1354 f"The length of checkpoint_filenames should be equal to the device count of training process. " 1355 f"The length is {ckpt_file_len} but the device count is {train_dev_count}.") 1356 1357 rank_list = _infer_rank_list(train_strategy, predict_strategy) 1358 1359 param_total_dict = defaultdict(dict) 1360 for file_index, file_name in enumerate(checkpoint_filenames): 1361 ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode) 1362 for param_name, param in ckpt_dict.items(): 1363 param_total_dict[param_name][file_index] = param 1364 1365 param_dict = {} 1366 param_not_in_strategy = [] 1367 param_not_in_ckpt = [] 1368 for _, param in network.parameters_and_names(): 1369 sliced_params = [] 1370 if param.name not in rank_list.keys(): 1371 param_not_in_strategy.append(param.name) 1372 continue 1373 if param.name not in param_total_dict: 1374 param_not_in_ckpt.append(param.name) 1375 continue 1376 1377 param_rank = rank_list[param.name][0] 1378 skip_merge_split = rank_list[param.name][1] 1379 shard_stride = train_strategy[param.name][4] 1380 if train_strategy[param.name][5]: 1381 shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5] 1382 else: 1383 shard_size = 0 1384 for rank in param_rank: 1385 param_total_list = list(range(0, ckpt_file_len)) 1386 if shard_size > 0: 1387 shard_total_list = [param_total_list[i:i + shard_size] for i in 1388 range(0, ckpt_file_len, shard_size)] 1389 param_total_list = shard_total_list[rank // shard_size] 1390 if shard_stride > 0: 1391 param_stride = [] 1392 # merge pre parameter 1393 param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] 1394 param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) 1395 param_index = list(set(param_index)) 1396 param_index.sort() 1397 for rank_num in param_index: 1398 param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) 1399 1400 sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) 1401 else: 1402 sliced_param = param_total_dict[param.name][rank] 1403 1404 sliced_params.append(sliced_param) 1405 if skip_merge_split: 1406 split_param = sliced_params[0] 1407 else: 1408 param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) 1409 _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy) 1410 split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy) 1411 opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None 1412 if opt_shard_group: 1413 data = split_param.data.asnumpy() 1414 rank = get_rank(opt_shard_group) 1415 size = get_group_size(opt_shard_group) 1416 try: 1417 data_slice = np.split(data, size)[rank] 1418 except BaseException as e: 1419 logger.error("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}" 1420 " and group is {}".format(param.name, split_param.data.shape, opt_shard_group)) 1421 raise RuntimeError(e.__str__()) 1422 split_param = Parameter(Tensor(data_slice), param.name, 1423 split_param.requires_grad, split_param.layerwise_parallel) 1424 param_dict[param.name] = split_param 1425 1426 if param_not_in_strategy: 1427 logger.warning("{} parameters in network are not in the slice strategy.".format(param_not_in_strategy)) 1428 if param_not_in_ckpt: 1429 logger.warning("{} parameters in slice strategy but not in the checkpoint file.".format(param_not_in_ckpt)) 1430 1431 load_param_into_net(network, param_dict, strict_load=strict_load) 1432 1433 1434def async_ckpt_thread_status(): 1435 """ 1436 Get the status of asynchronous save checkpoint thread. 1437 1438 When performing asynchronous save checkpoint, you can get the thread state through this function 1439 to ensure that write checkpoint file are completed. 1440 1441 Returns: 1442 True, Asynchronous save checkpoint thread is running. 1443 False, Asynchronous save checkpoint thread is not executing. 1444 """ 1445 thr_list = threading.enumerate() 1446 return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list] 1447 1448 1449def _check_predict_strategy(predict_strategy): 1450 """Check predict strategy.""" 1451 def _check_int_list(arg): 1452 if not isinstance(arg, list): 1453 return False 1454 for item in arg: 1455 if not isinstance(item, int): 1456 return False 1457 return True 1458 1459 if predict_strategy is None: 1460 return 1461 1462 flag = True 1463 predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) 1464 for key in predict_strategy.keys(): 1465 if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ 1466 or len(predict_strategy[key]) < 4: 1467 flag = False 1468 dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] 1469 if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ 1470 not (_check_int_list(param_split_shape) or not param_split_shape) or \ 1471 not (isinstance(field_size, int) and field_size == 0): 1472 flag = False 1473 1474 if not flag: 1475 raise ValueError(f"Please make sure that the key of predict_strategy is str, " 1476 f"and the value is a list or a tuple that the first four elements are " 1477 f"dev_matrix (list[int]), tensor_map (list[int]), " 1478 f"param_split_shape (list[int]) and field_size (zero).") 1479 1480 1481def _check_checkpoint_file(checkpoint_filenames): 1482 """Check checkpoint file name.""" 1483 for index, filename in enumerate(checkpoint_filenames): 1484 if not isinstance(filename, str) or not os.path.exists(filename) \ 1485 or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: 1486 raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") 1487 1488 1489def _convert_to_list(strategy): 1490 """Convert ParallelLayouts object to specified list.""" 1491 train_map = {} 1492 for param_name in strategy.keys(): 1493 try: 1494 layout = strategy.get(param_name) 1495 dev_mat = list(layout.dev_matrix[0].dim) 1496 tensor_map = list(layout.tensor_map[0].dim) 1497 param_split_shape = list(layout.param_split_shape[0].dim) 1498 field_size = int(layout.field) 1499 shard_stride = int(layout.opt_weight_shard_step) 1500 shard_size = int(layout.opt_weight_shard_size) 1501 train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size] 1502 except BaseException as e: 1503 raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") 1504 return train_map 1505 1506 1507def _convert_to_layout(param_name, tensor_layout): 1508 """Convert list to ParallelLayouts object.""" 1509 strategy = {} 1510 try: 1511 layout = ParallelLayouts() 1512 layout.field = tensor_layout[3] 1513 1514 dev_matrix = layout.dev_matrix.add() 1515 for item in tensor_layout[0]: 1516 dev_matrix.dim.append(item) 1517 1518 tensor_map = layout.tensor_map.add() 1519 for item in tensor_layout[1]: 1520 tensor_map.dim.append(item) 1521 1522 param_split_shape = layout.param_split_shape.add() 1523 for item in tensor_layout[2]: 1524 param_split_shape.dim.append(item) 1525 except BaseException as e: 1526 raise ValueError("Convert failed. " + e.__str__()) 1527 1528 strategy[param_name] = layout 1529 return strategy 1530 1531 1532def _merge_and_split(sliced_params, train_strategy, predict_strategy): 1533 """Merge sliced parameter and split it according to the predict strategy.""" 1534 merged_param = merge_sliced_parameter(sliced_params, train_strategy) 1535 if predict_strategy is None: 1536 return merged_param 1537 param_name = merged_param.name 1538 tensor_layout = predict_strategy[param_name] 1539 split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1]) 1540 requires_grad = merged_param.requires_grad 1541 layerwise_parallel = merged_param.layerwise_parallel 1542 split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) 1543 return split_param 1544 1545 1546def _calculation_net_size(net): 1547 """Calculate the size of parameters in the network.""" 1548 data_total = 0 1549 net_dict = net.parameters_dict() 1550 for name in net_dict: 1551 data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 1552 1553 return data_total 1554