1# Copyright 2020-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Train utility.""" 16from __future__ import absolute_import 17 18import os 19import json 20from collections.abc import Iterable 21 22import numpy as np 23 24from mindspore.common.tensor import Tensor 25from mindspore._c_expression import Tensor as Tensor_ 26from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype 27from mindspore.common import dtype as mstype 28from mindspore import log as logger 29from mindspore import _checkparam as Validator 30from mindspore.common.api import _cell_graph_executor 31from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model 32from mindspore.train.checkpoint_pb2 import Checkpoint 33from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy 34from mindspore.train.lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo 35from mindspore.parallel._parallel_serialization import _make_dir 36from mindspore.ops.operations import debug_ops 37 38 39def _convert_type(types): 40 """ 41 Convert from numpy type to tensor type. 42 43 Args: 44 types (list): Numpy type list of element in dataset. 45 46 Returns: 47 list, list of element in dataset. 48 """ 49 ms_types = [] 50 for np_type in types: 51 ms_type = pytype_to_dtype(np_type) 52 ms_types.append(ms_type) 53 return ms_types 54 55 56def _get_types_and_shapes(dataset): 57 """Get dataset types and shapes.""" 58 dataset_types = _convert_type(dataset.output_types()) 59 dataset_shapes = dataset.output_shapes() 60 return dataset_types, dataset_shapes 61 62 63def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False): 64 """Initialize and execute the dataset graph.""" 65 batch_size = exec_dataset.get_batch_size() 66 input_indexs = exec_dataset.input_indexs 67 68 # transform data format 69 dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) 70 send_epoch_end = bool(dataset_size == -1) 71 queue_name = _cell_graph_executor.get_queue_name(phase) 72 if queue_name is None: 73 queue_name = str("") 74 exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, 75 create_data_info_queue=create_data_info_queue, queue_name=queue_name) 76 _cell_graph_executor.init_dataset(exec_dataset.queue_name, 77 dataset_size, 78 batch_size, 79 dataset_types, 80 dataset_shapes, 81 input_indexs, 82 phase=phase) 83 return exec_dataset 84 85 86def _make_directory(path, arg_name='path'): 87 """Make directory.""" 88 return _make_dir(path, arg_name) 89 90 91def _construct_tensor_list(types, shapes, batch_expand_num=1): 92 """ 93 Construct list of tensors with types and shapes, used to initialize the network. 94 95 Args: 96 types: List or Tuple. The output types of element in dataset. 97 shapes: List or Tuple. The output shapes of element in dataset. 98 batch_expand_num (int): Batch expand number. 99 100 Returns: 101 List, list of Tensors. 102 """ 103 if len(types) != len(shapes): 104 raise ValueError("The length of dataset types must be equal to dataset shapes, " 105 "but got dataset types={} and dataset shapes={}".format(types, shapes)) 106 tensor_list = [] 107 for type_, shape in zip(types, shapes): 108 new_shape = () 109 for i, item in enumerate(shape): 110 if i == 0: 111 new_shape += (item * batch_expand_num,) 112 else: 113 new_shape += (item,) 114 tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)), dtype=type_) 115 tensor.virtual_flag = True 116 tensor_list.append(tensor) 117 return tensor_list 118 119 120def _to_tensor(elem, scaling_sens=None): 121 """Convert numpy to tensor, adapt to feed the data from host solution.""" 122 lst = [] 123 if not isinstance(elem, (tuple, list)): 124 elem = [elem] 125 for data in elem: 126 if not isinstance(data, np.ndarray): 127 if scaling_sens: 128 elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),) 129 else: 130 elem_tuple = tuple(elem) 131 return elem_tuple 132 lst.append(Tensor(data)) 133 if scaling_sens: 134 lst.append(Tensor(scaling_sens, mstype.float32)) 135 136 return lst[0] if len(lst) == 1 else tuple(lst) 137 138 139def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1): 140 """Construct tensor list to initialize the network which implemented in dataset sink.""" 141 tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1) 142 tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number) 143 return tensor_list_run, tensor_list_compile 144 145 146def _check_to_numpy(plugin, tensor, prim=None): 147 """Check the tensor and return a numpy.ndarray.""" 148 np_value = tensor.asnumpy() 149 np_value = np_value.copy() 150 summary_name = plugin.capitalize() + "Summary" if prim else "SummaryRecord" 151 if plugin == 'scalar': 152 if np_value.size == 1: 153 return np_value 154 raise ValueError( 155 f'For "{summary_name}", the v rank must be less than or equal to 1, but got {len(np_value)}.') 156 if plugin == 'image': 157 if np_value.ndim == 4: 158 return np_value 159 raise ValueError(f'For "{summary_name}", the tensor seems not to hold a valid image.') 160 if plugin in ('tensor', 'histogram'): 161 if np_value.ndim > 0: 162 return np_value 163 raise ValueError(f'For "{summary_name}", the value should not be empty.') 164 return np_value 165 166 167def check_summary_param(summary_name, tag, tensor): 168 """Checks the tag is valid for summary.""" 169 plugin = summary_name.split('Summary')[0].lower() 170 try: 171 if not isinstance(tag, str) or not tag: 172 raise TypeError(f'For "{summary_name}", the name must be valid string, but got "{tag}".') 173 if not isinstance(tensor, (Tensor, Tensor_)): 174 raise TypeError(f'For "{summary_name}", the parameter "value" expect to be Tensor, ' 175 f'but got {type(tensor).__name__}') 176 _check_to_numpy(plugin, tensor, prim=True) 177 except TypeError as err: 178 raise TypeError(err) from err 179 except ValueError as err: 180 raise ValueError(err) from err 181 finally: 182 debug_ops.SUMMARY_TENSOR_CACHE = [] 183 184 185def _check_lineage_value(plugin, value): 186 """Check the lineage value.""" 187 188 def raises(plugin, prototype): 189 raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.') 190 191 if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph): 192 raises(plugin, DatasetGraph) 193 194 if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage): 195 raises(plugin, EvaluationLineage) 196 197 if plugin == 'train_lineage' and not isinstance(value, TrainLineage): 198 raises(plugin, TrainLineage) 199 200 if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): 201 raises(plugin, UserDefinedInfo) 202 203 204def check_value_type(arg_name, arg_value, valid_types): 205 """Checks whether a value is instance of some types.""" 206 valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,) 207 is_valid = True 208 209 # bool is subclass of int, so for a bool value, we need to extra check 210 if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types: 211 is_valid = False 212 213 if not isinstance(arg_value, valid_types): 214 is_valid = False 215 216 if not is_valid: 217 raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' 218 f'but got {type(arg_value).__name__}.') 219 220 221def read_proto(file_name, proto_format="MINDIR", display_data=False): 222 """ 223 Read protobuf file. 224 225 Args: 226 file_name (str): File name. 227 proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR. 228 display_data (bool): Whether display data. Default: ``False``. 229 230 Returns: 231 Object, proto object. 232 """ 233 Validator.check_file_name_by_regular(file_name) 234 file_name = os.path.realpath(file_name) 235 if proto_format == "MINDIR": 236 model = mindir_model() 237 elif proto_format == "CKPT": 238 model = Checkpoint() 239 elif proto_format == "CKPT_STRATEGY": 240 model = ckpt_strategy() 241 else: 242 raise ValueError("Unsupported proto format.") 243 244 try: 245 with open(file_name, "rb") as f: 246 pb_content = f.read() 247 model.ParseFromString(pb_content) 248 except BaseException as e: 249 logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format}," 250 f" please check the correct file and format.") 251 raise ValueError(e.__str__()) from e 252 finally: 253 pass 254 255 if proto_format == "MINDIR" and not display_data: 256 for param_proto in model.graph.parameter: 257 param_proto.raw_data = b'\0' 258 259 if proto_format == "CKPT" and not display_data: 260 for element in model.value: 261 if element.tensor.ByteSize() != 0: 262 element.tensor.tensor_content = b'\0' 263 else: 264 for ele in element.maptensor.tensor: 265 ele.tensor_content = b'\0' 266 267 return model 268 269 270def parse_strategy_ckpt(file_name): 271 """ 272 Parses a strategy ckpt layout file and returns the rank location dict. 273 274 Args: 275 file_name (str):Strategy ckpt file name. 276 277 Returns: 278 Dict, layout dict. Key is parameter name, value is (dev_matrix, tensor_map). 279 280 Examples: 281 >>> from mindspore.train.utils import parse_strategy_ckpt 282 >>> layout_dict = parse_strategy_ckpt("/path/to/strategy.ckpt") 283 {"param1": [[4, 4], [0, -1]], "param2": [[4, 4], [-1, 0]],,,,} 284 """ 285 model = ckpt_strategy() 286 with open(file_name, "rb") as f: 287 pb_content = f.read() 288 model.ParseFromString(pb_content) 289 layout_dict = {} 290 for param in model.parallel_layout_item: 291 dev_matrix = [] 292 tensor_map = [] 293 for ele in param.parallel_layouts.dev_matrix[0].ListFields()[0][1]: 294 dev_matrix.append(ele) 295 296 for ele in param.parallel_layouts.tensor_map[0].ListFields()[0][1]: 297 tensor_map.append(ele) 298 layout_dict[param.param_name] = [dev_matrix, tensor_map] 299 return layout_dict 300 301 302def get_parameter_redundancy(layout_obj, initial_rank=0): 303 """ 304 Get parameter redundancy map. 305 306 Args: 307 layout_obj (Union[str, layout): File name of `strategy.ckpt` or net.parameter_layout_dict. 308 initial_rank (int): Start rank id for each pipeline. Default: 0. 309 310 Returns: 311 Dict, dict of parameter redundancy info. 312 313 Examples: 314 >>> from mindspore.train.utils import get_parameter_redundancy 315 >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt") 316 {'param1': ((0, 1, 2, 3, 4, 5, 6, 7),), 317 'param2': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)), 318 'param3': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)), 319 'param4': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15))} 320 """ 321 if isinstance(layout_obj, str): 322 parameter_layout = parse_strategy_ckpt(layout_obj) 323 else: 324 parameter_layout = {} 325 for k, v in layout_obj.items(): 326 parameter_layout[k] = v[:2] 327 328 param_redundancy_dict = {} 329 for key, (slices, deploy_loc, *_) in parameter_layout.items(): 330 redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8) 331 for i in deploy_loc: 332 internal_slice = tuple(slice(None) for _ in range(i)) 333 for j in range(slices[-i - 1]): 334 if i == -1: 335 continue 336 else: 337 redundancy_matrix[(..., j) + internal_slice + (i,)] = j 338 locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist() 339 redundancy_dict = {} 340 for index, locate in enumerate(locate_list): 341 redundancy_dict.setdefault(tuple(locate), []).append(index+initial_rank) 342 redundancy_list = [] 343 for _, indices in sorted(redundancy_dict.items()): 344 redundancy_list.append(tuple(indices)) 345 346 param_redundancy_dict[key] = tuple(redundancy_list) 347 return param_redundancy_dict 348 349 350def _collect_settings_by_rank(redundancy_map): 351 """ 352 Collect parameter redundancy map by rank id. 353 354 {"param1":((1,3,5,7),(2,4,6,8)),"param2":((1,3,5,7),(2,4,6,8))} 355 ->{(1,3,5,7):{"param1", "param2"},(2,4,6,8):{"param1", "param2"}} 356 """ 357 redundancy_map_reversed = {} 358 for key, redundancy in redundancy_map.items(): 359 for index, item in enumerate(redundancy): 360 redundancy_map_reversed.setdefault(item, []).append( 361 (key, index)) 362 return redundancy_map_reversed 363 364 365def _restructure(input_dict): 366 """ 367 Flatten and reorganize the nested dictionary structure.""" 368 if all(not isinstance(item, tuple) for item in input_dict): 369 return input_dict 370 res_dict = {} 371 for key, values in input_dict.items(): 372 for index, value in enumerate(values): 373 res_dict.setdefault(key[index % len(key)], []).append(value) 374 return _restructure(res_dict) 375 376 377def _rotate_list_elements(i, input_list): 378 """Rotate element list.""" 379 rotated_list = [input_list[(i + j) % len(input_list)] for j in 380 range(len(input_list))] 381 return rotated_list 382 383 384def remove_param_redundancy(param_redundancy_dict, keep_redundancy=1): 385 """ 386 Remove parameter redundancy, get the single parameter for each rank id. 387 Args: 388 param_redundancy_dict (Dict): Parameter redundancy dict. 389 keep_redundancy (Int): Keep redundancy number. 390 391 Returns: 392 Dict, single parameter for each rank id. Key is rank_id, value is set(params). 393 394 Examples: 395 >>> from mindspore.train.utils import get_parameter_redundancy, remove_param_redundancy 396 >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt") 397 >>> single_parameter = remove_param_redundancy(param_redundancy_dict) 398 {0: {param1, param3}, 1: {param2, param4},,,}} 399 """ 400 redundancy_dict_reversed = _collect_settings_by_rank(param_redundancy_dict) 401 sorted_layouts = {} 402 for device_layout, layer_names_list in redundancy_dict_reversed.items(): 403 sorted_layer_names = [item[0] for item in layer_names_list] 404 sorted_layouts[device_layout] = sorted_layer_names 405 result = {} 406 for i in range(keep_redundancy): 407 rotated_layouts = {tuple(_rotate_list_elements(i, key)): value for 408 key, value in sorted_layouts.items()} 409 restructured_layouts = _restructure(rotated_layouts) 410 for key, value in restructured_layouts.items(): 411 result.setdefault(key, set()).update(set(value)) 412 return result 413 414 415def parse_hccl_file(hccl_file_path): 416 """ 417 Parses an HCCL configuration JSON file, return a dict key is rank_id, value is device_ip. 418 419 Args: 420 hccl_file_path (str): The path to the HCCL configuration JSON file. 421 422 Returns: 423 Dict: A Dict, key is rank_id, value is device_ip. 424 425 Examples: 426 >>> from mindspore.train.utils import parse_hccl_file 427 >>> rankid_dict = parse_hccl_file("/path/to/hccl.json") 428 {0: "10.11.10.163", 1: "10.11.10.164", 2: "10.11.10.165", 3: "10.11.10.166",,,,} 429 """ 430 with open(hccl_file_path) as f: 431 hccl_dict = json.load(f) 432 server_list = hccl_dict["server_list"] 433 rankid_dict = {} 434 for server in server_list: 435 device_list = server["device"] 436 for device in device_list: 437 rankid_dict[int(device["rank_id"])] = device["device_ip"] 438 439 return rankid_dict 440