1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""load tensor and combine tensor""" 16from __future__ import division 17from __future__ import absolute_import 18 19import copy 20import numpy as np 21from mindspore.common.tensor import Tensor 22from mindspore.communication.management import get_rank, get_group_size 23from mindspore._c_expression import TensorTransform 24 25_tensor_transform = TensorTransform.get_instance() 26 27 28def _get_tensor_strategy(dev_mat, tensor_map): 29 """ 30 Get split strategy by device arrangement and tensor map. 31 32 Args: 33 dev_mat (list): The device matrix. 34 tensor_map (list): The map relation between tensor and devices. 35 36 Returns: 37 List, the split strategy with the same size of np_tensor. 38 """ 39 tensor_strategy = [] 40 for dim in tensor_map: 41 if dim == -1: 42 tensor_strategy.append(1) 43 else: 44 tensor_strategy.append(dev_mat[-dim - 1]) 45 return tensor_strategy 46 47 48def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index): 49 """ 50 Get the tensor slice index for the local device. 51 52 Args: 53 device_arrangement (list): The device matrix. 54 tensor_strategy (list): The split strategy with the same size of np_tensor. 55 tensor_map (list): The map relation between tensor and devices. 56 rank_index (int): The rank of local device. 57 58 Returns: 59 Integer, the index of the local device for tensor slices. 60 """ 61 device_coordinate = _rank_to_coordinate(rank_index, device_arrangement) 62 device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map) 63 tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy) 64 return tensor_slice_index 65 66 67def _rank_to_coordinate(rank_index, device_arrangement): 68 """ 69 Convert rank index to device coordinate. 70 71 Args: 72 rank_index (int): The index of the local device. 73 device_arrangement (list): The device matrix. 74 75 Returns: 76 List, the coordinate for local device in the device matrix 77 """ 78 dim_len = len(device_arrangement) 79 device_coordinate = np.zeros(dim_len) 80 for i in range(dim_len): 81 size = device_arrangement[dim_len - 1 - i] 82 device_coordinate[dim_len - 1 - i] = rank_index % size 83 rank_index = int(rank_index / size) 84 return device_coordinate 85 86 87def _coordinate_to_rank(device_coordinate, device_arrangement): 88 """ 89 Convert device coordinate to rank index. 90 91 Args: 92 device_coordinate (list): The coordinate for local device in the device matrix. 93 device_arrangement (list): The device matrix. 94 95 Returns: 96 Integer, the index of the local device for tensor slices. 97 """ 98 rank_index = 0 99 size = 1 100 for i in range(len(device_coordinate)): 101 rank_index += size * device_coordinate[len(device_coordinate) - 1 - i] 102 size *= device_arrangement[len(device_coordinate) - 1 - i] 103 return rank_index 104 105 106def _convert_to_new_device_coordinate(device_coordinate, tensor_map): 107 """ 108 Convert device_coordinate according to the tensor map. 109 110 Args: 111 device_coordinate (list): The coordinate for local device in the device matrix. 112 tensor_map (list): The map relation between tensor and devices. 113 114 Returns: 115 List, the converted coordinate. 116 """ 117 device_coordinate_new = [] 118 for i in range(len(tensor_map)): 119 if tensor_map[len(tensor_map) - 1 - i] != -1: 120 device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 - 121 tensor_map[len(tensor_map) - 1 - i]]) 122 else: 123 device_coordinate_new.insert(0, 0) 124 return device_coordinate_new 125 126 127def _chunk_tensor(np_tensor, strategy, depth): 128 """ 129 Recursive function to chunk tensor. 130 131 Args: 132 np_tensor (NDarray): The matrix to be split. 133 strategy (list): The split strategy with the same size of np_tensor. 134 depth (int): Recursion depth. 135 136 Returns: 137 NDarray, the splited matrix. 138 139 Raises: 140 ValueError: If np_tensor can not be split by strategy. 141 """ 142 output = [] 143 axis = len(np_tensor.shape) - depth 144 if np_tensor.shape[axis] % strategy[0] != 0: 145 raise ValueError("np_tensor can not be split by strategy!") 146 ret = list(np.split(np_tensor, strategy[0], axis)) 147 if depth == 1: 148 return ret 149 for ret_ in ret: 150 output.extend( 151 _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1)) 152 153 return output 154 155 156def _chunk_tensor_by_strategy(np_tensor, strategy): 157 """ 158 Split the input by strategy. 159 160 Args: 161 np_tensor (NDarray): The matrix to be split. 162 strategy (list): The split strategy with the same size of np_tensor. 163 164 Returns: 165 NDarray, the splited matrix. 166 167 Raises: 168 TypeError: If np_tensor is not ndarray 169 ValueError: If the length of np_tensor does not match the length of strategy. 170 """ 171 if not isinstance(np_tensor, np.ndarray): 172 raise TypeError("np_tensor should be ndarray!") 173 if len(strategy) != len(np_tensor.shape): 174 raise ValueError("The length of np_tensor does not match the length of strategy!") 175 return _chunk_tensor(np_tensor, strategy, len(strategy)) 176 177 178def _get_slice_index(dev_mat, tensor_map, opt_shard_group): 179 """ 180 Get the slice index for current slice. 181 182 Args: 183 dev_mat (list): The device matrix of devices. 184 tensor_map (list): The split strategy of tensor. 185 opt_shard_group(string): The group of optimizer shard 186 187 Returns: 188 Integer, the slice index for slice on this device. 189 """ 190 rank = get_rank() 191 dev_num = get_group_size() 192 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 193 tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) 194 if opt_shard_group: 195 tensor_slice_index += dev_num 196 opt_rank = get_rank(opt_shard_group) 197 tensor_slice_index += opt_rank 198 return tensor_slice_index 199 200 201def _load_tensor(tensor, dev_mat, tensor_map, full_shape=None, rank_id=-1): 202 """ 203 Get the tensor slice of the local device by the device matrix and the tensor map 204 205 Args: 206 tensor (Tensor): The tensor to be split. 207 dev_mat (list): The device matrix of devices. 208 tensor_map (list): The split strategy of tensor. 209 210 Returns: 211 numpy.array, the sliced array. 212 213 Examples: 214 >>> tensor = Tensor(np.ones([32, 32])) 215 >>> dev_mat = [2, 4] 216 >>> tensor_map = [1, -1] 217 >>> full_shape = [32, 32] 218 >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape) 219 """ 220 if rank_id == -1: 221 rank = get_rank() 222 else: 223 rank = rank_id 224 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 225 tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) 226 np_tensor = tensor.asnumpy() 227 if full_shape: 228 np_tensor = np_tensor.reshape(full_shape) 229 np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy) 230 np_tensor_slice = np_tensor_list[int(tensor_slice_index)] 231 return np_tensor_slice 232 233 234def _load_tensor_by_layout(tensor, layout, rank_id): 235 """ 236 Load tensor by layout. 237 238 Args: 239 tensor (Tensor): The input tensor. 240 layout (list): The tensor layout in auto parallel. 241 242 Returns: 243 Tensor, the sliced tensor. 244 245 Raises: 246 TypeError: If layout is not list. 247 ValueError: If the length of layout is not 3. 248 """ 249 if not isinstance(layout, tuple): 250 raise TypeError("The layout should be tuple! layout is {}".format(layout)) 251 if len(layout) < 7: 252 raise ValueError("The length of layout must be larger than 6! layout is {}".format(layout)) 253 dev_mat = layout[0] 254 tensor_map = layout[1] 255 slice_shape = layout[2] 256 if not tensor_map: 257 return tensor 258 uniform_split = layout[4] 259 group = layout[5] 260 full_shape = layout[6] 261 if uniform_split == 0: 262 raise RuntimeError("The load tensor only support uniform split now") 263 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape, rank_id) 264 if tensor_slice.shape != slice_shape and not group: 265 tensor_slice = tensor_slice.reshape(slice_shape) 266 if group: 267 # get a totally shard tensor slice for parallel optimizer 268 rank = get_rank(group) 269 size = get_group_size(group) 270 if tensor_slice.shape != tuple(slice_shape) and slice_shape: 271 slice_shape_extend = copy.deepcopy(slice_shape) 272 slice_shape_extend[0] = slice_shape[0] * size 273 tensor_slice = tensor_slice.reshape(slice_shape_extend) 274 tensor_slice = np.split(tensor_slice, size)[rank] 275 return Tensor(tensor_slice, tensor.dtype) 276 277 278def _reshape_param_data(param_data, dev_mat, tensor_map): 279 """ 280 Combine param slice by the device matrix and the tensor map, used in model parallel scenario. 281 282 Args: 283 param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet. 284 dev_mat (list): The device matrix of devices. 285 tensor_map (list): The split strategy of tensor. 286 287 Returns: 288 Tensor, the combined tensor which with the whole data value. 289 290 Examples: 291 >>> param_data = _allgather_param_net(param_data) 292 >>> dev_mat = [2, 2] 293 >>> tensor_map = [1, 0] 294 >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map) 295 """ 296 297 device_count = 1 298 for dim in dev_mat: 299 device_count *= dim 300 301 tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) 302 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 303 304 # get the actual number of slices,as: different devices may load the same slice 305 slice_count = 1 306 for dim in tensor_strategy: 307 slice_count *= dim 308 309 # reorder slices and remove duplicates based on device matrix and tensor_map 310 tensor_slices_new = list(range(slice_count)) 311 for i in range(device_count): 312 slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i) 313 tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i]) 314 315 # combine slices to generate complete parameter 316 dim_len = len(tensor_strategy) 317 for i in range(dim_len): 318 ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) 319 tensor_slices_new_inner = [] 320 for j in range(ele_count): 321 new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] 322 for k in range(j * tensor_strategy[dim_len - 1 - i] + 1, 323 (j + 1) * tensor_strategy[dim_len - 1 - i]): 324 new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i) 325 326 tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) 327 tensor_slices_new = tensor_slices_new_inner 328 329 return Tensor(tensor_slices_new[0]) 330 331 332def _extract_layout_item(layout_item): 333 dev_matrix = layout_item[0] 334 tensor_map = layout_item[1] 335 opt_shard_step = layout_item[4] 336 opt_shard_size = layout_item[5] 337 if opt_shard_size == -1: 338 opt_shard_size = np.prod(dev_matrix) // opt_shard_step 339 return dev_matrix, tensor_map, opt_shard_step, opt_shard_size 340 341 342def _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id): 343 """ 344 Transform tensor from source layout to the destination layout. 345 346 Args: 347 from_layout (tuple(tuple)): Source tensor layout 348 to_layout (tuple(tuple)): Destination tensor layout 349 device_list (tuple): The rank list of the tensor distributed. 350 rank_id (number): The tensor slice in which rank. 351 Returns: 352 transform operator list. 353 """ 354 if not isinstance(from_layout, tuple) or not isinstance(to_layout, tuple): 355 raise TypeError("The layout should be tuple! layout is {} and {}".format(from_layout, to_layout)) 356 return _tensor_transform.transform_tensor_sharding(from_layout, to_layout, device_list, rank_id) 357 358 359def _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix, 360 from_tensor_map, to_full_tensor_shape, 361 to_dev_matrix, to_tensor_map): 362 """construct from_layout and to_layout to the same device num""" 363 from_full_tensor_shape = list(from_full_tensor_shape) 364 to_full_tensor_shape = list(to_full_tensor_shape) 365 from_dev_matrix = list(from_dev_matrix) 366 from_tensor_map = list(from_tensor_map) 367 to_dev_matrix = list(to_dev_matrix) 368 to_tensor_map = list(to_tensor_map) 369 from_dev_prod = np.prod(from_dev_matrix) 370 to_dev_prod = np.prod(to_dev_matrix) 371 if len(from_full_tensor_shape) != len(from_tensor_map) or len(to_full_tensor_shape) != len(to_tensor_map): 372 raise ValueError("The tensor map dimensions should be equal to tensor shape dimensions, " 373 "please check strategy file.") 374 if from_dev_prod > to_dev_prod: 375 if from_dev_prod % to_dev_prod != 0: 376 raise ValueError("Cannot transform device_num from {} to {}".format(from_dev_prod, to_dev_prod)) 377 repeat_dim_size = from_dev_prod // to_dev_prod 378 to_dev_matrix.insert(0, repeat_dim_size) 379 elif from_dev_prod < to_dev_prod: 380 if to_dev_prod % from_dev_prod != 0: 381 raise ValueError("Cannot transform device_num from {} to {}".format(from_dev_prod, to_dev_prod)) 382 repeat_dim_size = to_dev_prod // from_dev_prod 383 from_dev_matrix.insert(0, repeat_dim_size) 384 from_tensor_layout = (from_dev_matrix, from_tensor_map, from_full_tensor_shape) 385 to_tensor_layout = (to_dev_matrix, to_tensor_map, to_full_tensor_shape) 386 return from_tensor_layout, to_tensor_layout 387 388 389def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_step, opt_shard_size, 390 origin_full_tensor_shape): 391 """ 392 dev_mat = [4, 2, 2] 393 tensor_map = [2, 1, 0] 394 opt_size = 2 395 => 396 dev_mat = [opt_size, 4, 2, 2] = [2, 4, 2, 2] 397 tensor_map = [2, 3, 1, 0] 398 thus new_strategy = [4, 2, 2, 2] 399 the tensor_shape should reshape to (model_parallel_size, -1, xx, xx) 400 first 4 means the model parallel sharding of data_dim 401 second 2 means the opt sharding of data_dim 402 And the model parallel sharding dim is the right of opt sharding dim, so it would be 0-1-2-3 model parallel sharding 403 then 0-4 optimizer sharding. 404 """ 405 406 if opt_shard_step == 0 or opt_shard_size == 0: 407 return dev_matrix, tensor_map, list(origin_full_tensor_shape) 408 tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map) 409 model_parallel_shard_size = np.prod(tensor_strategy) 410 if model_parallel_shard_size != opt_shard_step: 411 raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.". 412 format(opt_shard_step, model_parallel_shard_size)) 413 414 first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0] 415 full_tensor_shape = list(origin_full_tensor_shape) 416 full_tensor_shape[0] = tensor_strategy[0] 417 full_tensor_shape.insert(1, first_dim_no_sharding_size) 418 new_dev_matrix = tensor_strategy 419 repeat_dim = np.prod(dev_matrix) // (opt_shard_step * opt_shard_size) 420 421 new_tensor_map = [] 422 for idx, val in enumerate(tensor_strategy): 423 if val == 1: 424 new_tensor_map.append(-1) 425 else: 426 new_tensor_map.append(len(tensor_strategy) - 1 - idx) 427 new_tensor_map.insert(1, len(tensor_strategy)) 428 new_dev_matrix.insert(0, opt_shard_size) 429 if repeat_dim > 1: 430 new_dev_matrix.insert(0, repeat_dim) 431 return new_dev_matrix, new_tensor_map, full_tensor_shape 432 433 434def _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, self_rank): 435 """ 436 AllGather op: {op_name, group_ranks + axis} 437 """ 438 result_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, device_list, 439 self_rank) 440 result_list = list(result_map.keys()) 441 result_list.sort() 442 return result_list 443 444 445def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, device_list, self_rank): 446 """ 447 AllGather op: {op_name, group_ranks + axis} 448 """ 449 stack = [] 450 index = 0 451 transform_operators = _transform_tensor_by_layout(from_tensor_layout, to_tensor_layout, device_list, self_rank) 452 result_map = {self_rank: transform_operators} 453 for operators in transform_operators: 454 op_name = operators[0] 455 if op_name == "AllGather": 456 groups = operators[1][:-1] 457 stack.append((index, groups)) 458 index += 1 459 while stack: 460 group_info = stack.pop() 461 for rank in group_info[1]: 462 if rank not in result_map: 463 new_transform_operators = _transform_tensor_by_layout(from_tensor_layout, to_tensor_layout, 464 device_list, rank) 465 result_map[rank] = new_transform_operators 466 index = 0 467 for operators in new_transform_operators: 468 op_name = operators[0] 469 if op_name == "AllGather" and index < group_info[0]: 470 groups = operators[1][:-1] 471 stack.insert(0, (index, groups)) 472 index += 1 473 return result_map 474 475 476def _generate_transform_operator_stack(transform_operators_map, self_rank): 477 """ 478 return (rank_id, index, operator) 479 """ 480 if self_rank not in transform_operators_map: 481 raise ValueError("The transform operators of rank id {} is required.".format(self_rank)) 482 if not transform_operators_map[self_rank]: 483 return [] 484 init_level = len(transform_operators_map[self_rank]) - 1 485 handle_queue = [(self_rank, init_level, transform_operators_map[self_rank][init_level])] 486 result_queue = [] 487 while handle_queue: 488 queue_front = handle_queue.pop(0) 489 result_queue.append(queue_front) 490 current_rank_id = queue_front[0] 491 level = queue_front[1] 492 current_operator = queue_front[2] 493 if level >= 1: 494 if current_operator[0] == "AllGather": 495 current_group = current_operator[1][:-1] 496 for rank_id in current_group: 497 handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1])) 498 else: 499 handle_queue.append((current_rank_id, level - 1, transform_operators_map[current_rank_id][level - 1])) 500 return result_queue 501 502 503def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num): 504 """ 505 transform_operator_stack: [...(rank_id, index, operator)] 506 """ 507 if not transform_operator_stack: 508 return 509 level = transform_operator_stack[-1][1] 510 level_operators = [] 511 while True: 512 if not transform_operator_stack or (level != transform_operator_stack[-1][1]): 513 tmp_tensor_dict = {} 514 if not level_operators: 515 continue 516 op_name = level_operators[0][2][0] 517 for operator_pair in level_operators: 518 rank_id = operator_pair[0] 519 if rank_id % device_num not in tensor_dict: 520 raise ValueError("The checkpoint file of rank {} is missing.".format(rank_id % device_num)) 521 cur_level = operator_pair[1] 522 operator = operator_pair[2] 523 if operator[0] != op_name: 524 raise ValueError("The operator in the same level should be equal in the transform tensor operator " 525 "list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level)) 526 if operator[0] != "AllGather": 527 tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num], 528 operator) 529 continue 530 for rank in operator[1][:-1]: 531 if rank % device_num not in tensor_dict: 532 raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num)) 533 allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]] 534 tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator) 535 if op_name == "AllGather": 536 for rank, value in tmp_tensor_dict.items(): 537 tensor_dict[rank % device_num] = value 538 level_operators.clear() 539 if not transform_operator_stack: 540 break 541 operator_pair = transform_operator_stack.pop() 542 level = operator_pair[1] 543 level_operators.append(operator_pair) 544 545 546def _check_operator(operator): 547 if not isinstance(operator, tuple): 548 raise TypeError("The operator should be a list.") 549 if len(operator) != 2: 550 raise TypeError("The operator should contains 2 item.") 551 if not isinstance(operator[1], list): 552 raise TypeError("The operator[1] should be list.") 553 554 555def _apply_operator(operator_name): 556 """apply transform operator""" 557 558 def _apply_reshape_operator(numpy_data, reshape_op): 559 """ 560 Apply reshape operator. 561 562 Args: 563 numpy_data (numpy.ndarray): The data of tensor to apply operator. 564 reshape_op (tuple): reshape operator information, the second item is the destination shape. 565 Returns: 566 The data of tensor after apply operator. 567 """ 568 if not isinstance(numpy_data, np.ndarray): 569 raise TypeError("The data should be a numpy.ndarray.") 570 _check_operator(reshape_op) 571 return np.reshape(numpy_data, reshape_op[1]) 572 573 def _apply_allconcat_operator(numpy_data_list, allgather_op): 574 """ 575 Apply allconcat operator. 576 577 Args: 578 numpy_data (numpy.ndarray): The data of tensor to apply operator. 579 allgather_op (tuple): allgather operator information. 580 the second item is the allgather info, contains group and axis. 581 Returns: 582 The data of tensor after apply operator. 583 """ 584 if not isinstance(numpy_data_list, list): 585 raise TypeError("The data_list should be a list.") 586 for numpy_data in numpy_data_list: 587 if not isinstance(numpy_data, np.ndarray): 588 raise TypeError("The data should be a numpy.ndarray.") 589 _check_operator(allgather_op) 590 concat_group = allgather_op[1][:-1] 591 if len(concat_group) != len(numpy_data_list): 592 raise ValueError("The length of data_list {} should be equal to concat_group size {}". 593 format(len(numpy_data_list), len(concat_group))) 594 concat_axis = allgather_op[1][-1] 595 return np.concatenate(numpy_data_list, concat_axis) 596 597 def _apply_slice_operator(numpy_data, slice_op): 598 """ 599 Apply reshape operator. 600 601 Args: 602 numpy_data (numpy.ndarray): The data of tensor to apply operator. 603 slice_op (tuple): slice operator information, the second item is the slice information. 604 Returns: 605 The data of tensor after apply operator. 606 """ 607 if not isinstance(numpy_data, np.ndarray): 608 raise TypeError("The data should be a numpy.ndarray.") 609 _check_operator(slice_op) 610 if len(slice_op[1]) % 3 != 0: 611 raise ValueError("The slice operator information is wrong.") 612 shape_size = len(slice_op[1]) // 3 613 begin = slice_op[1][:shape_size] 614 end = slice_op[1][shape_size:shape_size * 2] 615 stride = slice_op[1][shape_size * 2:] 616 slice_index = [] 617 for begin_i, end_i, strides_i in zip(begin, end, stride): 618 s = slice(begin_i, end_i, strides_i) 619 slice_index.append(s) 620 slice_index = tuple(slice_index) 621 return numpy_data[slice_index] 622 623 _apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator, 624 "AllGather": _apply_allconcat_operator} 625 return _apply_operator_map.get(operator_name) 626 627 628def _reshape_param_data_with_weight(param_data, dev_mat, field_size): 629 """ 630 Combine param slice by the device matrix, used in model parallel scenario. 631 632 Args: 633 param_data (Tensor): The tensor to be reshaped and rearrangement, 634 generated from all the device from AllGatherParamNet. 635 dev_mat (list): The device matrix of devices. 636 Returns: 637 Tensor, the combined tensor which with the whole data value. 638 639 Examples: 640 >>> param_data = _allgather_param_net(param_data) 641 >>> dev_mat = [2, 2] 642 >>> field_size = [39] 643 >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) 644 """ 645 device_count = 1 646 for dim in dev_mat: 647 device_count *= dim 648 649 tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) 650 tensor_slices_col = [] 651 for i in range(len(tensor_slices[0][0])): 652 tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1) 653 for j in range(1, device_count): 654 tensor_slices_new = np.concatenate((tensor_slices_new, \ 655 np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1) 656 tensor_slices_col.append(tensor_slices_new) 657 new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) 658 for i in range(1, len(tensor_slices_col)): 659 new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) 660 return Tensor(new_tensor) 661