1# Copyright 2020-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Model and parameters serialization.""" 17from __future__ import absolute_import 18from __future__ import division 19 20import binascii 21import copy 22import json 23import os 24import shutil 25import stat 26import threading 27from threading import Thread, RLock 28from collections import defaultdict, OrderedDict 29from io import BytesIO 30 31import math 32import sys 33import time 34import google 35import numpy as np 36 37from mindspore.train.checkpoint_pb2 import Checkpoint 38from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model 39from mindspore.train.print_pb2 import Print 40 41import mindspore 42import mindspore.nn as nn 43from mindspore import context 44from mindspore import log as logger 45from mindspore._checkparam import check_input_data, check_input_dataset 46from mindspore import _checkparam as Validator 47from mindspore.common import dtype as mstype 48from mindspore.common.api import _cell_graph_executor as _executor 49from mindspore.common.api import _MindsporeFunctionExecutor 50from mindspore.common.api import _get_parameter_layout 51from mindspore.common.api import _generate_branch_control_input 52from mindspore.common.initializer import initializer, One 53from mindspore.common.parameter import Parameter, _offload_if_config 54from mindspore.common.tensor import Tensor 55from mindspore._c_expression import Tensor as Tensor_ 56from mindspore.common._utils import is_shape_unknown 57from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system 58from mindspore.communication.management import get_rank, get_group_size 59from mindspore.experimental import MapParameter 60from mindspore.ops import Cast 61from mindspore.parallel._cell_wrapper import get_allgather_cell 62from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index 63from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight 64from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode 65from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \ 66 _restore_group_info_list 67from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \ 68 _store_warm_up_ptr_by_tensor_list, _cache_enable 69from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters 70from mindspore.train._utils import read_proto 71from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \ 72 split_mindir, split_dynamic_mindir 73from mindspore.common.generator import Generator 74from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy 75from mindspore.parallel.parameter_broadcast import parameter_broadcast 76from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs 77 78tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16, 79 "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64, 80 "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64, 81 "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2} 82 83tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16, 84 "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64, 85 "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"} 86 87np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64} 88 89mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16, 90 5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16, 91 11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64} 92 93_ckpt_mutex = RLock() 94 95# unit is KB 96SLICE_SIZE = 512 * 1024 97PROTO_LIMIT_SIZE = 1024 * 1024 * 2 98TOTAL_SAVE = 1024 * 1024 99PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024 100ENCRYPT_BLOCK_SIZE = 64 * 1024 101INT_64_MAX = 9223372036854775807 102 103cpu_cast = Cast().set_device("CPU") 104 105_ckpt_fs = FileSystem() 106 107 108def init_ckpt_file_system(fs: FileSystem): 109 """Initialize checkpoint file system""" 110 if _register_mindio_file_system(fs): 111 return 112 _register_basic_file_system(fs) 113 114 115# Initialize checkpoint file system 116init_ckpt_file_system(_ckpt_fs) 117 118 119class ParamDictFuture: 120 def __init__(self, executor, param_dict_future): 121 self.executor = executor 122 self.param_dict_future = param_dict_future 123 124 def result(self): 125 param_dict = self.param_dict_future.result() 126 self.executor.shutdown() 127 return param_dict 128 129 130def _special_process_par(par, new_par): 131 """ 132 Processes the special condition. 133 134 Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor. 135 """ 136 par_shape_len = len(par.data.shape) 137 new_par_shape_len = len(new_par.data.shape) 138 if new_par_shape_len <= par_shape_len: 139 return False 140 141 for i in range(new_par_shape_len - par_shape_len): 142 if new_par.data.shape[par_shape_len + i] != 1: 143 return False 144 145 if new_par.data.dtype == mstype.bfloat16: 146 new_val = cpu_cast(new_par.data, mstype.float32).asnumpy() 147 else: 148 new_val = new_par.data.asnumpy() 149 150 new_val = new_val.reshape(par.data.shape) 151 par.set_data(Tensor(new_val, par.data.dtype)) 152 return True 153 154 155def _update_param(param, new_param, strict_load): 156 """Updates param's data from new_param's data.""" 157 if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): 158 if param.data.shape != new_param.data.shape: 159 if not _special_process_par(param, new_param): 160 logger.critical("Failed to combine the net and the parameters for param %s.", param.name) 161 msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same shape " 162 f"as {param.name} in the argument 'parameter_dict'. But got its shape {param.data.shape} in" 163 f" the argument 'net' and shape {new_param.data.shape} in the argument 'parameter_dict'." 164 f"May you need to check whether the checkpoint you loaded is correct or the batch size and " 165 f"so on in the 'net' and 'parameter_dict' are same.") 166 raise RuntimeError(msg) 167 168 if param.data.dtype != new_param.data.dtype: 169 if _type_convert(param, new_param, strict_load): 170 if new_param.data.dtype == mstype.bfloat16: 171 new_tensor = cpu_cast(new_param.data, param.data.dtype) 172 else: 173 new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype) 174 param.set_data(new_tensor, param.sliced) 175 return 176 177 logger.critical("Failed to combine the net and the parameters for param %s.", param.name) 178 msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same type as " 179 f"{param.name} in the argument 'parameter_dict'. but got its type {param.data.dtype} in the " 180 f"argument 'net' and type {new_param.data.dtype} in the argument 'parameter_dict'." 181 f"May you need to check whether the checkpoint you loaded is correct.") 182 raise RuntimeError(msg) 183 184 param.set_data(new_param.data, param.sliced) 185 return 186 187 if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): 188 if param.data.shape != (1,) and param.data.shape != (): 189 logger.critical("Failed to combine the net and the parameters for param %s.", param.name) 190 msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is " 191 f"scalar, then the shape of {param.name} in the argument 'net' should be " 192 f"(1,) or (), but got shape {param.data.shape}." 193 f"May you need to check whether the checkpoint you loaded is correct.") 194 raise RuntimeError(msg) 195 param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype)) 196 197 elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor): 198 logger.critical("Failed to combine the net and the parameters for param %s.", param.name) 199 msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is Tensor, " 200 f"then {param.name} in the argument 'net' also should be Tensor, but got {type(param.data)}." 201 f"May you need to check whether the checkpoint you loaded is correct.") 202 raise RuntimeError(msg) 203 204 else: 205 param.set_data(type(param.data)(new_param.data)) 206 207 208def _type_convert(param, new_param, strict_load): 209 """Whether to convert parameter's type during load checkpoint into network.""" 210 float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16) 211 int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64) 212 if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or 213 {param.data.dtype, new_param.data.dtype}.issubset(int_type)): 214 logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from " 215 f"the type of it in 'net':{param.data.dtype}, then the type convert from " 216 f"{new_param.data.dtype} to {param.data.dtype} in the network.") 217 return True 218 return False 219 220 221def _save_weight(checkpoint_dir, model_name, iteration, params): 222 """Save model weight into checkpoint.""" 223 logger.debug(f"Checkpoint dir is: '{checkpoint_dir}'") 224 exist_ckpt_file_list = [] 225 if os.path.exists(checkpoint_dir): 226 for exist_ckpt_name in os.listdir(checkpoint_dir): 227 file_prefix = os.path.join(model_name, "_iteration_") 228 if exist_ckpt_name.startswith(file_prefix): 229 exist_ckpt_file_list.append(exist_ckpt_name) 230 231 param_dict = OrderedDict() 232 for key in params.keys(): 233 value = params[key] 234 weight_type = value[0] 235 weight_shape = value[1] 236 weight_data = value[2] 237 weight_size = value[3] 238 weight_np = np.array(weight_data, dtype=weight_type.lower()) 239 logger.debug(f"weight_type: '{weight_type}', weight_shape: '{weight_shape}', weight_size: " 240 f"'{weight_size}', weight_np.nbytes: '{weight_np.nbytes}'") 241 242 param_dict[key] = [weight_shape, weight_type, weight_np] 243 ckpt_file_save_name = model_name + "_iteration_" + iteration + ".ckpt" 244 ckpt_file_save_path = os.path.join(checkpoint_dir, ckpt_file_save_name) 245 246 _exec_save(ckpt_file_save_path, param_dict) 247 248 for exist_ckpt_name in exist_ckpt_file_list: 249 os.remove(os.path.join(checkpoint_dir, exist_ckpt_name)) 250 logger.info(f"Save weight to checkpoint file path '{ckpt_file_save_path}' success.") 251 else: 252 logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.") 253 254 255def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False): 256 """Execute the process of saving checkpoint into file.""" 257 try: 258 with _ckpt_mutex: 259 if os.path.exists(ckpt_file_name): 260 os.chmod(ckpt_file_name, stat.S_IWUSR) 261 os.remove(ckpt_file_name) 262 with _ckpt_fs.create(ckpt_file_name, *_ckpt_fs.create_args) as f: 263 plain_data = None 264 if enc_key is not None: 265 plain_data = BytesIO() 266 267 crc_num = 0 268 for name, value in data_list.items(): 269 if name == "random_op": 270 _write_random_seed(name, value, f) 271 continue 272 if value[0] == "mapparameter": 273 _write_mapparameter(name, value, f, map_param_inc) 274 continue 275 if value[0] == "offload_parameter": 276 new_value = value[1:] 277 new_value[2] = value[3] 278 _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data) 279 _offload_if_config(value[3]) 280 continue 281 if value[1] == "str": 282 crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check) 283 continue 284 if isinstance(value[2], np.ndarray): 285 crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check) 286 continue 287 if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1: 288 _write_hugeparameter(name, value, f) 289 continue 290 291 crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check) 292 293 if enc_key is not None: 294 plain_data.seek(0) 295 max_block_size = ENCRYPT_BLOCK_SIZE * 1024 296 block_data = plain_data.read(max_block_size) 297 while block_data: 298 f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode)) 299 block_data = plain_data.read(max_block_size) 300 301 if crc_check: 302 f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big')) 303 304 os.chmod(ckpt_file_name, stat.S_IRUSR) 305 306 except BaseException as e: 307 logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, " 308 "or the disk space is insufficient and so on.", ckpt_file_name) 309 raise e 310 311 312def _write_random_seed(name, value, f): 313 """Write random op into protobuf file.""" 314 checkpoint_list = Checkpoint() 315 param_value = checkpoint_list.value.add() 316 param_value.tag = name 317 param_tensor = param_value.tensor 318 param_tensor.dims.extend(0) 319 param_tensor.tensor_type = "random_op" 320 param_tensor.tensor_content = value 321 f.write(checkpoint_list.SerializeToString()) 322 323 324def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False): 325 """Write parameter data into protobuf file.""" 326 data_size = value[2].nbytes / 1024 327 if data_size > SLICE_SIZE: 328 slice_count = math.ceil(data_size / SLICE_SIZE) 329 param_slice_list = np.array_split(value[2], slice_count) 330 else: 331 param_slice_list = [value[2]] 332 333 for param_slice in param_slice_list: 334 checkpoint_list = Checkpoint() 335 param_value = checkpoint_list.value.add() 336 param_value.tag = name 337 param_tensor = param_value.tensor 338 param_tensor.dims.extend(value[0]) 339 param_tensor.tensor_type = value[1] 340 param_tensor.tensor_content = param_slice.tobytes() 341 342 if enc_key is None: 343 output_data = checkpoint_list.SerializeToString() 344 if crc_check: 345 crc_num = binascii.crc32(output_data, crc_num) 346 f.write(output_data) 347 else: 348 plain_data.write(checkpoint_list.SerializeToString()) 349 350 return crc_num 351 352 353def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False): 354 """Write parameter bytes data into protobuf file.""" 355 bytes_value = value[2].get_bytes() 356 chunk_size = 1024 * SLICE_SIZE 357 358 for i in range(0, len(bytes_value), chunk_size): 359 checkpoint_list = Checkpoint() 360 param_value = checkpoint_list.value.add() 361 param_value.tag = name 362 param_tensor = param_value.tensor 363 param_tensor.dims.extend(value[0]) 364 param_tensor.tensor_type = value[1] 365 param_tensor.tensor_content = bytes_value[i:i + chunk_size] 366 367 if enc_key is None: 368 output_data = checkpoint_list.SerializeToString() 369 if crc_check: 370 crc_num = binascii.crc32(output_data, crc_num) 371 f.write(output_data) 372 else: 373 plain_data.write(checkpoint_list.SerializeToString()) 374 375 return crc_num 376 377 378def _write_mapparameter(name, value, f, map_param_inc=False): 379 """Write map parameter into protobuf file.""" 380 while True: 381 logger.info("Checkpoint save map_parameter.") 382 data_map_slice = value[1].export_slice_data(map_param_inc) 383 checkpoint_list = Checkpoint() 384 param_value = checkpoint_list.value.add() 385 param_value.tag = name 386 map_tensor = param_value.maptensor 387 for numpy_data in data_map_slice[:3]: 388 tensor_pro = map_tensor.tensor.add() 389 tensor_pro.dims.extend(numpy_data.shape) 390 tensor_pro.tensor_type = str(numpy_data.dtype) 391 tensor_pro.tensor_content = numpy_data.reshape(-1).tobytes() 392 f.write(checkpoint_list.SerializeToString()) 393 if data_map_slice[3]: 394 break 395 396 397def _write_hugeparameter(name, value, f): 398 """Write huge parameter into protobuf file.""" 399 slice_num = value[2].slice_num 400 offset = 0 401 max_size = value[0][0] 402 for param_slice in range(slice_num): 403 checkpoint_list = Checkpoint() 404 param_value = checkpoint_list.value.add() 405 param_value.tag = name 406 param_tensor = param_value.tensor 407 param_tensor.dims.extend(value[0]) 408 param_tensor.tensor_type = value[1] 409 param_key = value[3] 410 numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice) 411 if offset + numpy_data.shape[0] > max_size: 412 numpy_data = numpy_data[:max_size - offset] 413 param_tensor.tensor_content = numpy_data.tobytes() 414 f.write(checkpoint_list.SerializeToString()) 415 offset += numpy_data.shape[0] 416 417 418def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name): 419 """Check save_obj and ckpt_file_name for save_checkpoint.""" 420 if not isinstance(save_obj, (nn.Cell, list, dict)): 421 raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, " 422 "but got {}.".format(type(save_obj))) 423 if not isinstance(ckpt_file_name, str): 424 raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid," 425 "'ckpt_file_name' must be " 426 "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name))) 427 ckpt_file_name = os.path.abspath(ckpt_file_name) 428 if os.path.isdir(ckpt_file_name): 429 raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, " 430 "it must be a file name.".format(ckpt_file_name)) 431 if not ckpt_file_name.endswith('.ckpt'): 432 ckpt_file_name += ".ckpt" 433 return ckpt_file_name 434 435 436def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, 437 async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, 438 crc_check=False, **kwargs): 439 r""" 440 Save checkpoint to a specified file. 441 442 Note: 443 The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously. 444 445 Args: 446 save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`, 447 list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict 448 elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of 449 `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict, 450 it can be the returned value of `mindspore.load_checkpoint()`. 451 ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. 452 integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` . 453 async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` . 454 append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value 455 of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` . 456 enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption 457 is not required. Default: ``None`` . 458 enc_mode (str): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption 459 mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` . 460 Default: ``"AES-GCM"`` . 461 choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is 462 a parameter name in string type, and the returned value is a bool. 463 If returns ``True`` , the Parameter that matching the custom condition will be saved. 464 If returns ``False`` , the Parameter that not matching the custom condition will not 465 be saved. Default: ``None`` . 466 crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation 467 result to the file. Default: ``False`` . 468 kwargs (dict): Configuration options dictionary. 469 470 Raises: 471 TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type. 472 TypeError: If the parameter `integrated_save` or `async_save` is not bool type. 473 TypeError: If the parameter `ckpt_file_name` is not string type. 474 475 Examples: 476 >>> import mindspore as ms 477 >>> 478 >>> # Define the network structure of LeNet5. Refer to 479 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 480 >>> net = LeNet5() 481 >>> ms.save_checkpoint(net, "./lenet.ckpt", 482 ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1")) 483 >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt") 484 >>> print(param_dict1) 485 {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)} 486 >>> params_list = net.trainable_params() 487 >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt", 488 ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2")) 489 >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt") 490 >>> print(param_dict2) 491 {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)} 492 >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt") 493 >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt") 494 >>> print(param_dict3) 495 {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)} 496 497 Tutorial Examples: 498 - `Saving and Loading the Model - Saving and Loading the Model Weight 499 <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_ 500 """ 501 ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name) 502 integrated_save = Validator.check_bool(integrated_save) 503 async_save = Validator.check_bool(async_save) 504 append_dict = _check_append_dict(append_dict) 505 enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) 506 enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) 507 crc_check = Validator.check_isinstance('crc_check', crc_check, bool) 508 map_param_inc = kwargs.get('incremental', False) 509 logger.info("Execute the process of saving checkpoint files.") 510 global_step_num = kwargs.get('global_step_num', None) 511 512 save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func) 513 514 if append_dict: 515 append_info_list = [] 516 for k_name, value in append_dict.items(): 517 if isinstance(value, Generator): 518 value = value.get_state() 519 elif not isinstance(value, str): 520 value = Tensor(value) 521 append_info_list.append({"name": k_name, "data": value}) 522 save_obj.extend(append_info_list) 523 524 data_list = OrderedDict() 525 data_list_np = OrderedDict() 526 with _ckpt_mutex: 527 for param in save_obj: 528 if param["name"] == "random_op": 529 if os.getenv("AITURBO") == "1": 530 data_list_np["random_op"] = param["data"] 531 else: 532 data_list["random_op"] = param["data"] 533 continue 534 key = param["name"] 535 data_list[key] = [] 536 if isinstance(param["data"], MapParameter): 537 data_list[param["name"]].append("mapparameter") 538 data_list[param["name"]].append(param["data"]) 539 continue 540 if isinstance(param["data"], list): 541 if param["data"][0] == "persistent_data": 542 _save_param_list_data(data_list, key, param) 543 elif param["data"][0] == "offload_parameter": 544 data_list[key].append("offload_parameter") 545 _save_param_list_data(data_list, key, param) 546 547 if isinstance(param["data"], str): 548 if os.getenv("AITURBO") == "1": 549 data_list_np[key] = np.array(param["data"]) 550 else: 551 data_list[key].append([0]) 552 data_list[key].append('str') 553 data = np.array(param["data"]) 554 data_list[key].append(data) 555 else: 556 if isinstance(param["data"], Parameter): 557 param["data"].init_data() 558 if os.getenv("AITURBO") == "1": 559 data_list_np[key] = param["data"].asnumpy() 560 else: 561 dims = [] 562 for dim in param['data'].shape: 563 dims.append(dim) 564 data_list[key].append(dims) 565 tensor_type = str(param["data"].dtype) 566 data_list[key].append(tensor_type) 567 data = param["data"] 568 data_list[key].append(data) 569 570 if os.getenv("AITURBO") == "1": 571 import aiturbo 572 ckpt_name = os.path.basename(ckpt_file_name) 573 aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np) 574 elif async_save: 575 data_copy = copy.deepcopy(data_list) 576 thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check), 577 name="asyn_save_ckpt") 578 thr.start() 579 else: 580 _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check) 581 582 logger.info("Saving checkpoint process is finished.") 583 584 585def _convert_list_to_param_list(save_obj, choice_func): 586 """Convert a list of Parameter to param_list.""" 587 param_list = [] 588 if not save_obj: 589 return param_list 590 if isinstance(save_obj[0], dict): 591 for param in save_obj: 592 if isinstance(param, dict) and "name" in param and "data" in param: 593 if not isinstance(param["name"], str): 594 raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict " 595 f"should be string, but got {type(param['name'])}.") 596 if not isinstance(param["data"], Tensor): 597 raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict " 598 f"should be parameter, but got {type(param['data'])}.") 599 if choice_func is not None and not choice_func(param["name"]): 600 continue 601 each_param = {"name": param["name"], "data": param["data"]} 602 param_list.append(each_param) 603 else: 604 raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should " 605 f"have key values 'name' and 'value', but got {type(param)} and {param}.") 606 else: 607 for param in save_obj: 608 if isinstance(param, Parameter): 609 if choice_func is not None and not choice_func(param.name): 610 continue 611 each_param = {"name": param.name, "data": param} 612 param_list.append(each_param) 613 else: 614 raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter," 615 f"the param should be parameter, but got {type(param)}") 616 return param_list 617 618 619def _convert_dict_to_param_dict(save_obj, choice_func): 620 """Convert a dict of Parameter to param_list.""" 621 param_list = [] 622 for (key, value) in save_obj.items(): 623 if isinstance(key, str) and isinstance(value, (Parameter, str)): 624 if choice_func is not None and not choice_func(key): 625 continue 626 each_param = {"name": key, "data": value} 627 param_list.append(each_param) 628 else: 629 raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and" 630 f"value should be Parameter, but got the type of key is {type(key)} and" 631 f"the type of value is {type(value)}") 632 return param_list 633 634 635def _convert_cell_param_and_names_to_dict(save_obj, choice_func): 636 """Convert cell.parameters_and_names to OrderedDict.""" 637 param_dict = OrderedDict() 638 for _, param in save_obj.parameters_and_names(): 639 not_sliced = not param.sliced 640 is_graph_mode = context.get_context('mode') == context.GRAPH_MODE 641 # All parameters are initialized immediately under PyNative mode, skip this judgement. 642 judgment = not_sliced or param.has_init 643 if is_graph_mode and _is_in_auto_parallel_mode() and judgment: 644 continue 645 if choice_func is not None and not choice_func(param.name): 646 continue 647 # Add suffix for cache_enabled parameter, and then parameter can carry key info. 648 # Notice that suffix needs be removed when loading into net. 649 if param.cache_enable: 650 param_dict[param.name + ".__param_key__" + str(param.key)] = param 651 else: 652 param_dict[param.name] = param 653 return param_dict 654 655 656def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func): 657 """Convert nn.Cell to param_list.""" 658 sync_pipeline_shared_parameters(save_obj) 659 param_list = [] 660 parameter_layout_dict = save_obj.parameter_layout_dict 661 if _is_in_auto_parallel_mode() and not parameter_layout_dict: 662 parameter_layout_dict = _get_parameter_layout() 663 if not _is_in_auto_parallel_mode(): 664 save_obj.init_parameters_data() 665 param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func) 666 if append_dict and "random_op" in append_dict: 667 phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key 668 if phase in save_obj.compile_cache and _executor.has_compiled(phase): 669 random_byte = _executor._graph_executor.get_random_status(phase) 670 param_list.append({"name": "random_op", "data": random_byte}) 671 append_dict.pop("random_op") 672 for (key, value) in param_dict.items(): 673 each_param = {"name": key} 674 if isinstance(value, MapParameter): 675 each_param["data"] = value 676 param_list.append(each_param) 677 continue 678 679 if value.data.is_persistent_data(): 680 # list save persistent_data: [Tensor, shape, type, param.key] 681 param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key] 682 elif value.data.offload_file_path() != "": 683 # list save offload data: [Param, shape, type, param.key] 684 param_data = ["offload_parameter"] 685 param_tensor = value.data 686 if key in parameter_layout_dict: 687 param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor, 688 integrated_save) 689 param_data.append(param_tensor) 690 param_data.append(param_tensor.shape) 691 param_data.append(str(param_tensor.dtype)) 692 param_data.append(value.key) 693 else: 694 param_data = value.data 695 696 # in automatic model parallel scenario, some parameters were split to all the devices, 697 # which should be combined before saving 698 if key in parameter_layout_dict: 699 param_data = Tensor(value.data) 700 param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data, 701 integrated_save) 702 703 each_param["data"] = param_data 704 param_list.append(each_param) 705 return param_list 706 707 708def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func): 709 """Convert a save_obj to param_list.""" 710 if isinstance(save_obj, list): 711 return _convert_list_to_param_list(save_obj, choice_func) 712 713 if isinstance(save_obj, dict): 714 return _convert_dict_to_param_dict(save_obj, choice_func) 715 716 return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func) 717 718 719def _save_param_list_data(data_list, key, param): 720 """Save persistent data into save_obj.""" 721 dims = [] 722 # persistent_data shape can not be () 723 for dim in param['data'][2]: 724 dims.append(dim) 725 data_list[key].append(dims) 726 data_list[key].append(param['data'][3]) 727 data_list[key].append(param['data'][1]) 728 data_list[key].append(param['data'][4]) 729 730 731def _check_append_dict(append_dict): 732 """Check the argument append_dict for save_checkpoint.""" 733 if append_dict is None: 734 return append_dict 735 if not isinstance(append_dict, dict): 736 raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got " 737 "{}.".format(type(append_dict))) 738 for key, value in append_dict.items(): 739 if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)): 740 raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, " 741 f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}") 742 return append_dict 743 744 745def _check_load_obfuscate(**kwargs): 746 if 'obf_func' in kwargs.keys(): 747 customized_func = _check_customized_func(kwargs.get('obf_func')) 748 clean_funcs() 749 add_opaque_predicate(customized_func.__name__, customized_func) 750 return True 751 return False 752 753 754def load(file_name, **kwargs): 755 """ 756 Load MindIR. 757 758 The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details. 759 760 Args: 761 file_name (str): MindIR file name. 762 763 kwargs (dict): Configuration options dictionary. 764 765 - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32. 766 - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set. 767 768 - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``. 769 - For details of using the customized decryption, please check the `tutorial 770 <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_. 771 772 - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to 773 `obfuscate_model() 774 <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_. 775 776 Returns: 777 GraphCell, a compiled graph that can executed by `GraphCell`. 778 779 Raises: 780 ValueError: MindIR file does not exist or `file_name` is not a string. 781 RuntimeError: Failed to parse MindIR file. 782 783 Examples: 784 >>> import numpy as np 785 >>> import mindspore as ms 786 >>> import mindspore.nn as nn 787 >>> from mindspore import Tensor 788 >>> from mindspore import context 789 >>> context.set_context(mode=context.GRAPH_MODE) 790 >>> 791 >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") 792 >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) 793 >>> ms.export(net, input_tensor, file_name="net", file_format="MINDIR") 794 >>> graph = ms.load("net.mindir") 795 >>> net = nn.GraphCell(graph) 796 >>> output = net(input_tensor) 797 >>> print(output) 798 [[[[4. 6. 4.] 799 [6. 9. 6.] 800 [4. 6. 4.]]]] 801 802 Tutorial Examples: 803 - `Saving and Loading the Model - Saving and Loading MindIR 804 <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_ 805 """ 806 if not isinstance(file_name, str): 807 raise ValueError("For 'load', the argument 'file_name' must be string, but " 808 "got {}.".format(type(file_name))) 809 if not file_name.endswith(".mindir"): 810 raise ValueError("For 'load', the argument 'file_name'(MindIR file) should end with '.mindir', " 811 "please input the correct 'file_name'.") 812 if not os.path.exists(file_name): 813 raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, " 814 "please check whether the 'file_name' is correct.") 815 file_name = os.path.abspath(file_name) 816 817 # set customized functions for dynamic obfuscation 818 obfuscated = _check_load_obfuscate(**kwargs) 819 820 logger.info("Execute the process of loading mindir.") 821 if 'dec_key' in kwargs.keys(): 822 dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes) 823 dec_mode = "AES-GCM" 824 dec_func = None 825 if 'dec_mode' in kwargs.keys(): 826 if callable(kwargs.get('dec_mode')): 827 dec_mode = "Customized" 828 dec_func = kwargs.get('dec_mode') 829 else: 830 dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str) 831 graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode, 832 decrypt=dec_func, obfuscated=obfuscated) 833 else: 834 graph = load_mindir(file_name, obfuscated=obfuscated) 835 836 if graph is None: 837 if _is_cipher_file(file_name): 838 raise RuntimeError("Load MindIR failed. The file may be encrypted and decrypt failed, you " 839 "can check whether the values of the arguments 'dec_key' and 'dec_mode'" 840 " are the same as when exported MindIR file, or check the file integrity.") 841 raise RuntimeError("Load MindIR failed.") 842 return graph 843 844 845def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True): 846 """ 847 Auto Split MindIR. 848 849 The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details. 850 851 Args: 852 file_name (str): MindIR file name. 853 device_num (int): device number. Default: '8'. 854 rank_id (int): rank id. Default: '0'. 855 dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'. 856 sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'. 857 858 Raises: 859 ValueError: MindIR file does not exist or `file_name` is not a string. 860 RuntimeError: Failed to split MindIR file. 861 862 Examples: 863 >>> import mindspore as ms 864 >>> context.set_context(mode=context.GRAPH_MODE) 865 >>> 866 >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0) 867 868 """ 869 if not isinstance(file_name, str): 870 raise ValueError("For 'Split MindIR', the argument 'file_name' must be string, but " 871 "got {}.".format(type(file_name))) 872 if not file_name.endswith(".mindir"): 873 raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) should end with '.mindir', " 874 "please input the correct 'file_name'.") 875 if not os.path.exists(file_name): 876 raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, " 877 "please check whether the 'file_name' is correct.") 878 file_name = os.path.abspath(file_name) 879 880 logger.info("Execute the process of export and split mindir.") 881 dynamic = True 882 if dynamic: 883 graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp) 884 else: 885 graph = split_mindir(file_name) 886 887 if graph is None: 888 if _is_cipher_file(file_name): 889 raise RuntimeError("Export and split MindIR failed. The file may be encrypted and decrypt failed, you " 890 "can check whether the values of the arguments 'dec_key' and 'dec_mode'" 891 " are the same as when exported MindIR file, or check the file integrity.") 892 raise RuntimeError("Export and split MindIR failed.") 893 return graph 894 895 896def _check_param_type(param_config, key, target_type, requested): 897 """check type of parameters""" 898 if key in param_config: 899 if not isinstance(param_config[key], target_type): 900 raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key]))) 901 if key == 'obf_random_seed': 902 if param_config[key] > INT_64_MAX or param_config[key] <= 0: 903 raise ValueError( 904 "'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX, 905 param_config[key])) 906 return param_config[key] 907 if requested: 908 raise ValueError("The parameter {} is requested, but not got.".format(key)) 909 if key == "obf_random_seed": 910 return 0 911 return None 912 913 914def _check_customized_func(customized_func): 915 """ check customized function of dynamic obfuscation """ 916 if not callable(customized_func): 917 raise TypeError( 918 "'customized_func' must be a function, but not got {}.".format(type(customized_func))) 919 # test customized_func 920 try: 921 func_result = customized_func(1.0, 1.0) 922 except Exception as ex: 923 raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex)) 924 else: 925 if not isinstance(func_result, bool): 926 raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result))) 927 return customized_func 928 929 930def _check_obfuscate_params(obf_config): 931 """Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func""" 932 if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys(): 933 raise ValueError( 934 "At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.") 935 obfuscate_type = _check_param_type(obf_config, "type", str, False) 936 if obfuscate_type not in (None, "dynamic"): 937 raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type)) 938 if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str): 939 if obf_config['obf_ratio'] not in ["small", "medium", "large"]: 940 raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format( 941 obf_config['obf_ratio'])) 942 ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6} 943 obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio']) 944 obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True) 945 if (obf_ratio <= 0) or (obf_ratio > 1): 946 raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio'])) 947 customized_funcs = [] 948 if 'customized_func' in obf_config.keys(): 949 device_target = context.get_context('device_target') 950 if device_target in ["GPU", "Ascend"]: 951 raise ValueError( 952 "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target)) 953 customized_funcs.append(_check_customized_func(obf_config['customized_func'])) 954 obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False) 955 return obf_ratio, customized_funcs, obf_random_seed 956 957 958def obfuscate_model(obf_config, **kwargs): 959 """ 960 Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its 961 predict correctness. The obfuscated model can prevent attackers from stealing the model. 962 963 Args: 964 obf_config (dict): obfuscation config. 965 966 - type (str): The type of obfuscation, only 'dynamic' is supported until now. 967 - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original 968 model is encrypted, then enc_key and enc_mode should be provided. 969 - save_model_path (str): The path to save the obfuscated model. 970 - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which 971 is the same as using :func:`mindspore.export`. 972 - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio` 973 should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are 974 correspond to 0.1, 0.3, and 0.6 respectively. 975 - customized_func (function): A python function used for customized function mode, which used for control 976 the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const ( 977 Reference to 'my_func()' in 978 `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_). 979 This function needs to ensure that its result is constant for any input. Users can refer to opaque 980 predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface 981 when loading obfuscated model. 982 - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The 983 structure of obfuscated models corresponding to different random seeds is different. If 984 `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell` 985 interface when loading 986 obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should 987 be set, and the latter mode would be applied if both of them are set. 988 989 kwargs (dict): Configuration options dictionary. 990 991 - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32. 992 - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set. 993 Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``. 994 995 Raises: 996 TypeError: If `obf_config` is not a dict. 997 ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"]. 998 ValueError: If `original_model_path` is not provided in `obf_config`. 999 ValueError: If the model saved in `original_model_path` has been obfuscated. 1000 ValueError: If `save_model_path` is not provided in `obf_config`. 1001 ValueError: If `obf_ratio` is not provided in `obf_config`. 1002 ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`. 1003 ValueError: If `obf_random_seed` is not in (0, 9223372036854775807]. 1004 ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'. 1005 1006 Examples: 1007 >>> import mindspore as ms 1008 >>> import mindspore.nn as nn 1009 >>> import numpy as np 1010 >>> # Download ori_net.mindir 1011 >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir 1012 >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32)) 1013 >>> obf_config = {'original_model_path': "./net.mindir", 1014 ... 'save_model_path': "./obf_net", 1015 ... 'model_inputs': [input1, ], 1016 ... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423} 1017 >>> ms.obfuscate_model(obf_config) 1018 >>> obf_func = ms.load("obf_net.mindir") 1019 >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423) 1020 >>> print(obf_net(input1).asnumpy()) 1021 """ 1022 if not isinstance(obf_config, dict): 1023 raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config))) 1024 file_path = _check_param_type(obf_config, "original_model_path", str, True) 1025 if not file_path.endswith(".mindir"): 1026 raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', " 1027 "please input the correct 'file_path'.") 1028 if not os.path.exists(file_path): 1029 raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, " 1030 "please check whether the 'file_path' is correct.") 1031 saved_path = _check_param_type(obf_config, "save_model_path", str, True) 1032 model_inputs = _check_param_type(obf_config, "model_inputs", list, True) 1033 for item in model_inputs: 1034 if not isinstance(item, Tensor): 1035 raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item))) 1036 if -1 in item.shape: 1037 raise ValueError( 1038 "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape)) 1039 obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config) 1040 if customized_funcs and obf_random_seed > 0: 1041 logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be" 1042 " applied, remember to set 'obf_random_seed' when loading obfuscated model.") 1043 1044 if obf_random_seed == 0: # apply customized_func mode 1045 clean_funcs() 1046 for func in customized_funcs: 1047 add_opaque_predicate(func.__name__, func) 1048 branch_control_input = 0 1049 else: # apply password mode 1050 branch_control_input = _generate_branch_control_input(obf_random_seed) 1051 1052 if 'enc_key' in kwargs.keys(): 1053 enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes) 1054 enc_mode = "AES-GCM" 1055 if 'enc_mode' in kwargs.keys(): 1056 enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str) 1057 if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]: 1058 raise ValueError( 1059 "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for" 1060 "obfuscate_model(), but got {}.".format(enc_mode)) 1061 obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, 1062 branch_control_input=branch_control_input, dec_key=enc_key, 1063 key_len=len(enc_key), 1064 dec_mode=enc_mode) 1065 else: 1066 obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, 1067 branch_control_input=branch_control_input) 1068 1069 obf_net = nn.GraphCell(obf_graph) 1070 if obf_random_seed != 0: 1071 append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32)) 1072 model_inputs += [append_y_tensor] 1073 export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs) 1074 1075 1076def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key, 1077 dec_mode, crc_check): 1078 """load parameter into parameter_dict""" 1079 ckpt_file_name = _check_ckpt_file_name(ckpt_file_name) 1080 checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check) 1081 try: 1082 param_data_list = [] 1083 map_data_list = [[], [], []] 1084 map_shape_list = [0, 0, 0] 1085 if specify_prefix: 1086 logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, " 1087 "please use `choice_func` instead.") 1088 if filter_prefix: 1089 logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, " 1090 "please use `choice_func` instead.") 1091 for element_id, element in enumerate(checkpoint_list.value): 1092 if element.tag == "random_op": 1093 parameter_dict["random_op"] = element.tensor.tensor_content 1094 continue 1095 if not _whether_load_param(specify_prefix, filter_prefix, element.tag): 1096 continue 1097 if specify_prefix is None and filter_prefix is None and \ 1098 choice_func is not None and not choice_func(element.tag): 1099 continue 1100 if element.tensor.ByteSize() == 0: 1101 _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, 1102 parameter_dict) 1103 if element.tag in parameter_dict: 1104 map_data_list = [[], [], []] 1105 map_shape_list = [0, 0, 0] 1106 continue 1107 data = element.tensor.tensor_content 1108 data_type = element.tensor.tensor_type 1109 np_type = tensor_to_np_type.get(data_type) 1110 ms_type = tensor_to_ms_type[data_type] 1111 if data_type == 'str': 1112 str_length = int(len(data) / 4) 1113 np_type = np_type + str(str_length) 1114 param_data_list.append(data) 1115 if (element_id == len(checkpoint_list.value) - 1) or \ 1116 (element.tag != checkpoint_list.value[element_id + 1].tag): 1117 new_data = b"".join(param_data_list) 1118 param_data_list.clear() 1119 dims = element.tensor.dims 1120 if data_type == 'str': 1121 str_value = np.frombuffer(new_data, np_type) 1122 parameter_dict[element.tag] = str(str_value[0]) 1123 else: 1124 if dims == [0]: 1125 dims = [] 1126 param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type) 1127 parameter = Parameter(param_data, name=element.tag) 1128 parameter_dict[element.tag] = parameter 1129 _offload_if_config(parameter) 1130 1131 logger.info("Loading checkpoint files process is finished.") 1132 1133 except BaseException as e: 1134 logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name) 1135 raise ValueError(e.__str__() + "\nFor 'load_checkpoint', " 1136 "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e 1137 1138 1139def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, 1140 dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None, 1141 crc_check=False): 1142 """ 1143 Load checkpoint info from a specified file. 1144 1145 Note: 1146 - `specify_prefix` and `filter_prefix` do not affect each other. 1147 - If none of the parameters are loaded from checkpoint file, it will throw ValueError. 1148 - `specify_prefix` and `filter_prefix` are in the process of being deprecated, 1149 `choice_func` is recommended instead. 1150 And using either of those two args will override `choice_func` at the same time. 1151 1152 Args: 1153 ckpt_file_name (str): Checkpoint file name. 1154 net (Cell): The network where the parameters will be loaded. Default: ``None`` . 1155 strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter 1156 into net when parameter name's suffix in checkpoint file is the same as the 1157 parameter in the network. When the types are inconsistent perform type conversion 1158 on the parameters of the same type, such as float32 to float16. Default: ``False`` . 1159 filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the 1160 filter_prefix will not be loaded. Default: ``None`` . 1161 dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption 1162 is not required. Default: ``None`` . 1163 dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption 1164 mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` . 1165 Default: ``"AES-GCM"`` . 1166 specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the 1167 specify_prefix will be loaded. Default: ``None`` . 1168 choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string, 1169 and the return value is a bool. If returns ``True`` , the Parameter 1170 that matches the custom condition will be loaded. If returns ``False`` , the Parameter that 1171 matches the custom condition will be removed. Default: ``None`` . 1172 crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` . 1173 1174 Returns: 1175 Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of 1176 :func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig` 1177 are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string, 1178 then the return value obtained by loading checkpoint is string, and in other cases the return value is 1179 Parameter. 1180 1181 Raises: 1182 ValueError: Checkpoint file's format is incorrect. 1183 ValueError: Parameter's dict is None after load checkpoint file. 1184 TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect. 1185 1186 Examples: 1187 >>> import mindspore as ms 1188 >>> 1189 >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" 1190 >>> param_dict = ms.load_checkpoint(ckpt_file_name, 1191 ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1")) 1192 >>> print(param_dict["conv2.weight"]) 1193 Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True) 1194 >>> def func(param_name): 1195 ... whether_load = False 1196 ... if param_name.startswith("conv"): 1197 ... whether_load = True 1198 ... if param_name.startswith("conv1"): 1199 ... whether_load = False 1200 ... return whether_load 1201 >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func) 1202 >>> print(param_dict1["conv2.weight"]) 1203 Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True) 1204 >>> def func(param_name): 1205 ... whether_load = False 1206 ... if param_name.startswith("conv1"): 1207 ... whether_load = True 1208 ... return whether_load 1209 >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func) 1210 >>> print(param_dict2) 1211 {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)} 1212 1213 Tutorial Examples: 1214 - `Saving and Loading the Model - Saving and Loading the Model Weight 1215 <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_ 1216 """ 1217 specify_prefix = _check_prefix(specify_prefix) 1218 filter_prefix = _check_prefix(filter_prefix) 1219 dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) 1220 dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) 1221 crc_check = Validator.check_isinstance('crc_check', crc_check, bool) 1222 logger.info("Execute the process of loading checkpoint files.") 1223 1224 parameter_dict = {} 1225 1226 if os.getenv("AITURBO") == "1": 1227 rank_id = get_rank() 1228 import aiturbo 1229 ckpt_path = os.path.dirname(ckpt_file_name) 1230 ckpt_name = os.path.basename(ckpt_file_name) 1231 np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id) 1232 for key, value in np_dict.items(): 1233 if isinstance(value, str): 1234 parameter_dict[key] = value 1235 else: 1236 parameter_dict[key] = Parameter(Tensor(value), name=key) 1237 else: 1238 _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key, 1239 dec_mode, crc_check) 1240 1241 if not parameter_dict: 1242 raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether " 1243 f"'filter_prefix' or 'specify_prefix' are set correctly.") 1244 1245 if _warm_up_host_cache_enabled(parameter_dict): 1246 (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net) 1247 if net is not None: 1248 load_param_into_net(net, parameter_dict, strict_load) 1249 if _warm_up_host_cache_enabled(parameter_dict): 1250 _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict) 1251 1252 return parameter_dict 1253 1254 1255def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, 1256 dec_mode="AES-GCM", specify_prefix=None, choice_func=None): 1257 """ 1258 Load checkpoint info from a specified file asyncly. 1259 1260 .. warning:: 1261 This is an experimental API that is subject to change or deletion. 1262 1263 Note: 1264 - `specify_prefix` and `filter_prefix` do not affect each other. 1265 - If none of the parameters are loaded from checkpoint file, it will throw ValueError. 1266 - `specify_prefix` and `filter_prefix` are in the process of being deprecated, 1267 `choice_func` is recommended instead. 1268 And using either of those two args will override `choice_func` at the same time. 1269 1270 Args: 1271 ckpt_file_name (str): Checkpoint file name. 1272 net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` . 1273 strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load 1274 parameter into net when parameter name's suffix in checkpoint file is the 1275 same as the parameter in the network. When the types are inconsistent 1276 perform type conversion on the parameters of the same type, such as float32 1277 to float16. Default: ``False`` . 1278 filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters 1279 starting with the `filter_prefix` will not be loaded. Default: ``None`` . 1280 dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` , 1281 the decryption is not required. Default: ``None`` . 1282 dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies 1283 the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` 1284 and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` . 1285 specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters 1286 starting with the specify_prefix will be loaded. Default: ``None`` . 1287 choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type 1288 string, and the return value is a bool. If returns ``True`` , the Parameter 1289 that matches the custom condition will be loaded. If returns ``False`` , the Parameter that 1290 matches the custom condition will be removed. Default: ``None`` . 1291 1292 Returns: 1293 A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result. 1294 1295 Raises: 1296 ValueError: Checkpoint file's format is incorrect. 1297 ValueError: Parameter's dict is None after load checkpoint file. 1298 TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect. 1299 1300 Examples: 1301 >>> import mindspore 1302 >>> from mindspore import nn 1303 >>> from mindspore.train import Model 1304 >>> from mindspore.amp import FixedLossScaleManager 1305 >>> from mindspore import context 1306 >>> from mindspore import load_checkpoint_async 1307 >>> from mindspore import load_param_into_net 1308 >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 1309 >>> # Create the dataset taking MNIST as an example. Refer to 1310 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 1311 >>> dataset = create_dataset() 1312 >>> # Define the network structure of LeNet5. Refer to 1313 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 1314 >>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt" 1315 >>> net = LeNet5() 1316 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 1317 >>> loss_scale_manager = FixedLossScaleManager() 1318 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 1319 >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, 1320 ... loss_scale_manager=loss_scale_manager) 1321 >>> pd_future = load_checkpoint_async(ckpt_file) 1322 >>> model.build(train_dataset=dataset, epoch=2) 1323 >>> param_dict = pd_future.result() 1324 >>> load_param_into_net(net, param_dict) 1325 >>> model.train(2, dataset) 1326 >>> print("param dict len: ", len(param_dict), flush=True) 1327 """ 1328 from concurrent.futures import ThreadPoolExecutor 1329 executor = ThreadPoolExecutor(max_workers=2) 1330 param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix, 1331 dec_key, dec_mode, specify_prefix, choice_func) 1332 return ParamDictFuture(executor, param_dict_future) 1333 1334 1335def _load_map_parameter(checkpoint_list, element, element_id, map_data_list, 1336 map_shape_list, parameter_dict): 1337 """load map parameter.""" 1338 logger.info("Checkpoint load map_parameter.") 1339 if (element_id != len(checkpoint_list.value) - 1) and \ 1340 element.tag == checkpoint_list.value[element_id + 1].tag: 1341 for index, tensor in enumerate(element.maptensor.tensor): 1342 data = tensor.tensor_content 1343 data_type = tensor.tensor_type 1344 np_type = np_type_convert.get(data_type) 1345 element_data = np.frombuffer(data, np_type) 1346 map_data_list[index].append(element_data) 1347 map_shape_list[index] += tensor.dims[0] 1348 else: 1349 map_array = [] 1350 for index, tensor in enumerate(element.maptensor.tensor): 1351 data = tensor.tensor_content 1352 data_type = tensor.tensor_type 1353 np_type = np_type_convert.get(data_type) 1354 element_data = np.frombuffer(data, np_type) 1355 map_data_list[index].append(element_data) 1356 new_data = b"".join(map_data_list[index]) 1357 param_data = np.frombuffer(new_data, np_type) 1358 dims = tensor.dims 1359 dims[0] += map_shape_list[index] 1360 param_data = param_data.reshape(list(dims)) 1361 map_array.append(param_data) 1362 parameter_dict[element.tag] = map_array 1363 1364 1365def _check_ckpt_file_name(ckpt_file_name): 1366 """Check function load_checkpoint's ckpt_file_name.""" 1367 if not isinstance(ckpt_file_name, str): 1368 raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, " 1369 "but got {}.".format(type(ckpt_file_name))) 1370 1371 if ckpt_file_name[-5:] != ".ckpt": 1372 raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please " 1373 "input the correct 'ckpt_file_name'.") 1374 1375 ckpt_file_name = os.path.abspath(ckpt_file_name) 1376 if not os.path.exists(ckpt_file_name): 1377 raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check " 1378 "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name)) 1379 1380 return ckpt_file_name 1381 1382 1383def _check_prefix(prefix): 1384 """Check the correctness of the parameters.""" 1385 if prefix is None: 1386 return prefix 1387 if not isinstance(prefix, (str, list, tuple)): 1388 raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, " 1389 "list[string] or tuple[string], but got {}.".format(str(type(prefix)))) 1390 if isinstance(prefix, str): 1391 prefix = (prefix,) 1392 if not prefix: 1393 raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when" 1394 " 'specify_prefix' or 'filter_prefix' is list or tuple.") 1395 for index, pre in enumerate(prefix): 1396 if not isinstance(pre, str): 1397 raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, " 1398 "the element in it must be string, but got " 1399 f"{str(type(pre))} at index {index}.") 1400 if pre == "": 1401 raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' " 1402 "can't include ''.") 1403 return prefix 1404 1405 1406def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check): 1407 """Parse checkpoint protobuf.""" 1408 checkpoint_list = Checkpoint() 1409 try: 1410 if dec_key is None: 1411 with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f: 1412 pb_content = f.read() 1413 else: 1414 pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) 1415 if pb_content is None: 1416 raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.") 1417 if crc_check and pb_content[-17:-10] == b"crc_num": 1418 logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.") 1419 if pb_content[-17:-10] == b"crc_num": 1420 crc_num_bytes = pb_content[-10:] 1421 pb_content = pb_content[:-17] 1422 if crc_check: 1423 crc_num = int.from_bytes(crc_num_bytes, byteorder='big') 1424 cal_crc_num = binascii.crc32(pb_content, 0) 1425 if cal_crc_num != crc_num: 1426 raise ValueError("For 'load_checkpoint', the crc check is failed, " 1427 "please check whether the ckpt file is damaged.") 1428 checkpoint_list.ParseFromString(pb_content) 1429 except BaseException as e: 1430 if _is_cipher_file(ckpt_file_name): 1431 err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \ 1432 "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name) 1433 else: 1434 err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \ 1435 " the correct of the file.".format(ckpt_file_name) 1436 logger.error(err_info) 1437 raise ValueError(err_info) from e 1438 return checkpoint_list 1439 1440 1441def _whether_load_param(specify_prefix, filter_prefix, param_name): 1442 """Checks whether the load the parameter after `specify_prefix` or `filter_prefix`.""" 1443 whether_load = True 1444 if specify_prefix: 1445 whether_load = False 1446 for prefix in specify_prefix: 1447 if param_name.startswith(prefix): 1448 whether_load = True 1449 break 1450 if filter_prefix: 1451 for prefix in filter_prefix: 1452 if param_name.startswith(prefix): 1453 whether_load = False 1454 break 1455 return whether_load 1456 1457 1458def _init_parameter_data_in_parallel_mode(net, parameter_dict): 1459 """In parallel mode, only init the paraemters in ckpt.""" 1460 is_train_phase = net.phase.startswith('train') 1461 for _, param in net.parameters_and_names(): 1462 if param.name in parameter_dict and param.from_ckpt and not is_train_phase: 1463 param.shape = tuple(parameter_dict[param.name].shape) 1464 continue 1465 if param.name in parameter_dict and param.has_init: 1466 logger.warning("{} is not init while load ckpt.".format(param.name)) 1467 new_tensor = param.init_data() 1468 param._update_tensor_data(new_tensor) 1469 1470 1471def _check_load_param_into_net(net, parameter_dict): 1472 """check load_param_into_net""" 1473 if not isinstance(net, nn.Cell): 1474 logger.critical("Failed to combine the net and the parameters.") 1475 msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net))) 1476 raise TypeError(msg) 1477 if not isinstance(parameter_dict, dict): 1478 logger.critical("Failed to combine the net and the parameters.") 1479 msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, " 1480 "but got {}.".format(type(parameter_dict))) 1481 raise TypeError(msg) 1482 if "random_op" in parameter_dict.keys(): 1483 net._add_attr("random_op_snapshot", parameter_dict["random_op"]) 1484 parameter_dict.pop("random_op") 1485 1486 1487def load_param_into_net(net, parameter_dict, strict_load=False): 1488 """ 1489 Load parameters into network, return parameter list that are not loaded in the network. 1490 1491 Args: 1492 net (Cell): The network where the parameters will be loaded. 1493 parameter_dict (dict): The dictionary generated by load checkpoint file, 1494 it is a dictionary consisting of key: parameters's name, value: parameter. 1495 strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter 1496 into net when parameter name's suffix in checkpoint file is the same as the 1497 parameter in the network. When the types are inconsistent perform type conversion 1498 on the parameters of the same type, such as float32 to float16. Default: ``False`` . 1499 1500 Returns: 1501 - param_not_load (List), the parameter name in model which are not loaded into the network. 1502 - ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network. 1503 1504 Raises: 1505 TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. 1506 1507 Examples: 1508 >>> import mindspore as ms 1509 >>> 1510 >>> # Define the network structure of LeNet5. Refer to 1511 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 1512 >>> net = LeNet5() 1513 >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" 1514 >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1") 1515 >>> param_not_load, _ = ms.load_param_into_net(net, param_dict) 1516 >>> print(param_not_load) 1517 ['conv1.weight'] 1518 1519 Tutorial Examples: 1520 - `Saving and Loading the Model - Saving and Loading the Model Weight 1521 <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_ 1522 """ 1523 _check_load_param_into_net(net, parameter_dict) 1524 for key, value in parameter_dict.items(): 1525 if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)): 1526 logger.critical("Load parameters into net failed.") 1527 msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a " 1528 "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value))) 1529 raise TypeError(msg) 1530 1531 strict_load = Validator.check_bool(strict_load) 1532 logger.info("Execute the process of loading parameters into net.") 1533 for _, param in net.parameters_and_names(): 1534 param.from_ckpt = True 1535 if not _is_in_auto_parallel_mode(): 1536 net.init_parameters_data() 1537 else: 1538 _init_parameter_data_in_parallel_mode(net, parameter_dict) 1539 param_not_load = [] 1540 ckpt_not_load = list(parameter_dict.keys()) 1541 for _, param in net.parameters_and_names(): 1542 if param.name in parameter_dict: 1543 if isinstance(param, MapParameter): 1544 param.import_data(parameter_dict[param.name]) 1545 continue 1546 # Add has attr protection when load server checkpoint file on worker. 1547 if not hasattr(parameter_dict[param.name], "data"): 1548 continue 1549 new_param = parameter_dict[param.name] 1550 _update_param(param, new_param, strict_load) 1551 ckpt_not_load.remove(param.name) 1552 else: 1553 param_not_load.append(param.name) 1554 1555 if param_not_load and not strict_load: 1556 _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load) 1557 1558 logger.info("Loading parameters into net is finished.") 1559 if param_not_load: 1560 logger.warning("For 'load_param_into_net', " 1561 "{} parameters in the 'net' are not loaded, because they are not in the " 1562 "'parameter_dict', please check whether the network structure is consistent " 1563 "when training and loading checkpoint.".format(len(param_not_load))) 1564 logger.warning("{} are not loaded.".format(param_not_load)) 1565 if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None: 1566 param_layout = net.parameter_layout_dict 1567 param_redundancy = get_parameter_redundancy(param_layout) 1568 remove_param_redundancy_dict = remove_param_redundancy(param_redundancy) 1569 target_parameter_name_set = set(parameter_dict.keys()) 1570 for rank_id, param_name_set in remove_param_redundancy_dict: 1571 if param_name_set == target_parameter_name_set: 1572 parameter_broadcast(net, param_layout, rank_id) 1573 return param_not_load, ckpt_not_load 1574 1575 1576def _warm_up_host_cache_enabled(parameter_dict): 1577 """Warm up host cache enabled.""" 1578 if _cache_enable(): 1579 return True 1580 for key in parameter_dict.keys(): 1581 if key.find(".__param_key__") != -1: 1582 return True 1583 return False 1584 1585 1586def _warm_up_host_cache(parameter_dict, net): 1587 """Warm up host cache.""" 1588 ms_role = os.getenv("MS_ROLE") 1589 is_worker = ms_role == "MS_WORKER" 1590 param_key_dict = {} 1591 # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict. 1592 if is_worker: 1593 net.init_parameters_data() 1594 net_dict = {} 1595 for name, value in net.parameters_and_names(): 1596 net_dict[name] = value 1597 for param_name, value in parameter_dict.items(): 1598 pos = param_name.find(".__param_key__") 1599 if pos != -1: 1600 net_param_name = param_name[:pos] 1601 param_key_dict[param_name] = net_param_name 1602 net_value = None 1603 if net_param_name not in net_dict: 1604 logger.warning("net param name : %s is not in net", net_param_name) 1605 else: 1606 net_value = net_dict.get(net_param_name, None) 1607 pos += len(".__param_key__") 1608 param_key = int(param_name[pos:]) 1609 value_is_map_parameter = isinstance(value, list) and len(value) == 3 1610 if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)): 1611 key_tensor = Tensor.from_numpy(value[0]) 1612 value_tensor = Tensor.from_numpy(value[1]) 1613 status_tensor = Tensor.from_numpy(value[2]) 1614 _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor) 1615 elif not isinstance(value, list) and isinstance(net_value, Parameter): 1616 _store_warm_up_ptr_by_tensor(param_key, value) 1617 else: 1618 logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value)) 1619 else: 1620 for param_name, value in parameter_dict.items(): 1621 pos = param_name.find(".__param_key__") 1622 if pos != -1: 1623 net_param_name = param_name[:pos] 1624 param_key_dict[param_name] = net_param_name 1625 # Split param key from parameter_dict since worker cannot load param key. 1626 warm_up_dict = {} 1627 for key, value in param_key_dict.items(): 1628 if is_worker: 1629 warm_up_dict[value] = parameter_dict.pop(key) 1630 else: 1631 parameter_dict[value] = parameter_dict.pop(key) 1632 return (is_worker, parameter_dict, warm_up_dict) 1633 1634 1635def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict): 1636 """Warm up host cache post process.""" 1637 if is_worker: 1638 net_dict.update(warm_up_dict) 1639 _set_checkpoint_load_status(True) 1640 1641 1642def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load): 1643 """When some net parameter did not load, try to continue loading.""" 1644 prefix_name = "" 1645 longest_name = param_not_load[0] 1646 while prefix_name != longest_name and param_not_load: 1647 logger.debug("Count: {} parameters has not been loaded, try to continue loading.".format(len(param_not_load))) 1648 prefix_name = longest_name 1649 for net_param_name in param_not_load: 1650 for dict_name in parameter_dict: 1651 if dict_name.endswith(net_param_name): 1652 prefix_name = dict_name[:-len(net_param_name)] 1653 break 1654 if prefix_name != longest_name: 1655 break 1656 1657 if prefix_name != longest_name: 1658 logger.warning(f"For 'load_param_into_net', remove parameter prefix name: {prefix_name}," 1659 f" continue to load.") 1660 for _, param in net.parameters_and_names(): 1661 new_param_name = prefix_name + param.name 1662 if param.name in param_not_load and new_param_name in parameter_dict: 1663 new_param = parameter_dict[new_param_name] 1664 _update_param(param, new_param, strict_load) 1665 param_not_load.remove(param.name) 1666 1667 1668def _save_graph(network, file_name): 1669 """ 1670 Saves the graph of network to a file. 1671 1672 Args: 1673 network (Cell): Obtain a pipeline through network for saving graph. 1674 file_name (str): Graph file name into which the graph will be saved. 1675 """ 1676 logger.info("Execute the process of saving graph.") 1677 1678 file_name = os.path.abspath(file_name) 1679 graph_pb = network.get_func_graph_proto() 1680 if graph_pb: 1681 with open(file_name, "wb") as f: 1682 os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR) 1683 f.write(graph_pb) 1684 1685 1686def _reshape_tensor(tensor, dst_shape): 1687 """reshape tensor to dst shape""" 1688 np_tensor = tensor.asnumpy() 1689 np_tensor = np_tensor.reshape(dst_shape) 1690 return Tensor(np_tensor, tensor.dtype) 1691 1692 1693def _check_param_for_integrate_save(pipeline_stages, uniform_split): 1694 """check whether current settings and parameters are supported in integrated save checkpoint mode""" 1695 if pipeline_stages > 1: 1696 raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.") 1697 if uniform_split == 0: 1698 raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set " 1699 "'integrated_save' to True, the checkpoint will be integrated save, it " 1700 "is only supports uniform split tensor now.") 1701 1702 1703def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save): 1704 """ 1705 Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. 1706 1707 Args: 1708 net (Cell): MindSpore network. 1709 param_name (str): The parameter name, which to be combined. 1710 param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data. 1711 integrated_save (bool): Whether to integrated save in automatic model parallel scene. 1712 Returns: 1713 Tensor, the combined tensor which with the whole data value. 1714 """ 1715 layout = parameter_layout_dict[param_name] 1716 if len(layout) < 8: 1717 logger.info("The layout dict does not contain the key %s", param_name) 1718 return param_data 1719 1720 dev_mat = layout[0] 1721 tensor_map = layout[1] 1722 uniform_split = layout[4] 1723 opt_shard_group = layout[5] 1724 before_reshape_slice_shape = layout[2] 1725 before_reshape_full_shape = layout[6] 1726 after_reshape_slice_shape = layout[7] 1727 do_reshape = False 1728 if before_reshape_full_shape and after_reshape_slice_shape \ 1729 and after_reshape_slice_shape != before_reshape_slice_shape: 1730 do_reshape = True 1731 1732 allgather_net = None 1733 mp_weight = False 1734 for dim in tensor_map: 1735 if dim != -1: 1736 mp_weight = True 1737 break 1738 if param_name in net.parallel_parameter_merge_net_dict: 1739 allgather_net = net.parallel_parameter_merge_net_dict[param_name] 1740 else: 1741 logger.info("Need to create allgather net for %s", param_name) 1742 if integrated_save: 1743 _check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split) 1744 # while any dim is not equal to -1, means param is split and needs to be merged 1745 # pipeline parallel need to be supported here later 1746 if mp_weight: 1747 allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape, 1748 tuple(after_reshape_slice_shape)) 1749 object.__setattr__(allgather_net, "keep_input_unchanged", True) 1750 elif opt_shard_group: 1751 allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape, 1752 tuple(after_reshape_slice_shape)) 1753 elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"): 1754 allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape, 1755 tuple(after_reshape_slice_shape)) 1756 net.parallel_parameter_merge_net_dict[param_name] = allgather_net 1757 if allgather_net: 1758 param_data = allgather_net(param_data) 1759 if mp_weight and integrated_save: 1760 param_data = _reshape_param_data(param_data, dev_mat, tensor_map) 1761 if do_reshape: 1762 param_data = _reshape_tensor(param_data, before_reshape_full_shape) 1763 return param_data 1764 1765 1766def export(net, *inputs, file_name, file_format, **kwargs): 1767 """ 1768 Export the MindSpore network into an offline model in the specified format. 1769 1770 Note: 1771 1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB. 1772 2. When `file_name` does not have a suffix, the system will automatically add one 1773 according to the `file_format`. 1774 3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported. 1775 4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve 1776 class properties in calculations. 1777 5. AIR format is deprecated, and will be removed in a future version, please use other format or use 1778 MindSpore Lite to do offline inference. 1779 1780 Args: 1781 net (Union[Cell, function]): MindSpore network. 1782 inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs 1783 of the `net`, if the network has multiple inputs, set them together. While its type is Dataset, 1784 it represents the preprocess behavior of the `net`, data preprocess operations will be serialized. 1785 In second situation, you should adjust batch size of dataset script manually which will impact on 1786 the batch size of 'net' input. Only supports parse "image" column from dataset currently. 1787 file_name (str): File name of the model to be exported. 1788 file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. 1789 1790 - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model. 1791 - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. 1792 - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format 1793 for MindSpore models. 1794 1795 kwargs (dict): Configuration options dictionary. 1796 1797 - enc_key (byte): Byte-type key used for encryption. The valid length is 16, 24, or 32. 1798 - enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set. 1799 1800 - For 'AIR' and 'ONNX' models, only customized encryption is supported. 1801 - For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' 1802 or Customized encryption. 1803 Default: ``'AES-GCM'``. 1804 - For details of using the customized encryption, please check the `tutorial 1805 <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_. 1806 1807 - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the 1808 preprocessing of the dataset into MindIR. 1809 1810 - obf_config (dict): obfuscation config. 1811 1812 - type (str): The type of obfuscation, only 'dynamic' is supported until now. 1813 - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` 1814 should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are 1815 correspond to 0.1, 0.3, and 0.6 respectively. 1816 - customized_func (function): A python function used for customized function mode, which used for control 1817 the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const ( 1818 Reference to 'my_func()' in 1819 `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_). 1820 This function needs to ensure that its result is constant for any input. Users can refer to opaque 1821 predicates. If customized_func is set, then it should be passed to `load()` interface when loading 1822 obfuscated model. 1823 - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The 1824 structure of obfuscated models corresponding to different random seeds is different. If 1825 `obf_random_seed` is set, then it should be passed 1826 to :class:`mindspore.nn.GraphCell` interface when loading 1827 obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should 1828 be set, and the latter mode would be applied if both of them are set. 1829 1830 - incremental (bool): export MindIR incrementally. 1831 1832 - custom_func (function): Functions for custom defined export policies. This function will be used to 1833 customize the model during network export. Currently only support for files with mindir format. The 1834 function only accepts one input representing the proto object of the mindir file. When modifying a model, 1835 it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading 1836 failure or functional errors. Default: ``None`` . 1837 1838 Examples: 1839 >>> import mindspore as ms 1840 >>> import numpy as np 1841 >>> from mindspore import Tensor 1842 >>> 1843 >>> # Define the network structure of LeNet5. Refer to 1844 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 1845 >>> net = LeNet5() 1846 >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) 1847 >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR') 1848 >>> 1849 >>> # Export model in MindIR format and modified the model info using custom_func 1850 >>> # The custom_func only support one input representing the Proto object of the model 1851 >>> # And custom_func does not support return value 1852 >>> def _custom_func(mindir_model): 1853 ... mindir_model.producer_name = "test11111" 1854 ... mindir_model.producer_version = "11.0" 1855 ... mindir_model.user_info["version"] = "11.0" 1856 >>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func) 1857 1858 1859 Tutorial Examples: 1860 - `Saving and Loading the Model - Saving and Loading MindIR 1861 <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_ 1862 """ 1863 old_ms_jit_value = context.get_context("jit_syntax_level") 1864 context.set_context(jit_syntax_level=mindspore.STRICT) 1865 1866 supported_formats = ['AIR', 'ONNX', 'MINDIR'] 1867 if file_format not in supported_formats: 1868 raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.") 1869 if file_format == 'AIR': 1870 logger.warning("AIR format is deprecated, and will be removed in a future version, please use other format or " 1871 "use MindSpore Lite to do offline inference") 1872 Validator.check_file_name_by_regular(file_name) 1873 logger.info("exporting model file:%s format:%s.", file_name, file_format) 1874 1875 if check_input_dataset(*inputs, dataset_type=mindspore.dataset.Dataset): 1876 if len(inputs) != 1: 1877 raise RuntimeError(f"You can only serialize one dataset into MindIR, got " + str(len(inputs)) + " datasets") 1878 shapes, types, columns = inputs[0].output_shapes(), inputs[0].output_types(), inputs[0].get_col_names() 1879 kwargs['dataset'] = inputs[0] 1880 only_support_col = "image" 1881 1882 inputs_col = list() 1883 for c, s, t in zip(columns, shapes, types): 1884 if only_support_col != c: 1885 continue 1886 inputs_col.append(Tensor(np.random.uniform(-1.0, 1.0, size=s).astype(t))) 1887 if not inputs_col: 1888 raise RuntimeError(f"Only supports parse \"image\" column from dataset now, given dataset has columns: " 1889 + str(columns)) 1890 inputs = tuple(inputs_col) 1891 1892 file_name = os.path.abspath(file_name) 1893 if 'enc_key' in kwargs.keys(): 1894 kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs) 1895 _export(net, file_name, file_format, *inputs, **kwargs) 1896 1897 context.set_context(jit_syntax_level=old_ms_jit_value) 1898 1899 1900def _get_funcgraph(net, *inputs): 1901 """ 1902 Compile the MindSpore network and get FuncGraph. 1903 1904 Arg: 1905 net (Union[Cell, function]): MindSpore network. 1906 inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs 1907 of the `net`, if the network has multiple inputs, set them together. While its type is Dataset, 1908 it represents the preprocess behavior of the `net`, data preprocess operations will be serialized. 1909 In second situation, you should adjust batch size of dataset script manually which will impact on 1910 the batch size of 'net' input. Only supports parse "image" column from dataset currently. 1911 1912 Returns: 1913 FuncGraph, a mindspore._c_expression.FuncGraph obj. 1914 1915 Raises: 1916 ValueError: input `net` is not a nn.Cell. 1917 1918 Examples: 1919 >>> import mindspore as ms 1920 >>> import numpy as np 1921 >>> from mindspore import Tensor 1922 >>> 1923 >>> # Define the network structure of LeNet5. Refer to 1924 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 1925 >>> net = LeNet5() 1926 >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) 1927 >>> ms.get_funcgraph(net, input_tensor) 1928 1929 """ 1930 if not isinstance(net, nn.Cell): 1931 raise ValueError(f"For get_funcgraph's parameter 'net', currently only support Cell right now.") 1932 phase_name = "lite_infer_predict" if _is_in_auto_parallel_mode() else "lite_infer_get_func_graph" 1933 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) 1934 # pylint: disable=protected-access 1935 func_graph = _executor._get_func_graph(net, graph_id) 1936 return func_graph 1937 1938 1939def _export(net, file_name, file_format, *inputs, **kwargs): 1940 """ 1941 It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format. 1942 """ 1943 logger.info("exporting model file:%s format:%s.", file_name, file_format) 1944 if "obf_config" in kwargs and file_format != "MINDIR": 1945 raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.") 1946 if "custom_func" in kwargs and file_format != "MINDIR": 1947 raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.") 1948 if file_format == 'AIR': 1949 _save_air(net, file_name, *inputs, **kwargs) 1950 elif file_format == 'ONNX': 1951 _save_onnx(net, file_name, *inputs, **kwargs) 1952 elif file_format == 'MINDIR': 1953 _save_mindir(net, file_name, *inputs, **kwargs) 1954 1955 1956def _check_key_mode_type(file_format, **kwargs): 1957 """check enc_key and enc_mode are valid""" 1958 enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes) 1959 enc_mode = kwargs.get('enc_mode') 1960 1961 if callable(enc_mode): 1962 return enc_key, enc_mode 1963 1964 enc_mode = 'AES-GCM' 1965 if 'enc_mode' in kwargs.keys(): 1966 enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str) 1967 1968 if file_format in ('AIR', 'ONNX'): 1969 raise ValueError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.") 1970 1971 if enc_mode in ('AES-CBC', 'AES-GCM', 'SM4-CBC'): 1972 return enc_key, enc_mode 1973 raise ValueError(f"MindIR only support AES-GCM/AES-CBC/SM4-CBC encryption, but got {enc_mode}") 1974 1975 1976def _save_air(net, file_name, *inputs, **kwargs): 1977 """Save AIR format file.""" 1978 phase_name = 'export.air' 1979 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) 1980 if not file_name.endswith('.air'): 1981 file_name += ".air" 1982 if os.path.exists(file_name): 1983 os.chmod(file_name, stat.S_IWUSR) 1984 if "/" in file_name: 1985 real_path = os.path.abspath(file_name[:file_name.rfind("/")]) 1986 os.makedirs(real_path, exist_ok=True) 1987 if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys(): 1988 _executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode')) 1989 else: 1990 _executor.export(file_name, graph_id) 1991 os.chmod(file_name, stat.S_IRUSR) 1992 1993 1994def _save_onnx(net, file_name, *inputs, **kwargs): 1995 """Save ONNX format file.""" 1996 # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction) 1997 if not isinstance(net, nn.Cell): 1998 raise ValueError(f"Export ONNX format model only support nn.Cell object, but got {type(net)}.") 1999 _check_dynamic_input(inputs) 2000 cell_mode = net.training 2001 net.set_train(mode=False) 2002 total_size = _calculation_net_size(net) 2003 if total_size > PROTO_LIMIT_SIZE: 2004 raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.' 2005 .format(total_size / 1024 / 1024, PROTO_LIMIT_SIZE / 1024 / 1024)) 2006 phase_name = 'export.onnx' 2007 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) 2008 onnx_stream = _executor._get_func_graph_proto(net, graph_id) 2009 if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys(): 2010 enc_mode = kwargs.get('enc_mode') 2011 onnx_stream = enc_mode(onnx_stream, kwargs.get('enc_key')) 2012 if not file_name.endswith('.onnx'): 2013 file_name += ".onnx" 2014 if os.path.exists(file_name): 2015 os.chmod(file_name, stat.S_IWUSR) 2016 with open(file_name, 'wb') as f: 2017 f.write(onnx_stream) 2018 os.chmod(file_name, stat.S_IRUSR) 2019 net.set_train(mode=cell_mode) 2020 2021 2022def _check_dynamic_input(inputs): 2023 for ele in inputs: 2024 if isinstance(ele, Tensor) and -1 in ele.shape: 2025 raise ValueError(f"Export ONNX format model not support dynamic shape mode.") 2026 2027 2028def _generate_front_info_for_param_data_file(is_encrypt, kwargs): 2029 front_info = bytes() 2030 check_code = sys.byteorder == "little" 2031 front_info += check_code.to_bytes(1, byteorder=sys.byteorder) 2032 front_info += bytes(63) 2033 if is_encrypt(): 2034 front_info = _encrypt(front_info, len(front_info), kwargs.get('enc_key'), 2035 len(kwargs.get('enc_key')), kwargs.get('enc_mode')) 2036 return front_info 2037 2038 2039def _change_file(f, dirname, external_local, is_encrypt, kwargs): 2040 """Change to another file to write parameter data.""" 2041 # The parameter has been not written in the file 2042 front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) 2043 f.seek(0, 0) 2044 f.write(front_info) 2045 f.close() 2046 ori_data_file_name = f.name 2047 os.chmod(ori_data_file_name, stat.S_IRUSR) 2048 if os.path.getsize(ori_data_file_name) == 64: 2049 raise RuntimeError("The parameter size is exceed 1T,cannot export to the file") 2050 data_file_name = os.path.join(dirname, external_local) 2051 return _get_data_file(is_encrypt, kwargs, data_file_name) 2052 2053 2054def _get_data_file(is_encrypt, kwargs, data_file_name): 2055 """Get Data File to write parameter data.""" 2056 # Reserves 64 bytes as spare information such as check data 2057 offset = 64 2058 if os.path.exists(data_file_name): 2059 os.chmod(data_file_name, stat.S_IWUSR) 2060 2061 place_holder_data = bytes(offset) 2062 if is_encrypt(): 2063 place_holder_data = _encrypt(place_holder_data, len(place_holder_data), kwargs["enc_key"], 2064 len(kwargs["enc_key"]), kwargs["enc_mode"]) 2065 parameter_size = (offset / 1024) 2066 try: 2067 f = open(data_file_name, "wb") 2068 f.write(place_holder_data) 2069 except IOError: 2070 f.close() 2071 2072 return f, parameter_size, offset 2073 2074 2075def _encrypt_data(is_encrypt, write_data, kwargs): 2076 """Encrypt parameter data.""" 2077 if is_encrypt(): 2078 if callable(kwargs.get('enc_mode')): 2079 enc_func = kwargs.get('enc_mode') 2080 write_data = enc_func(write_data, kwargs.get('enc_key')) 2081 else: 2082 write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'), 2083 len(kwargs.get('enc_key')), kwargs.get('enc_mode')) 2084 return write_data 2085 2086 2087def _split_save(net_dict, model, file_name, is_encrypt, **kwargs): 2088 """The function to save parameter data.""" 2089 logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") 2090 # save parameter 2091 if model.graph.map_parameter: 2092 raise ValueError("MapParameter not support save in split MindIR file now.") 2093 file_prefix = file_name.split("/")[-1] 2094 if file_prefix.endswith(".mindir"): 2095 file_prefix = file_prefix[:-7] 2096 current_path = os.path.abspath(file_name) 2097 dirname = os.path.dirname(current_path) 2098 data_path = os.path.join(dirname, file_prefix + "_variables") 2099 if os.path.exists(data_path): 2100 shutil.rmtree(data_path) 2101 os.makedirs(data_path, exist_ok=True) 2102 os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 2103 index = 0 2104 external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) 2105 data_file_name = os.path.join(dirname, external_local) 2106 f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name) 2107 try: 2108 round = 0 2109 names = [] 2110 for param_proto in model.graph.parameter: 2111 name = param_proto.name[param_proto.name.find(":") + 1:] 2112 names.append((name, param_proto)) 2113 names.sort(key=lambda x: x[0]) 2114 for pairs in names: 2115 name = pairs[0] 2116 param_proto = pairs[1] 2117 param = net_dict[name] 2118 raw_data = param.data.get_bytes() 2119 data_length = len(raw_data) 2120 append_size = 0 2121 if data_length % 64 != 0: 2122 append_size = 64 - (data_length % 64) 2123 parameter_size += ((append_size + data_length) / 1024) 2124 if parameter_size > PARAMETER_SPLIT_SIZE: 2125 index += 1 2126 external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) 2127 f, parameter_size, offset = _change_file(f, dirname, external_local, is_encrypt, kwargs) 2128 parameter_size += ((append_size + data_length) / 1024) 2129 param_proto.external_data.location = external_local 2130 param_proto.external_data.length = data_length 2131 param_proto.external_data.offset = offset 2132 write_data = raw_data + bytes(append_size) 2133 offset += (data_length + append_size) 2134 write_data = _encrypt_data(is_encrypt, write_data, kwargs) 2135 f.write(write_data) 2136 round += 1 2137 logger.debug(f"writing {round}th split data, name:{name}") 2138 2139 graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") 2140 if os.path.exists(graph_file_name): 2141 os.chmod(graph_file_name, stat.S_IWUSR) 2142 with open(graph_file_name, 'wb') as model_file: 2143 os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) 2144 model_string = model.SerializeToString() 2145 if is_encrypt(): 2146 model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'), 2147 len(kwargs.get('enc_key')), kwargs.get('enc_mode')) 2148 model_file.write(model_string) 2149 os.chmod(graph_file_name, stat.S_IRUSR) 2150 2151 front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) 2152 f.seek(0, 0) 2153 f.write(front_info) 2154 finally: 2155 f.close() 2156 os.chmod(data_file_name, stat.S_IRUSR) 2157 2158 2159def _msfunc_info(net, *inputs): 2160 """Get mindir stream and parameter dict of ms_function""" 2161 # pylint: disable=protected-access 2162 net_dict = OrderedDict() 2163 _ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9) 2164 graph_id = _ms_func_executor.compile(net.__name__, *inputs) 2165 mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') 2166 params = _ms_func_executor._graph_executor.get_params(graph_id) 2167 for name, value in params.items(): 2168 net_dict[name] = Parameter(value, name=name) 2169 return mindir_stream, net_dict 2170 2171 2172def _cell_info(net, incremental, *inputs): 2173 """Get mindir stream and net dict of cell""" 2174 phase_name = "export.mindir" 2175 graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) 2176 # pylint: disable=protected-access 2177 mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental) 2178 # clean obfuscation config to prevent the next call 2179 _executor.obfuscate_config = None 2180 2181 net_dict = net.parameters_dict() 2182 return mindir_stream, net_dict 2183 2184 2185def _set_obfuscate_config(**kwargs): 2186 """Set obfuscation config for executor.""" 2187 logger.warning("Obfuscate model.") 2188 if 'enc_mode' in kwargs.keys(): 2189 enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str) 2190 if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]: 2191 raise ValueError( 2192 "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for" 2193 "obfuscation, but got {}.".format(enc_mode)) 2194 obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config')) 2195 if customized_funcs and obf_random_seed > 0: 2196 logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be" 2197 " applied, remember to set 'obf_random_seed' when loading obfuscated model.") 2198 2199 if obf_random_seed == 0: # apply customized_func mode 2200 device_target = context.get_context('device_target') 2201 if device_target in ["GPU", "Ascend"]: 2202 raise ValueError( 2203 "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target)) 2204 clean_funcs() 2205 for func in customized_funcs: 2206 add_opaque_predicate(func.__name__, func) 2207 _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed} 2208 2209 2210def _save_mindir(net, file_name, *inputs, **kwargs): 2211 """Save MindIR format file.""" 2212 # set obfuscate configs 2213 if 'obf_config' in kwargs.keys(): 2214 _set_obfuscate_config(**kwargs) 2215 for item in inputs: 2216 if -1 in item.shape: 2217 raise ValueError( 2218 "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape)) 2219 2220 incremental = kwargs.get('incremental', False) 2221 2222 model = mindir_model() 2223 if not isinstance(net, nn.Cell): 2224 mindir_stream, net_dict = _msfunc_info(net, *inputs) 2225 else: 2226 mindir_stream, net_dict = _cell_info(net, incremental, *inputs) 2227 model.ParseFromString(mindir_stream) 2228 2229 if kwargs.get('dataset'): 2230 check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset) 2231 dataset = kwargs.get('dataset') 2232 _save_dataset_to_mindir(model, dataset) 2233 2234 custom_func = kwargs.get('custom_func', None) 2235 if custom_func is not None: 2236 custom_func(model) 2237 2238 save_together = _save_together(net_dict, model) 2239 is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys() 2240 if save_together: 2241 _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs) 2242 else: 2243 _split_save(net_dict, model, file_name, is_encrypt, **kwargs) 2244 2245 2246def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs): 2247 """Save graph and parameter together.""" 2248 for param_proto in model.graph.parameter: 2249 param_name = param_proto.name[param_proto.name.find(":") + 1:] 2250 if param_name in net_dict.keys(): 2251 param_data = net_dict[param_name].data.get_bytes() 2252 param_proto.raw_data = param_data 2253 else: 2254 raise ValueError("The parameter '{}' is not belongs to any cell," 2255 "the data of parameter cannot be exported.".format(param_proto.name)) 2256 incremental = kwargs.get('incremental', False) 2257 for map_param_proto in model.graph.map_parameter: 2258 map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:] 2259 if map_param_name in net_dict.keys(): 2260 map_parameter = net_dict[map_param_name] 2261 key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental) 2262 map_param_proto.key_tensor.raw_data = key_bytes 2263 map_param_proto.value_tensor.raw_data = value_bytes 2264 map_param_proto.status_tensor.raw_data = status_bytes 2265 else: 2266 raise ValueError("The map_parameter '{}' is not belongs to any cell," 2267 "the data of parameter cannot be exported.".format(map_param_proto.name)) 2268 if not file_name.endswith('.mindir'): 2269 file_name += ".mindir" 2270 current_path = os.path.abspath(file_name) 2271 dirname = os.path.dirname(current_path) 2272 os.makedirs(dirname, exist_ok=True) 2273 if os.path.exists(file_name): 2274 os.chmod(file_name, stat.S_IWUSR) 2275 with open(file_name, 'wb') as f: 2276 os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR) 2277 model_string = model.SerializeToString() 2278 if is_encrypt(): 2279 if callable(kwargs.get('enc_mode')): 2280 enc_func = kwargs.get('enc_mode') 2281 model_string = enc_func(model_string, kwargs.get('enc_key')) 2282 else: 2283 model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'), 2284 len(kwargs.get('enc_key')), kwargs.get('enc_mode')) 2285 f.write(model_string) 2286 os.chmod(file_name, stat.S_IRUSR) 2287 2288 2289def _save_together(net_dict, model): 2290 """Whether graph and parameter save together during save mindir model.""" 2291 data_total = 0 2292 for param_proto in model.graph.parameter: 2293 name = param_proto.name[param_proto.name.find(":") + 1:] 2294 if name in net_dict.keys(): 2295 data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024 2296 else: 2297 raise ValueError("The parameter '{}' is not belongs to any cell," 2298 "the data of parameter cannot be exported.".format(param_proto.name)) 2299 if data_total > TOTAL_SAVE: 2300 return False 2301 return True 2302 2303 2304def _save_dataset_to_mindir(model, dataset): 2305 """Save dataset preprocess operations into mindir model.""" 2306 dataset_json = dataset.to_json() 2307 reverse_dataset = [] 2308 while dataset_json: 2309 reverse_dataset = [dataset_json] + reverse_dataset 2310 if len(dataset_json['children']) > 1: 2311 logger.warning("Need to support dataset_node with more than one child, using child 0 as default.") 2312 dataset_json = dataset_json['children'][0] if dataset_json['children'] else [] 2313 2314 for op in reverse_dataset: 2315 if op['op_type'] == 'Map': 2316 model.preprocessor.op.add() 2317 model.preprocessor.op[-1].input_columns = json.dumps(op['input_columns']) 2318 model.preprocessor.op[-1].output_columns = json.dumps(op['output_columns']) 2319 model.preprocessor.op[-1].op_type = json.dumps(op['op_type']) 2320 model.preprocessor.op[-1].operations = json.dumps(op['operations']) 2321 model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False 2322 2323 2324def check_checkpoint(ckpt_file_name): 2325 """ 2326 Check whether the checkpoint is valid. 2327 2328 Args: 2329 ckpt_file_name (str): Checkpoint file name. 2330 2331 Returns: 2332 bool, whether the checkpoint is valid. 2333 2334 Examples: 2335 >>> import mindspore as ms 2336 >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" 2337 >>> check_result = ms.check_checkpoint(ckpt_file_name) 2338 >>> print(check_result) 2339 True 2340 """ 2341 if not ckpt_file_name.endswith('.ckpt'): 2342 return False 2343 checkpoint_list = Checkpoint() 2344 with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f: 2345 pb_content = f.read() 2346 if pb_content[-17:-10] == b"crc_num": 2347 crc_num_bytes = pb_content[-10:] 2348 pb_content = pb_content[:-17] 2349 crc_num = int.from_bytes(crc_num_bytes, byteorder='big') 2350 cal_crc_num = binascii.crc32(pb_content, 0) 2351 if cal_crc_num != crc_num: 2352 logger.warning("For 'check_checkpoint', the ckpt crc check is failed.") 2353 return False 2354 try: 2355 checkpoint_list.ParseFromString(pb_content) 2356 except google.protobuf.message.DecodeError as e: 2357 logger.warning("For 'check_checkpoint', the ckpt parse is failed.") 2358 logger.warning(e) 2359 return False 2360 return True 2361 2362 2363def parse_print(print_file_name): 2364 """ 2365 Parse data file generated by :class:`mindspore.ops.Print`. 2366 2367 Args: 2368 print_file_name (str): The file name needs to be parsed. 2369 2370 Returns: 2371 List, element of list is Tensor. 2372 2373 Raises: 2374 ValueError: The print file does not exist or is empty. 2375 RuntimeError: Failed to parse the file. 2376 2377 Examples: 2378 >>> import numpy as np 2379 >>> import mindspore as ms 2380 >>> from mindspore import nn, Tensor, ops 2381 >>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data') 2382 >>> class PrintInputTensor(nn.Cell): 2383 ... def __init__(self): 2384 ... super().__init__() 2385 ... self.print = ops.Print() 2386 ... 2387 ... def construct(self, input_pra): 2388 ... self.print('print:', input_pra) 2389 ... return input_pra 2390 >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32) 2391 >>> input_pra = Tensor(x) 2392 >>> net = PrintInputTensor() 2393 >>> net(input_pra) 2394 >>> 2395 >>> data = ms.parse_print('./log.data') 2396 >>> print(data) 2397 ['print:', Tensor(shape=[2, 4], dtype=Float32, value= 2398 [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], 2399 [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])] 2400 """ 2401 print_file_path = os.path.abspath(print_file_name) 2402 2403 if os.path.getsize(print_file_path) == 0: 2404 raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct " 2405 "'print_file_name'.") 2406 2407 logger.info("Execute load print process.") 2408 print_list = Print() 2409 2410 try: 2411 with open(print_file_path, "rb") as f: 2412 pb_content = f.read() 2413 print_list.ParseFromString(pb_content) 2414 except BaseException as e: 2415 logger.critical("Failed to read the print file %s, please check whether the file is " 2416 "correct.", print_file_name) 2417 raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether " 2418 "the file is correct.".format(print_file_name)) from e 2419 2420 tensor_list = [] 2421 2422 try: 2423 for print_ in print_list.value: 2424 # String type 2425 if print_.HasField("desc"): 2426 tensor_list.append(print_.desc) 2427 elif print_.HasField("tensor"): 2428 dims = print_.tensor.dims 2429 data_type = print_.tensor.tensor_type 2430 data = print_.tensor.tensor_content 2431 np_type = tensor_to_np_type.get(data_type) 2432 param_data = np.fromstring(data, np_type) 2433 ms_type = tensor_to_ms_type.get(data_type) 2434 if dims and dims != [0]: 2435 param_value = param_data.reshape(dims) 2436 tensor_list.append(Tensor(param_value, ms_type)) 2437 # Scalar type 2438 else: 2439 data_type_ = data_type.lower() 2440 if 'float' in data_type_: 2441 param_data = float(param_data[0]) 2442 elif 'int' in data_type_: 2443 param_data = int(param_data[0]) 2444 elif 'bool' in data_type_: 2445 param_data = bool(param_data[0]) 2446 tensor_list.append(Tensor(param_data, ms_type)) 2447 2448 except BaseException as e: 2449 logger.critical("Failed to load the print file %s.", print_list) 2450 raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e 2451 2452 return tensor_list 2453 2454 2455def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): 2456 """ 2457 Merge data slices to one tensor with whole data when strategy is not None. 2458 2459 Args: 2460 sliced_data (list[numpy.ndarray]): Data slices in order of rank_id. 2461 parameter_name (str): Name of parameter. 2462 strategy (dict): Parameter slice strategy. 2463 is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly. 2464 2465 Returns: 2466 Tensor, the merged Tensor which has the whole data. 2467 2468 Raises: 2469 ValueError: Failed to merge. 2470 """ 2471 layout = strategy.get(parameter_name) 2472 try: 2473 dev_mat = list(layout.dev_matrix[0].dim) 2474 tensor_map = list(layout.tensor_map[0].dim) 2475 param_split_shape = list(layout.param_split_shape[0].dim) 2476 field_size = int(layout.field) 2477 except BaseException as e: 2478 raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'" 2479 f", please make sure that 'strategy' is correct.") from e 2480 2481 device_count = 1 2482 for dim in dev_mat: 2483 device_count *= dim 2484 2485 if len(sliced_data) != device_count: 2486 raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to " 2487 f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but " 2488 f"device_count is {device_count}.") 2489 2490 if not param_split_shape: 2491 if not is_even: 2492 raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' " 2493 "should be the same when slice manner is even.") 2494 2495 all_gather_tensor = Tensor(np.concatenate(sliced_data)) 2496 2497 if field_size > 0: 2498 merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size) 2499 else: 2500 merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map) 2501 2502 else: 2503 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 2504 2505 slice_count = 1 2506 for dim in tensor_strategy: 2507 slice_count *= dim 2508 2509 if len(param_split_shape) != slice_count: 2510 raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be " 2511 f"{slice_count}, but got {len(param_split_shape)}.") 2512 2513 tensor_slices_new = list(range(slice_count)) 2514 tensor_slices = sliced_data 2515 for i in range(device_count): 2516 slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)) 2517 if tensor_slices[i].shape[0] != param_split_shape[slice_index]: 2518 raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be " 2519 f"{param_split_shape[slice_index]} in 0 axis, but got " 2520 f"{tensor_slices[i].shape[0]}.") 2521 tensor_slices_new[slice_index] = np.array(tensor_slices[i]) 2522 2523 dim_len = len(tensor_strategy) 2524 for i in range(dim_len): 2525 ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) 2526 tensor_slices_new_inner = [] 2527 for j in range(ele_count): 2528 new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] 2529 for k in range(j * tensor_strategy[dim_len - 1 - i] + 1, 2530 (j + 1) * tensor_strategy[dim_len - 1 - i]): 2531 new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i) 2532 tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) 2533 tensor_slices_new = tensor_slices_new_inner 2534 merged_tensor = Tensor(tensor_slices_new[0]) 2535 2536 return merged_tensor 2537 2538 2539def restore_group_info_list(group_info_file_name): 2540 """ 2541 Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank 2542 who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL 2543 environment variables like "export GROUP_INFO_FILE=/data/group_info.pb". 2544 2545 Args: 2546 group_info_file_name (str): Name of group information file. 2547 2548 Returns: 2549 List, the rank list. 2550 2551 Raises: 2552 ValueError: group information file is incorrect. 2553 TypeError: `group_info_file_name` is not str. 2554 2555 Examples: 2556 >>> import mindspore as ms 2557 >>> ms.restore_list = restore_group_info_list("./group_info.pb") 2558 """ 2559 if not isinstance(group_info_file_name, str): 2560 raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, " 2561 f"but got {type(group_info_file_name)}.") 2562 2563 if not os.path.isfile(group_info_file_name): 2564 raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.") 2565 2566 if os.path.getsize(group_info_file_name) == 0: 2567 raise ValueError("For 'restore_group_info_list', the group information file should not be empty.") 2568 2569 return _restore_group_info_list(group_info_file_name) 2570 2571 2572def build_searched_strategy(strategy_filename): 2573 """ 2574 Build strategy of every parameter in network. Used in the case of distributed inference. 2575 2576 Args: 2577 strategy_filename (str): Name of strategy file. 2578 2579 Returns: 2580 Dict, whose key is parameter name and value is slice strategy of this parameter. 2581 2582 Raises: 2583 ValueError: Strategy file is incorrect. 2584 TypeError: `strategy_filename` is not a string. 2585 2586 Examples: 2587 >>> import mindspore as ms 2588 >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt") 2589 """ 2590 return _build_searched_strategy(strategy_filename) 2591 2592 2593def merge_sliced_parameter(sliced_parameters, strategy=None): 2594 """ 2595 Merge parameter slices into one parameter. Used in the case of distributed inference. 2596 2597 Args: 2598 sliced_parameters (list[Parameter]): Parameter slices in order of rank id. 2599 strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and 2600 value is slice strategy of this parameter. If strategy is None, just merge 2601 parameter slices in 0 axis order. Default: ``None``. 2602 2603 Returns: 2604 Parameter, the merged parameter which has the whole data. 2605 2606 Raises: 2607 ValueError: Failed to merge. 2608 TypeError: The sliced_parameters is incorrect or strategy is not dict. 2609 KeyError: The parameter name is not in keys of strategy. 2610 2611 Examples: 2612 >>> import numpy as np 2613 >>> import mindspore as ms 2614 >>> from mindspore import Tensor, Parameter 2615 >>> 2616 >>> sliced_parameters = [ 2617 ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), 2618 ... "network.embedding_table"), 2619 ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), 2620 ... "network.embedding_table"), 2621 ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), 2622 ... "network.embedding_table"), 2623 ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), 2624 ... "network.embedding_table")] 2625 >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters) 2626 >>> print(merged_parameter) 2627 Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True) 2628 """ 2629 if not isinstance(sliced_parameters, list): 2630 raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, " 2631 f"but got {type(sliced_parameters)}.") 2632 2633 if not sliced_parameters: 2634 raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.") 2635 2636 if strategy and not isinstance(strategy, dict): 2637 raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, " 2638 f"but got {type(strategy)}.") 2639 2640 try: 2641 parameter_name = sliced_parameters[0].name 2642 parameter_shape = sliced_parameters[0].data.shape 2643 parameter_shape_length = len(parameter_shape) 2644 except BaseException as e: 2645 raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be " 2646 f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e 2647 2648 is_even = True 2649 for index, parameter in enumerate(sliced_parameters): 2650 if not isinstance(parameter, Parameter): 2651 raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', " 2652 f"but got {type(parameter)} at index {index}.") 2653 2654 if parameter.name != parameter_name \ 2655 or len(parameter.data.shape) != parameter_shape_length \ 2656 or parameter.data.shape[1:] != parameter_shape[1:]: 2657 raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'" 2658 f" have the same name, dimension length and shape except 0 axis. The name, dimension " 2659 f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, " 2660 f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: " 2661 f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} " 2662 f"at index {index}.") 2663 2664 if parameter.data.shape != parameter_shape: 2665 is_even = False 2666 2667 layerwise_parallel = sliced_parameters[0].layerwise_parallel 2668 requires_grad = sliced_parameters[0].requires_grad 2669 sliced_data = [] 2670 for parameter in sliced_parameters: 2671 if parameter.data.dtype == mstype.bfloat16: 2672 sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy()) 2673 else: 2674 sliced_data.append(parameter.data.asnumpy()) 2675 2676 if not strategy: 2677 merged_tensor = Tensor(np.concatenate(sliced_data)) 2678 merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) 2679 2680 else: 2681 if parameter_name not in strategy.keys(): 2682 raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in " 2683 f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.") 2684 merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even) 2685 merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) 2686 2687 return merged_parameter 2688 2689 2690def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, 2691 train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'): 2692 """ 2693 Load checkpoint into net for distributed predication. Used in the case of distributed inference. 2694 2695 Args: 2696 network (Cell): Network for distributed predication. 2697 checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. 2698 predict_strategy (dict): Strategy of predication process. It means that using one device to predict 2699 when setting predict_strategy as None. Default: ``None`` . 2700 train_strategy_filename (str): The filename of training strategy protocol buffer file. 2701 When train_strategy_filename is None, the training strategy file will be 2702 obtained from context.get_auto_parallel_context("strategy_ckpt_load_file"). 2703 Therefore, the training strategy file needs to be specified 2704 in at least one of them. Default: ``None`` . 2705 strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter 2706 into net when parameter name's suffix in checkpoint file is the same as the 2707 parameter in the network. When the types are inconsistent, perform type conversion 2708 on the parameters of the same type, such as float32 to float16. Default: ``False`` . 2709 dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption 2710 is not required. Default: ``None`` . 2711 dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption 2712 mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` . 2713 Default: ``'AES-GCM'`` . 2714 2715 Raises: 2716 TypeError: The type of inputs do not match the requirements. 2717 ValueError: Failed to load checkpoint into net. 2718 2719 Supported Platforms: 2720 ``Ascend`` ``GPU`` 2721 2722 Examples: 2723 .. note:: 2724 Before running the following examples, you need to configure the communication environment variables. 2725 2726 For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. 2727 Please see the `rank table startup 2728 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_ 2729 for more details. 2730 2731 For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup 2732 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ . 2733 2734 For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster 2735 Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ . 2736 2737 >>> import os 2738 >>> import numpy as np 2739 >>> import mindspore as ms 2740 >>> import mindspore.dataset as ds 2741 >>> from mindspore import nn, ops, train 2742 >>> from mindspore.communication import init 2743 >>> 2744 >>> step_per_epoch = 4 2745 >>> device_num = 8 2746 >>> 2747 >>> # Define the network structure. 2748 >>> class Net(nn.Cell): 2749 ... def __init__(self, matmul_size, strategy=None): 2750 ... super().__init__() 2751 ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32) 2752 ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np)) 2753 ... self.matmul = ops.MatMul() 2754 ... self.neg = ops.Neg() 2755 ... if strategy is not None: 2756 ... self.matmul.shard(strategy) 2757 ... 2758 ... def construct(self, inputs): 2759 ... x = self.matmul(inputs, self.matmul_weight) 2760 ... x = self.neg(x) 2761 ... return x 2762 >>> 2763 >>> # Create dataset. 2764 >>> def get_dataset(*inputs): 2765 ... def generate(): 2766 ... for _ in range(step_per_epoch): 2767 ... yield inputs 2768 ... return generate 2769 >>> 2770 >>> # Train network and save distributed checkpoint. 2771 >>> def train_net(): 2772 ... ms.set_context(mode=ms.GRAPH_MODE) 2773 ... init() 2774 ... np.random.seed(1) 2775 ... input_data = np.random.rand(16, 96).astype(np.float32) 2776 ... label_data = np.random.rand(16, 16).astype(np.float32) 2777 ... fake_dataset = get_dataset(input_data, label_data) 2778 ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]) 2779 ... 2780 ... # Set parallel strategy. 2781 ... strategy = ((1, 4), (4, 1)) 2782 ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, 2783 ... strategy_ckpt_save_file="./train_strategy.ckpt") 2784 ... network = Net(matmul_size=(96, 16), strategy=strategy) 2785 ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 2786 ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") 2787 ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt) 2788 ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False) 2789 ... global_rank_id = int(os.getenv("RANK_ID")) 2790 ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id) 2791 ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) 2792 ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) 2793 ... ms.reset_auto_parallel_context() 2794 >>> 2795 >>> # Load distributed checkpoint and test. 2796 >>> def load_model(): 2797 ... ms.set_context(mode=ms.GRAPH_MODE) 2798 ... init() 2799 ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel", 2800 ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num) 2801 ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32)) 2802 ... network = Net(matmul_size=(96, 16)) 2803 ... model = ms.Model(network) 2804 ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data)) 2805 ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)] 2806 ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout) 2807 ... predict_result = model.predict(predict_data) 2808 ... print(predict_result) 2809 >>> 2810 >>> train_net() 2811 >>> load_model() 2812 [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ] 2813 [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887] 2814 ... 2815 [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]] 2816 """ 2817 network = Validator.check_isinstance("network", network, nn.Cell) 2818 _check_checkpoint_file(checkpoint_filenames) 2819 _check_predict_strategy(predict_strategy) 2820 2821 dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) 2822 dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) 2823 2824 if train_strategy_filename is None: 2825 train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") 2826 _train_strategy = build_searched_strategy(train_strategy_filename) 2827 train_strategy = _convert_to_list(_train_strategy) 2828 2829 train_dev_count = 1 2830 ckpt_file_len = len(checkpoint_filenames) 2831 for dim in train_strategy[list(train_strategy.keys())[0]][0]: 2832 train_dev_count *= dim 2833 if train_dev_count != ckpt_file_len: 2834 raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be " 2835 f"equal to the device count of training process. " 2836 f"But got the length of 'checkpoint_filenames'" 2837 f" is {ckpt_file_len} and the device count is {train_dev_count}.") 2838 rank_list = _infer_rank_list(train_strategy, predict_strategy) 2839 2840 param_total_dict = defaultdict(dict) 2841 for file_index, file_name in enumerate(checkpoint_filenames): 2842 ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode) 2843 for param_name, param in ckpt_dict.items(): 2844 param_total_dict[param_name][file_index] = param 2845 2846 param_dict = {} 2847 param_not_in_strategy = [] 2848 param_not_in_ckpt = [] 2849 for _, param in network.parameters_and_names(): 2850 sliced_params = [] 2851 if param.name not in rank_list.keys(): 2852 param_not_in_strategy.append(param.name) 2853 continue 2854 if param.name not in param_total_dict: 2855 param_not_in_ckpt.append(param.name) 2856 continue 2857 2858 param_rank = rank_list.get(param.name)[0] 2859 skip_merge_split = rank_list.get(param.name)[1] 2860 shard_stride = train_strategy.get(param.name)[4] 2861 if train_strategy.get(param.name)[5]: 2862 shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] 2863 else: 2864 shard_size = 0 2865 for rank in param_rank: 2866 param_total_list = list(range(0, ckpt_file_len)) 2867 if shard_size > 0: 2868 shard_total_list = [] 2869 for i in range(0, ckpt_file_len, shard_size): 2870 shard_total_list.append(param_total_list[i:i + shard_size]) 2871 param_total_list = shard_total_list[rank // shard_size] 2872 if shard_stride > 0: 2873 param_stride = [] 2874 # merge pre parameter 2875 param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] 2876 param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) 2877 param_index = list(set(param_index)) 2878 param_index.sort() 2879 for rank_num in param_index: 2880 if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16: 2881 param_stride.append( 2882 cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy()) 2883 else: 2884 param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) 2885 2886 sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) 2887 else: 2888 sliced_param = param_total_dict[param.name][rank] 2889 2890 sliced_params.append(sliced_param) 2891 if skip_merge_split: 2892 split_param = sliced_params[0] 2893 else: 2894 param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) 2895 _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy) 2896 split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy) 2897 opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None 2898 if opt_shard_group: 2899 if split_param.data.dtype == mstype.bfloat16: 2900 data = cpu_cast(split_param.data, mstype.float32).asnumpy() 2901 else: 2902 data = split_param.data.asnumpy() 2903 rank = get_rank(opt_shard_group) 2904 size = get_group_size(opt_shard_group) 2905 try: 2906 data_slice = np.split(data, size)[rank] 2907 except BaseException as e: 2908 logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}" 2909 " and group is {}".format(param.name, split_param.data.shape, opt_shard_group)) 2910 raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice" 2911 f" in load distributed checkpoint for {param.name}. Data shape is " 2912 f"{split_param.data.shape} and group is {opt_shard_group}.") from e 2913 split_param = Parameter(Tensor(data_slice), param.name, 2914 split_param.requires_grad, split_param.layerwise_parallel) 2915 param_dict[param.name] = split_param 2916 2917 if param_not_in_strategy: 2918 logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, " 2919 "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct." 2920 .format(param_not_in_strategy)) 2921 if param_not_in_ckpt: 2922 logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in " 2923 "the checkpoint file, please check whether 'checkpoint_filenames' is correct." 2924 .format(param_not_in_ckpt)) 2925 2926 load_param_into_net(network, param_dict, strict_load=strict_load) 2927 2928 2929def async_ckpt_thread_status(): 2930 """ 2931 Get the status of asynchronous save checkpoint thread. 2932 2933 When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed. 2934 2935 Returns: 2936 bool, True, Asynchronous save checkpoint thread is running. 2937 False, Asynchronous save checkpoint thread is not executing. 2938 2939 Examples: 2940 >>> import mindspore as ms 2941 >>> ms.async_ckpt_thread_status() 2942 False 2943 """ 2944 thr_list = threading.enumerate() 2945 return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list] 2946 2947 2948def _check_predict_strategy(predict_strategy): 2949 """Check predict strategy.""" 2950 2951 def _check_int_list(arg): 2952 if not isinstance(arg, list): 2953 return False 2954 for item in arg: 2955 if not isinstance(item, int): 2956 return False 2957 return True 2958 2959 if predict_strategy is None: 2960 return 2961 2962 flag = True 2963 predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) 2964 for key in predict_strategy.keys(): 2965 if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ 2966 or len(predict_strategy[key]) < 4: 2967 flag = False 2968 dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] 2969 if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ 2970 not (_check_int_list(param_split_shape) or not param_split_shape) or \ 2971 not (isinstance(field_size, int) and field_size == 0): 2972 flag = False 2973 2974 if not flag: 2975 raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, " 2976 f"the key of it must be string, and the value of it must be list or tuple that " 2977 f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), " 2978 f"param_split_shape (list[int]) and field_size (int, which value is 0)." 2979 f"Please check whether 'predict_strategy' is correct.") 2980 2981 2982def _check_checkpoint_file(checkpoint_filenames): 2983 """Check checkpoint file name.""" 2984 for index, filename in enumerate(checkpoint_filenames): 2985 if not isinstance(filename, str) or not os.path.exists(filename) \ 2986 or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: 2987 raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and " 2988 f"make sure the {filename} at index {index} is a valid checkpoint file, it must " 2989 f"be a string ending with '.ckpt', and the checkpoint file it represents must " 2990 f"be exist and not empty.") 2991 2992 2993def _merge_and_split(sliced_params, train_strategy, predict_strategy): 2994 """Merge sliced parameter and split it according to the predict strategy.""" 2995 merged_param = merge_sliced_parameter(sliced_params, train_strategy) 2996 if predict_strategy is None: 2997 return merged_param 2998 param_name = merged_param.name 2999 tensor_layout = predict_strategy[param_name] 3000 rank = get_rank() 3001 split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank) 3002 requires_grad = merged_param.requires_grad 3003 layerwise_parallel = merged_param.layerwise_parallel 3004 if merged_param.data.dtype == mstype.bfloat16: 3005 split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel) 3006 else: 3007 split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) 3008 return split_param 3009 3010 3011def _calculation_net_size(net): 3012 """Calculate the size of parameters in the network.""" 3013 data_total = 0 3014 net_dict = net.parameters_dict() 3015 for name in net_dict: 3016 data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024 3017 3018 return data_total 3019 3020 3021def _get_mindir_inputs(file_name): 3022 """ 3023 Get MindIR file's inputs. 3024 3025 Note: 3026 1. Parsing encrypted MindIR file is not supported. 3027 2. Parsing dynamic shape MindIR file is not supported. 3028 3029 Args: 3030 file_name (str): MindIR file name. 3031 3032 Returns: 3033 Tensor, list(Tensor), the input of MindIR file. 3034 3035 Raises: 3036 TypeError: If the parameter file_name is not `str`. 3037 RuntimeError: MindIR's input is not tensor type or has no dims. 3038 3039 Examples: 3040 >>> input_tensor = get_mindir_inputs("lenet.mindir") 3041 """ 3042 Validator.check_file_name_by_regular(file_name) 3043 file_name = os.path.abspath(file_name) 3044 model = read_proto(file_name) 3045 input_tensor = [] 3046 3047 for ele_input in model.graph.input: 3048 input_shape = [] 3049 if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"): 3050 raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.") 3051 3052 for ele_shape in ele_input.tensor[0].dims: 3053 input_shape.append(ele_shape) 3054 if is_shape_unknown(input_shape): 3055 raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.") 3056 3057 mindir_type = ele_input.tensor[0].data_type 3058 if mindir_type not in mindir_to_tensor_type: 3059 raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.") 3060 3061 input_type = mindir_to_tensor_type.get(mindir_type) 3062 input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One())) 3063 3064 if not input_tensor: 3065 logger.warning("The MindIR model has no input, return None.") 3066 return None 3067 return input_tensor[0] if len(input_tensor) == 1 else input_tensor 3068 3069 3070def convert_model(mindir_file, convert_file, file_format): 3071 """ 3072 Convert mindir model to other format model. The current version only supports conversion to ONNX models. 3073 3074 .. warning:: 3075 This is an experimental API that is subject to change or deletion. 3076 3077 Args: 3078 mindir_file (str): MindIR file name. 3079 convert_file (str): Convert model file name. 3080 file_format (str): Convert model's format, current version only supports "ONNX". 3081 3082 Raises: 3083 TypeError: If the parameter `mindir_file` is not `str`. 3084 TypeError: If the parameter `convert_file` is not `str`. 3085 ValueError: If the parameter `file_format` is not "ONNX". 3086 3087 Examples: 3088 >>> import mindspore as ms 3089 >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX") 3090 """ 3091 Validator.check_file_name_by_regular(mindir_file) 3092 Validator.check_file_name_by_regular(convert_file) 3093 if file_format != "ONNX": 3094 raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.") 3095 net_input = _get_mindir_inputs(mindir_file) 3096 graph = load(mindir_file) 3097 net = nn.GraphCell(graph) 3098 if isinstance(net_input, Tensor): 3099 export(net, net_input, file_name=convert_file, file_format=file_format) 3100 else: 3101 export(net, *net_input, file_name=convert_file, file_format=file_format) 3102