1# Copyright 2023-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Utils of auto parallel""" 16from importlib import import_module 17import numpy as np 18import mindspore as ms 19from mindspore import context, log as logger 20from mindspore._c_expression import reset_op_id, reset_op_id_with_offset 21from mindspore.common.tensor import Tensor 22from mindspore.common.dtype import dtype_to_nptype 23from mindspore.common import dtype as mstype 24from mindspore.communication.management import get_group_size, get_rank 25from mindspore.parallel._auto_parallel_context import auto_parallel_context 26from mindspore.common.seed import get_seed 27from mindspore._c_expression import GraphExecutor_ 28from mindspore.parallel._tensor import _load_tensor_by_layout 29 30SUPPORTED_TUPLE_IN_TUPLE_STRATEGY = ["GroupedMatmul", "FusedInferAttentionScore"] 31 32 33def _get_parallel_mode(): 34 """Get parallel mode.""" 35 return auto_parallel_context().get_parallel_mode() 36 37 38def _is_sharding_propagation(): 39 """Is sharding propagation.""" 40 return (auto_parallel_context().get_strategy_search_mode() == "sharding_propagation") or ( 41 auto_parallel_context().get_sharding_propagation()) 42 43 44def _is_in_auto_parallel_mode(): 45 return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL] 46 47 48def _is_in_data_parallel_mode(): 49 return _get_parallel_mode() == ms.ParallelMode.DATA_PARALLEL 50 51 52def _is_in_hybrid_parallel_mode(): 53 return _get_parallel_mode() == ms.ParallelMode.HYBRID_PARALLEL 54 55 56def _is_pynative_parallel(): 57 parallel_mode = context.get_auto_parallel_context('parallel_mode') 58 return context.get_context('mode') == context.PYNATIVE_MODE and parallel_mode in ( 59 context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL) 60 61 62def _get_full_batch(): 63 """Get whether to use full_batch.""" 64 return auto_parallel_context().get_full_batch() 65 66 67def _get_pipeline_stages(): 68 """Get pipeline stages""" 69 return auto_parallel_context().get_pipeline_stages() 70 71 72def _check_full_batch(): 73 """ 74 full_batch could only be used under semi_auto_parallel or auto_parallel, check it. 75 76 Raises: 77 RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel. 78 """ 79 parallel_mode = _get_parallel_mode() 80 full_batch = _get_full_batch() 81 if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch): 82 raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.") 83 84 85def _need_to_full(): 86 """Check whether to convert input to full shape or tensor.""" 87 if _get_parallel_mode() not in ("semi_auto_parallel", "auto_parallel"): 88 return False 89 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 90 if dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"): 91 return True 92 return not _get_full_batch() 93 94 95def _slice_parameter(parameter, phase, layout): 96 """Slice python parameter obj according to the layout.""" 97 is_train_phase = phase.startswith('train') 98 is_prefill_phase = phase.startswith('prefill') 99 if layout is not None and parameter.from_ckpt and not is_train_phase: 100 is_opt_shard_group = layout[5] 101 if not parameter.sliced and is_prefill_phase and is_opt_shard_group: 102 rank = get_rank() 103 new_tensor = _load_tensor_by_layout(parameter, layout, rank) 104 parameter.set_data(new_tensor, True) 105 return 106 layout_shape = layout[2] 107 parameter.shape = tuple(layout_shape) 108 return 109 graph_executor = GraphExecutor_.get_instance() 110 new_param = parameter.init_data(layout, set_sliced=True) 111 parameter = new_param 112 graph_executor.updata_param_node_default_input(phase, {parameter.name: parameter}) 113 if layout is None: 114 parameter.sliced = True 115 return 116 if not parameter.sliced: 117 rank = get_rank() 118 new_tensor = _load_tensor_by_layout(parameter, layout, rank) 119 parameter.set_data(new_tensor, True) 120 121 122def _slice_tensor(tensor, layout, rank_id): 123 """Slice python tensor obj according to the layout.""" 124 new_tensor = _load_tensor_by_layout(tensor, layout, rank_id) 125 return new_tensor 126 127 128def _init_optimizer_state(parameter, phase): 129 """init optimizer state""" 130 if not parameter.has_init: 131 return 132 graph_executor = GraphExecutor_.get_instance() 133 new_param = parameter.init_data() 134 parameter = new_param 135 graph_executor.updata_param_node_default_input(phase, {parameter.name: parameter}) 136 137 138def _to_full_shapes(shapes, device_num): 139 """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution.""" 140 new_shapes = [] 141 dataset_strategy = () 142 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 143 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 144 if dataset_strategy: 145 if len(shapes) != len(dataset_strategy): 146 raise ValueError("The input shapes size {} is not equal to " 147 "dataset strategy size {}".format(len(shapes), len(dataset_strategy))) 148 for index, shape in enumerate(shapes): 149 if len(shape) != len(dataset_strategy[index]): 150 raise ValueError("The input shapes item size {} is not equal to " 151 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index]))) 152 new_shape = [] 153 for i, item in enumerate(shape): 154 if item > 0: 155 new_shape += (item * dataset_strategy[index][i],) # static shape 156 else: 157 new_shape += (item,) # dynamic shape 158 new_shapes.append(new_shape) 159 return new_shapes 160 for shape in shapes: 161 shape_v = [] 162 for i, item in enumerate(shape): 163 if i == 0 and item > 0: 164 shape_v += (item * device_num,) # only for static shape 165 else: 166 shape_v += (item,) 167 new_shapes.append(shape_v) 168 return new_shapes 169 170 171def _origin_shapes(shapes): 172 """resume origin shape after full shape.""" 173 if _need_to_full(): 174 device_num = _get_device_num() // _get_pipeline_stages() 175 else: 176 return shapes 177 new_shapes = [] 178 dataset_strategy = () 179 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 180 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 181 if dataset_strategy: 182 if len(shapes) != len(dataset_strategy): 183 raise ValueError("The input shapes size {} is not equal to " 184 "dataset strategy size {}".format(len(shapes), len(dataset_strategy))) 185 for index, shape in enumerate(shapes): 186 if len(shape) != len(dataset_strategy[index]): 187 raise ValueError("The input shapes item size {} is not equal to " 188 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index]))) 189 new_shape = [] 190 for i, item in enumerate(shape): 191 if item > 0: 192 new_shape += (item // dataset_strategy[index][i],) # static shape 193 else: 194 new_shape += (item,) # dynamic shape 195 new_shapes.append(new_shape) 196 return new_shapes 197 for shape in shapes: 198 shape_v = [] 199 for i, item in enumerate(shape): 200 if i == 0 and item > 0: 201 shape_v += (item // device_num,) # only for static shape 202 else: 203 shape_v += (item,) 204 new_shapes.append(shape_v) 205 return new_shapes 206 207 208def _dynamic_shape_for_dataset(dataset_shapes, dynamic_shapes): 209 """convert static dataset shapes to dynamic shape""" 210 if len(dataset_shapes) != len(dynamic_shapes): 211 raise ValueError("The dataset shapes size of {} is not equal to " 212 "dynamic shapes size of {}".format(dataset_shapes, dynamic_shapes)) 213 ret = dataset_shapes 214 for i in range(len(dynamic_shapes)): 215 if len(dataset_shapes[i]) != len(dynamic_shapes[i]): 216 raise ValueError("The dataset shapes size of {} is not equal to " 217 "dynamic shapes size of {}".format(dataset_shapes, dynamic_shapes)) 218 for j in range(len(dynamic_shapes[i])): 219 if dynamic_shapes[i][j] == -1: 220 ret[i][j] = -1 221 return ret 222 223 224def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None): 225 """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data 226 from host solution. 227 """ 228 lst = [] 229 device_num = global_device_num // _get_pipeline_stages() 230 stage_rank = global_rank % device_num 231 if not isinstance(elem, (tuple, list)): 232 elem = [elem] 233 if stage_rank >= device_num: 234 raise ValueError("The global rank must be smaller than device number, the global rank is {}, " 235 "the device num is {}".format(stage_rank, device_num)) 236 dataset_strategy = () 237 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 238 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 239 if elem and dataset_strategy: 240 if len(elem) != len(dataset_strategy): 241 raise ValueError("The input size {} is not equal to " 242 "dataset strategy size {}".format(len(elem), len(dataset_strategy))) 243 for index, data in enumerate(elem): 244 if isinstance(data, np.ndarray): 245 data = Tensor(data) 246 if not isinstance(data, Tensor): 247 raise ValueError("elements in tensors must be Tensor") 248 shape_ = data.shape 249 type_ = data.dtype 250 new_shape = () 251 if not dataset_strategy: 252 batchsize_per_device = 1 253 for i, item in enumerate(shape_): 254 if i == 0: 255 new_shape += (item * device_num,) 256 batchsize_per_device = item 257 else: 258 new_shape += (item,) 259 new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) 260 start = stage_rank * batchsize_per_device 261 new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy() 262 else: 263 if len(shape_) != len(dataset_strategy[index]): 264 raise ValueError("The input shapes item size {} is not equal to " 265 "dataset strategy item size {}".format(len(shape_), len(dataset_strategy[index]))) 266 slice_index = () 267 for i, item in enumerate(shape_): 268 new_shape += (item * dataset_strategy[index][i],) 269 start = (stage_rank % dataset_strategy[index][i]) * item 270 end = (stage_rank % dataset_strategy[index][i] + 1) * item 271 s = slice(start, end, 1) 272 slice_index += (s,) 273 new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) 274 new_tensor_numpy[slice_index] = data.asnumpy() 275 new_tensor = Tensor(new_tensor_numpy, dtype=type_) 276 lst.append(new_tensor) 277 if scaling_sens: 278 lst.append(Tensor(scaling_sens, mstype.float32)) 279 return tuple(lst) 280 281 282def _get_gradients_mean(): 283 """Get if using gradients_mean.""" 284 return auto_parallel_context().get_gradients_mean() 285 286 287def _get_device_num(): 288 """Get the device num.""" 289 parallel_mode = auto_parallel_context().get_parallel_mode() 290 if parallel_mode == "stand_alone": 291 device_num = 1 292 return device_num 293 294 if auto_parallel_context().get_device_num_is_set() is False: 295 device_num = get_group_size() 296 else: 297 device_num = auto_parallel_context().get_device_num() 298 return device_num 299 300 301def _get_stage_device_num(): 302 """Get the device number of each pipeline stage""" 303 return _get_device_num() // _get_pipeline_stages() 304 305 306def _get_global_rank(): 307 """Get the global rank.""" 308 parallel_mode = auto_parallel_context().get_parallel_mode() 309 if parallel_mode == "stand_alone": 310 global_rank = 0 311 return global_rank 312 313 if auto_parallel_context().get_global_rank_is_set() is False: 314 global_rank = get_rank() 315 else: 316 global_rank = auto_parallel_context().get_global_rank() 317 return global_rank 318 319 320def _get_parameter_broadcast(): 321 """Get the parameter broadcast.""" 322 parallel_mode = auto_parallel_context().get_parallel_mode() 323 parameter_broadcast = auto_parallel_context().get_parameter_broadcast() 324 325 if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed() is None: 326 logger.warning("You are suggested to use mindspore.context.set_auto_parallel_context(parameter_broadcast=True)" 327 " or mindspore.common.set_seed() to share parameters among multi-devices.") 328 329 return parameter_broadcast 330 331 332def _get_enable_parallel_optimizer(): 333 """Get if using parallel optimizer.""" 334 return auto_parallel_context().get_enable_parallel_optimizer() 335 336 337def _get_grad_accumulation_shard(): 338 """Get if using parallel shard.""" 339 return auto_parallel_context().get_grad_accumulation_shard() 340 341 342def _device_number_check(parallel_mode, device_number): 343 """ 344 Check device num. 345 346 Args: 347 parallel_mode (str): The parallel mode. 348 device_number (int): The device number. 349 """ 350 if parallel_mode == "stand_alone" and device_number != 1: 351 raise ValueError("If parallel_mode is stand_alone, device_number must be 1, " 352 "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode)) 353 354 355def _parameter_broadcast_check(parallel_mode, parameter_broadcast): 356 """ 357 Check parameter broadcast. 358 359 Note: 360 If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same 361 random seed to make sure parameters on multiple devices are the same. 362 363 Args: 364 parallel_mode (str): The parallel mode. 365 parameter_broadcast (bool): The parameter broadcast. 366 367 Raises: 368 ValueError: If parameter is broadcasted 369 but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel"). 370 """ 371 if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"): 372 raise ValueError("stand_alone, semi_auto_parallel and auto_parallel " 373 "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}" 374 .format(parallel_mode, parameter_broadcast)) 375 376 377def _get_python_op(op_name, op_path, instance_name, arglist): 378 """Get python operator.""" 379 module = import_module(op_path) 380 cls = getattr(module, op_name) 381 if op_path != "mindspore.ops.functional": 382 # The AllGather attrs contains group_name and group_ranks, pop group_ranks. 383 if op_name == "AllGather" and len(arglist) == 2: 384 arglist.pop() 385 op = cls(*arglist) 386 else: 387 op = cls 388 op.set_prim_instance_name(instance_name) 389 return op 390 391 392def _reset_op_id(): 393 """Reset op id.""" 394 reset_op_id() 395 396 397def _reset_op_id_with_offset(): 398 """Reset op id with offset.""" 399 reset_op_id_with_offset() 400 401 402def _parallel_predict_check(): 403 """validate parallel model prediction""" 404 if _is_in_auto_parallel_mode(): 405 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 406 is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch")) 407 if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp: 408 logger.warning('Using non full-batch dataset in model prediction may lead to incorrect data.') 409 410 411def _check_similar_layout(tensor_layout1, tensor_layout2): 412 """check if two tensor layouts are same""" 413 if tensor_layout1[1] != tensor_layout2[1]: 414 return False 415 for i in tensor_layout1[1]: 416 if i == -1: 417 continue 418 if tensor_layout1[0][-1 - i] != tensor_layout2[0][-1 - i]: 419 return False 420 return True 421 422 423def _check_same_layout(tensor_layout1, tensor_layout2): 424 """check if two tensor layouts are same""" 425 return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1] 426 427 428def _remove_repeated_slices(tensor_layout): 429 """generate unrepeated tensor layout""" 430 import copy 431 new_tensor_layout = copy.deepcopy(tensor_layout) 432 dev_mat = tensor_layout[0][:] 433 tensor_map = tensor_layout[1] 434 for dim in range(len(dev_mat)): 435 if dim not in tensor_map: 436 dev_mat[-1 - dim] = 1 437 new_tensor_layout[0] = dev_mat 438 return new_tensor_layout 439 440 441def _infer_rank_list(train_map, predict_map=None): 442 """ 443 infer checkpoint slices to be loaded. 444 map value format: [dev_mat, tensor_map, param_split_shape, field_size, opt_shard_stride, opt_shard_size] 445 """ 446 ret = {} 447 if _get_pipeline_stages() > 1: 448 local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages())) 449 else: 450 local_rank = _get_global_rank() 451 for param_name in train_map: 452 train_layout = train_map[param_name] 453 train_dev_mat = train_layout[0] 454 dev_num = np.array(train_dev_mat).prod() 455 new_train_layout = _remove_repeated_slices(train_layout) 456 array = np.arange(dev_num).reshape(train_dev_mat) 457 index = () 458 for i in new_train_layout[0]: 459 if i == 1: 460 index = index + (0,) 461 else: 462 index = index + (slice(None),) 463 rank_list = array[index].flatten() 464 if not predict_map: 465 ret[param_name] = (rank_list, False) 466 continue 467 if param_name not in predict_map: 468 logger.warning("predict_map does not contain %s", param_name) 469 continue 470 predict_layout = predict_map[param_name] 471 dev_num = np.array(predict_layout[0]).prod() 472 # optimization pass 473 if _check_same_layout(train_layout, predict_layout): 474 ret[param_name] = ([local_rank], True) 475 continue 476 if _check_similar_layout(train_layout, predict_layout): 477 if len(rank_list) == 1: 478 ret[param_name] = (rank_list, True) 479 elif len(rank_list) == dev_num: 480 ret[param_name] = ([rank_list[local_rank]], True) 481 else: 482 ret[param_name] = (rank_list, False) 483 else: 484 ret[param_name] = (rank_list, False) 485 return ret 486 487 488def _handle_symbol_inputs(symbol_inputs): 489 """handle symbol inputs""" 490 dataset_strategy = () 491 divisor_key = "divisor" 492 # dataset strategy is set 493 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 494 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 495 if dataset_strategy: 496 if len(symbol_inputs) != len(dataset_strategy): 497 raise ValueError("The symbol_inputs size {} is not equal to " 498 "dataset strategy size {}".format(len(symbol_inputs), len(dataset_strategy))) 499 for index, shape in enumerate(symbol_inputs): 500 dataset_ele_s = dataset_strategy[index] 501 if len(shape) != len(dataset_ele_s): 502 raise ValueError("The symbol_inputs item size {} is not equal to " 503 "dataset strategy item size {}".format(len(shape), len(dataset_ele_s))) 504 505 for i, item in enumerate(shape): 506 if isinstance(item, dict): # symbol 507 symbol_inputs[index][i][divisor_key] = symbol_inputs[index][i][divisor_key] * dataset_ele_s[i] 508 else: # common shape 509 symbol_inputs[index][i] = item * dataset_ele_s[i] 510 511 return symbol_inputs 512 513 # full batch is set 514 device_num = _get_device_num() // _get_pipeline_stages() 515 for index, shape in enumerate(symbol_inputs): 516 for i, item in enumerate(shape): 517 if i == 0 and isinstance(item, dict): # symbol 518 symbol_inputs[index][i][divisor_key] = symbol_inputs[index][i][divisor_key] * device_num 519 520 return symbol_inputs 521 522 523def _no_need_to_change_symbols(shapes): 524 """no need to handle the symbol if full_batch is true or it's not parallel mode""" 525 if not _need_to_full(): 526 return True 527 528 # if static shape, return 529 is_dynamic_shape = False 530 for shape in shapes: 531 if any(i < 0 for i in shape): 532 is_dynamic_shape = True 533 break 534 if is_dynamic_shape is False: 535 return True 536 537 return False 538 539 540def _change_symbols_for_parallel(shapes, symbol_inputs=None): 541 """create or modify symbol inputs""" 542 if _no_need_to_change_symbols(shapes) is True: 543 return symbol_inputs 544 # the symbol_inputs is [[{'divisor': 8}, 16], [{'divisor': 8}, 16]] 545 # the dataset_shapes is [(-1, 16), (-1, 16)] 546 # if symbol_inputs is [None, None, ..., None], reset it 547 if symbol_inputs is not None and all(s is None for s in symbol_inputs): 548 symbol_inputs = [] 549 550 # if symbol inputs is none or empty, create default symbol inputs 551 # if symbol inputs is not none, handle the empty symbol 552 divisor_key = "divisor" 553 if symbol_inputs is None or bool(symbol_inputs) is False: 554 symbol_inputs = [list(shape) for shape in shapes] # tuple to list 555 for i, s in enumerate(symbol_inputs): 556 for j, item in enumerate(s): 557 if item == -1: 558 symbol_inputs[i][j] = {divisor_key: 1} 559 else: 560 for i, s in enumerate(symbol_inputs): 561 # the symbol_inputs may be [None, [{'divisor': 8}, 16]] 562 # and the dataset_shapes is [(-1, 16), (-1, 16)], need to handle None 563 if s is None: 564 symbol_inputs[i] = shapes[i] 565 for k, item in enumerate(symbol_inputs[i]): 566 if item == -1: 567 symbol_inputs[i][k] = {divisor_key: 1} 568 s = symbol_inputs[i] 569 for j, item in enumerate(s): 570 if isinstance(item, dict) and bool(item) is False: # the item is empty 571 symbol_inputs[i][j] = {divisor_key: 1} 572 573 return _handle_symbol_inputs(symbol_inputs) 574 575 576def _grads_divided_by_device_num_if_recomputation(grads): 577 """ 578 If in pynative parallel and full_batch is True, divide grads by device num to ensure that the gradients is correct. 579 """ 580 if not _is_pynative_parallel() or not _get_full_batch(): 581 return grads 582 583 device_num = _get_device_num() 584 logger.info(f"In PyNative mode, when parallel mode is in " 585 f"({context.ParallelMode.SEMI_AUTO_PARALLEL}, {context.ParallelMode.AUTO_PARALLEL}) and " 586 f"full_batch is Ture, the gradients will be automatically divided by device_num({device_num}).") 587 588 if not isinstance(grads, (tuple, Tensor)): 589 raise ValueError(f"The type of grads must be either Tuple[Tensor] or Tensor, but got {type(grads)}.") 590 591 if isinstance(grads, tuple): 592 new_grads = () 593 if grads: 594 device_num_tensor = Tensor(device_num, grads[0].dtype) 595 for grad in grads: 596 new_grads += (grad / device_num_tensor,) 597 else: 598 device_num_tensor = Tensor(device_num, grads.dtype) 599 new_grads = grads / device_num_tensor 600 return new_grads 601