• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""load tensor and combine tensor"""
16from __future__ import division
17from __future__ import absolute_import
18
19import copy
20import numpy as np
21from mindspore.common.tensor import Tensor
22from mindspore.communication.management import get_rank, get_group_size
23from mindspore._c_expression import TensorTransform
24
25_tensor_transform = TensorTransform.get_instance()
26
27
28def _get_tensor_strategy(dev_mat, tensor_map):
29    """
30    Get split strategy by device arrangement and tensor map.
31
32    Args:
33        dev_mat (list): The device matrix.
34        tensor_map (list): The map relation between tensor and devices.
35
36    Returns:
37        List, the split strategy with the same size of np_tensor.
38    """
39    tensor_strategy = []
40    for dim in tensor_map:
41        if dim == -1:
42            tensor_strategy.append(1)
43        else:
44            tensor_strategy.append(dev_mat[-dim - 1])
45    return tensor_strategy
46
47
48def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index):
49    """
50    Get the tensor slice index for the local device.
51
52    Args:
53        device_arrangement (list): The device matrix.
54        tensor_strategy (list): The split strategy with the same size of np_tensor.
55        tensor_map (list): The map relation between tensor and devices.
56        rank_index (int): The rank of local device.
57
58    Returns:
59        Integer, the index of the local device for tensor slices.
60    """
61    device_coordinate = _rank_to_coordinate(rank_index, device_arrangement)
62    device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map)
63    tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy)
64    return tensor_slice_index
65
66
67def _rank_to_coordinate(rank_index, device_arrangement):
68    """
69    Convert rank index to device coordinate.
70
71    Args:
72        rank_index (int): The index of the local device.
73        device_arrangement (list): The device matrix.
74
75    Returns:
76        List, the coordinate for local device in the device matrix
77    """
78    dim_len = len(device_arrangement)
79    device_coordinate = np.zeros(dim_len)
80    for i in range(dim_len):
81        size = device_arrangement[dim_len - 1 - i]
82        device_coordinate[dim_len - 1 - i] = rank_index % size
83        rank_index = int(rank_index / size)
84    return device_coordinate
85
86
87def _coordinate_to_rank(device_coordinate, device_arrangement):
88    """
89    Convert device coordinate to rank index.
90
91    Args:
92        device_coordinate (list): The coordinate for local device in the device matrix.
93        device_arrangement (list): The device matrix.
94
95    Returns:
96        Integer, the index of the local device for tensor slices.
97    """
98    rank_index = 0
99    size = 1
100    for i in range(len(device_coordinate)):
101        rank_index += size * device_coordinate[len(device_coordinate) - 1 - i]
102        size *= device_arrangement[len(device_coordinate) - 1 - i]
103    return rank_index
104
105
106def _convert_to_new_device_coordinate(device_coordinate, tensor_map):
107    """
108    Convert device_coordinate according to the tensor map.
109
110    Args:
111        device_coordinate (list): The coordinate for local device in the device matrix.
112        tensor_map (list): The map relation between tensor and devices.
113
114    Returns:
115        List, the converted coordinate.
116    """
117    device_coordinate_new = []
118    for i in range(len(tensor_map)):
119        if tensor_map[len(tensor_map) - 1 - i] != -1:
120            device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 -
121                                                              tensor_map[len(tensor_map) - 1 - i]])
122        else:
123            device_coordinate_new.insert(0, 0)
124    return device_coordinate_new
125
126
127def _chunk_tensor(np_tensor, strategy, depth):
128    """
129    Recursive function to chunk tensor.
130
131    Args:
132        np_tensor (NDarray): The matrix to be split.
133        strategy (list): The split strategy with the same size of np_tensor.
134        depth (int): Recursion depth.
135
136    Returns:
137        NDarray, the splited matrix.
138
139    Raises:
140        ValueError: If np_tensor can not be split by strategy.
141    """
142    output = []
143    axis = len(np_tensor.shape) - depth
144    if np_tensor.shape[axis] % strategy[0] != 0:
145        raise ValueError("np_tensor can not be split by strategy!")
146    ret = list(np.split(np_tensor, strategy[0], axis))
147    if depth == 1:
148        return ret
149    for ret_ in ret:
150        output.extend(
151            _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
152
153    return output
154
155
156def _chunk_tensor_by_strategy(np_tensor, strategy):
157    """
158    Split the input by strategy.
159
160    Args:
161        np_tensor (NDarray): The matrix to be split.
162        strategy (list): The split strategy with the same size of np_tensor.
163
164    Returns:
165        NDarray, the splited matrix.
166
167    Raises:
168        TypeError: If np_tensor is not ndarray
169        ValueError: If the length of np_tensor does not match the length of strategy.
170    """
171    if not isinstance(np_tensor, np.ndarray):
172        raise TypeError("np_tensor should be ndarray!")
173    if len(strategy) != len(np_tensor.shape):
174        raise ValueError("The length of np_tensor does not match the length of strategy!")
175    return _chunk_tensor(np_tensor, strategy, len(strategy))
176
177
178def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
179    """
180    Get the slice index for current slice.
181
182    Args:
183        dev_mat (list): The device matrix of devices.
184        tensor_map (list): The split strategy of tensor.
185        opt_shard_group(string): The group of optimizer shard
186
187    Returns:
188        Integer, the slice index for slice on this device.
189    """
190    rank = get_rank()
191    dev_num = get_group_size()
192    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
193    tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
194    if opt_shard_group:
195        tensor_slice_index += dev_num
196        opt_rank = get_rank(opt_shard_group)
197        tensor_slice_index += opt_rank
198    return tensor_slice_index
199
200
201def _load_tensor(tensor, dev_mat, tensor_map, full_shape=None, rank_id=-1):
202    """
203    Get the tensor slice of the local device by the device matrix and the tensor map
204
205    Args:
206        tensor (Tensor): The tensor to be split.
207        dev_mat (list): The device matrix of devices.
208        tensor_map (list): The split strategy of tensor.
209
210    Returns:
211        numpy.array, the sliced array.
212
213    Examples:
214        >>> tensor = Tensor(np.ones([32, 32]))
215        >>> dev_mat = [2, 4]
216        >>> tensor_map = [1, -1]
217        >>> full_shape = [32, 32]
218        >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape)
219    """
220    if rank_id == -1:
221        rank = get_rank()
222    else:
223        rank = rank_id
224    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
225    tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
226    np_tensor = tensor.asnumpy()
227    if full_shape:
228        np_tensor = np_tensor.reshape(full_shape)
229    np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
230    np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
231    return np_tensor_slice
232
233
234def _load_tensor_by_layout(tensor, layout, rank_id):
235    """
236    Load tensor by layout.
237
238    Args:
239        tensor (Tensor): The input tensor.
240        layout (list): The tensor layout in auto parallel.
241
242    Returns:
243        Tensor, the sliced tensor.
244
245    Raises:
246        TypeError: If layout is not list.
247        ValueError: If the length of layout is not 3.
248    """
249    if not isinstance(layout, tuple):
250        raise TypeError("The layout should be tuple! layout is {}".format(layout))
251    if len(layout) < 7:
252        raise ValueError("The length of layout must be larger than 6! layout is {}".format(layout))
253    dev_mat = layout[0]
254    tensor_map = layout[1]
255    slice_shape = layout[2]
256    if not tensor_map:
257        return tensor
258    uniform_split = layout[4]
259    group = layout[5]
260    full_shape = layout[6]
261    if uniform_split == 0:
262        raise RuntimeError("The load tensor only support uniform split now")
263    tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape, rank_id)
264    if tensor_slice.shape != slice_shape and not group:
265        tensor_slice = tensor_slice.reshape(slice_shape)
266    if group:
267        # get a totally shard tensor slice for parallel optimizer
268        rank = get_rank(group)
269        size = get_group_size(group)
270        if tensor_slice.shape != tuple(slice_shape) and slice_shape:
271            slice_shape_extend = copy.deepcopy(slice_shape)
272            slice_shape_extend[0] = slice_shape[0] * size
273            tensor_slice = tensor_slice.reshape(slice_shape_extend)
274        tensor_slice = np.split(tensor_slice, size)[rank]
275    return Tensor(tensor_slice, tensor.dtype)
276
277
278def _reshape_param_data(param_data, dev_mat, tensor_map):
279    """
280    Combine param slice by the device matrix and the tensor map, used in model parallel scenario.
281
282    Args:
283        param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet.
284        dev_mat (list): The device matrix of devices.
285        tensor_map (list): The split strategy of tensor.
286
287    Returns:
288        Tensor, the combined tensor which with the whole data value.
289
290    Examples:
291        >>> param_data = _allgather_param_net(param_data)
292        >>> dev_mat = [2, 2]
293        >>> tensor_map = [1, 0]
294        >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map)
295    """
296
297    device_count = 1
298    for dim in dev_mat:
299        device_count *= dim
300
301    tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
302    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
303
304    # get the actual number of slices,as: different devices may load the same slice
305    slice_count = 1
306    for dim in tensor_strategy:
307        slice_count *= dim
308
309    # reorder slices and remove duplicates based on device matrix and tensor_map
310    tensor_slices_new = list(range(slice_count))
311    for i in range(device_count):
312        slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)
313        tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i])
314
315    # combine slices to generate complete parameter
316    dim_len = len(tensor_strategy)
317    for i in range(dim_len):
318        ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
319        tensor_slices_new_inner = []
320        for j in range(ele_count):
321            new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
322            for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
323                           (j + 1) * tensor_strategy[dim_len - 1 - i]):
324                new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
325
326            tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
327        tensor_slices_new = tensor_slices_new_inner
328
329    return Tensor(tensor_slices_new[0])
330
331
332def _extract_layout_item(layout_item):
333    dev_matrix = layout_item[0]
334    tensor_map = layout_item[1]
335    opt_shard_step = layout_item[4]
336    opt_shard_size = layout_item[5]
337    if opt_shard_size == -1:
338        opt_shard_size = np.prod(dev_matrix) // opt_shard_step
339    return dev_matrix, tensor_map, opt_shard_step, opt_shard_size
340
341
342def _transform_tensor_by_layout(from_layout, to_layout, device_list, rank_id):
343    """
344    Transform tensor from source layout to the destination layout.
345
346    Args:
347        from_layout (tuple(tuple)): Source tensor layout
348        to_layout (tuple(tuple)): Destination tensor layout
349        device_list (tuple): The rank list of the tensor distributed.
350        rank_id (number): The tensor slice in which rank.
351    Returns:
352        transform operator list.
353    """
354    if not isinstance(from_layout, tuple) or not isinstance(to_layout, tuple):
355        raise TypeError("The layout should be tuple! layout is {} and {}".format(from_layout, to_layout))
356    return _tensor_transform.transform_tensor_sharding(from_layout, to_layout, device_list, rank_id)
357
358
359def _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
360                                     from_tensor_map, to_full_tensor_shape,
361                                     to_dev_matrix, to_tensor_map):
362    """construct from_layout and to_layout to the same device num"""
363    from_full_tensor_shape = list(from_full_tensor_shape)
364    to_full_tensor_shape = list(to_full_tensor_shape)
365    from_dev_matrix = list(from_dev_matrix)
366    from_tensor_map = list(from_tensor_map)
367    to_dev_matrix = list(to_dev_matrix)
368    to_tensor_map = list(to_tensor_map)
369    from_dev_prod = np.prod(from_dev_matrix)
370    to_dev_prod = np.prod(to_dev_matrix)
371    if len(from_full_tensor_shape) != len(from_tensor_map) or len(to_full_tensor_shape) != len(to_tensor_map):
372        raise ValueError("The tensor map dimensions should be equal to tensor shape dimensions, "
373                         "please check strategy file.")
374    if from_dev_prod > to_dev_prod:
375        if from_dev_prod % to_dev_prod != 0:
376            raise ValueError("Cannot transform device_num from {} to {}".format(from_dev_prod, to_dev_prod))
377        repeat_dim_size = from_dev_prod // to_dev_prod
378        to_dev_matrix.insert(0, repeat_dim_size)
379    elif from_dev_prod < to_dev_prod:
380        if to_dev_prod % from_dev_prod != 0:
381            raise ValueError("Cannot transform device_num from {} to {}".format(from_dev_prod, to_dev_prod))
382        repeat_dim_size = to_dev_prod // from_dev_prod
383        from_dev_matrix.insert(0, repeat_dim_size)
384    from_tensor_layout = (from_dev_matrix, from_tensor_map, from_full_tensor_shape)
385    to_tensor_layout = (to_dev_matrix, to_tensor_map, to_full_tensor_shape)
386    return from_tensor_layout, to_tensor_layout
387
388
389def _construct_tensor_layout_for_opt_shard(dev_matrix, tensor_map, opt_shard_step, opt_shard_size,
390                                           origin_full_tensor_shape):
391    """
392    dev_mat = [4, 2, 2]
393    tensor_map = [2, 1, 0]
394    opt_size = 2
395    =>
396    dev_mat = [opt_size, 4, 2, 2] = [2, 4, 2, 2]
397    tensor_map = [2, 3, 1, 0]
398    thus new_strategy = [4, 2, 2, 2]
399    the tensor_shape should reshape to (model_parallel_size, -1, xx, xx)
400    first 4 means the model parallel sharding of data_dim
401    second 2 means the opt sharding of data_dim
402    And the model parallel sharding dim is the right of opt sharding dim, so it would be 0-1-2-3 model parallel sharding
403    then 0-4 optimizer sharding.
404    """
405
406    if opt_shard_step == 0 or opt_shard_size == 0:
407        return dev_matrix, tensor_map, list(origin_full_tensor_shape)
408    tensor_strategy = _get_tensor_strategy(dev_matrix, tensor_map)
409    model_parallel_shard_size = np.prod(tensor_strategy)
410    if model_parallel_shard_size != opt_shard_step:
411        raise ValueError("The optimizer sharding step {} is not equal to the model parallel sharding size {}.".
412                         format(opt_shard_step, model_parallel_shard_size))
413
414    first_dim_no_sharding_size = origin_full_tensor_shape[0] // tensor_strategy[0]
415    full_tensor_shape = list(origin_full_tensor_shape)
416    full_tensor_shape[0] = tensor_strategy[0]
417    full_tensor_shape.insert(1, first_dim_no_sharding_size)
418    new_dev_matrix = tensor_strategy
419    repeat_dim = np.prod(dev_matrix) // (opt_shard_step * opt_shard_size)
420
421    new_tensor_map = []
422    for idx, val in enumerate(tensor_strategy):
423        if val == 1:
424            new_tensor_map.append(-1)
425        else:
426            new_tensor_map.append(len(tensor_strategy) - 1 - idx)
427    new_tensor_map.insert(1, len(tensor_strategy))
428    new_dev_matrix.insert(0, opt_shard_size)
429    if repeat_dim > 1:
430        new_dev_matrix.insert(0, repeat_dim)
431    return new_dev_matrix, new_tensor_map, full_tensor_shape
432
433
434def _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, self_rank):
435    """
436    AllGather op: {op_name, group_ranks + axis}
437    """
438    result_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, device_list,
439                                                                    self_rank)
440    result_list = list(result_map.keys())
441    result_list.sort()
442    return result_list
443
444
445def _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, device_list, self_rank):
446    """
447    AllGather op: {op_name, group_ranks + axis}
448    """
449    stack = []
450    index = 0
451    transform_operators = _transform_tensor_by_layout(from_tensor_layout, to_tensor_layout, device_list, self_rank)
452    result_map = {self_rank: transform_operators}
453    for operators in transform_operators:
454        op_name = operators[0]
455        if op_name == "AllGather":
456            groups = operators[1][:-1]
457            stack.append((index, groups))
458            index += 1
459    while stack:
460        group_info = stack.pop()
461        for rank in group_info[1]:
462            if rank not in result_map:
463                new_transform_operators = _transform_tensor_by_layout(from_tensor_layout, to_tensor_layout,
464                                                                      device_list, rank)
465                result_map[rank] = new_transform_operators
466                index = 0
467                for operators in new_transform_operators:
468                    op_name = operators[0]
469                    if op_name == "AllGather" and index < group_info[0]:
470                        groups = operators[1][:-1]
471                        stack.insert(0, (index, groups))
472                        index += 1
473    return result_map
474
475
476def _generate_transform_operator_stack(transform_operators_map, self_rank):
477    """
478    return (rank_id, index, operator)
479    """
480    if self_rank not in transform_operators_map:
481        raise ValueError("The transform operators of rank id {} is required.".format(self_rank))
482    if not transform_operators_map[self_rank]:
483        return []
484    init_level = len(transform_operators_map[self_rank]) - 1
485    handle_queue = [(self_rank, init_level, transform_operators_map[self_rank][init_level])]
486    result_queue = []
487    while handle_queue:
488        queue_front = handle_queue.pop(0)
489        result_queue.append(queue_front)
490        current_rank_id = queue_front[0]
491        level = queue_front[1]
492        current_operator = queue_front[2]
493        if level >= 1:
494            if current_operator[0] == "AllGather":
495                current_group = current_operator[1][:-1]
496                for rank_id in current_group:
497                    handle_queue.append((rank_id, level - 1, transform_operators_map[rank_id][level - 1]))
498            else:
499                handle_queue.append((current_rank_id, level - 1, transform_operators_map[current_rank_id][level - 1]))
500    return result_queue
501
502
503def _apply_tensor_transform_operators(transform_operator_stack, tensor_dict, device_num):
504    """
505    transform_operator_stack: [...(rank_id, index, operator)]
506    """
507    if not transform_operator_stack:
508        return
509    level = transform_operator_stack[-1][1]
510    level_operators = []
511    while True:
512        if not transform_operator_stack or (level != transform_operator_stack[-1][1]):
513            tmp_tensor_dict = {}
514            if not level_operators:
515                continue
516            op_name = level_operators[0][2][0]
517            for operator_pair in level_operators:
518                rank_id = operator_pair[0]
519                if rank_id % device_num not in tensor_dict:
520                    raise ValueError("The checkpoint file of rank {} is missing.".format(rank_id % device_num))
521                cur_level = operator_pair[1]
522                operator = operator_pair[2]
523                if operator[0] != op_name:
524                    raise ValueError("The operator in the same level should be equal in the transform tensor operator "
525                                     "list, but the find {} and {} in level {}".format(op_name, operator[0], cur_level))
526                if operator[0] != "AllGather":
527                    tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(tensor_dict[rank_id % device_num],
528                                                                                     operator)
529                    continue
530                for rank in operator[1][:-1]:
531                    if rank % device_num not in tensor_dict:
532                        raise ValueError("The checkpoint file of rank {} is missing.".format(rank % device_num))
533                allgather_list = [tensor_dict[rank % device_num] for rank in operator[1][:-1]]
534                tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
535            if op_name == "AllGather":
536                for rank, value in tmp_tensor_dict.items():
537                    tensor_dict[rank % device_num] = value
538            level_operators.clear()
539        if not transform_operator_stack:
540            break
541        operator_pair = transform_operator_stack.pop()
542        level = operator_pair[1]
543        level_operators.append(operator_pair)
544
545
546def _check_operator(operator):
547    if not isinstance(operator, tuple):
548        raise TypeError("The operator should be a list.")
549    if len(operator) != 2:
550        raise TypeError("The operator should contains 2 item.")
551    if not isinstance(operator[1], list):
552        raise TypeError("The operator[1] should be list.")
553
554
555def _apply_operator(operator_name):
556    """apply transform operator"""
557
558    def _apply_reshape_operator(numpy_data, reshape_op):
559        """
560        Apply reshape operator.
561
562        Args:
563            numpy_data (numpy.ndarray): The data of tensor to apply operator.
564            reshape_op (tuple): reshape operator information, the second item is the destination shape.
565        Returns:
566            The data of tensor after apply operator.
567        """
568        if not isinstance(numpy_data, np.ndarray):
569            raise TypeError("The data should be a numpy.ndarray.")
570        _check_operator(reshape_op)
571        return np.reshape(numpy_data, reshape_op[1])
572
573    def _apply_allconcat_operator(numpy_data_list, allgather_op):
574        """
575        Apply allconcat operator.
576
577        Args:
578            numpy_data (numpy.ndarray): The data of tensor to apply operator.
579            allgather_op (tuple): allgather operator information.
580              the second item is the allgather info, contains group and axis.
581        Returns:
582            The data of tensor after apply operator.
583        """
584        if not isinstance(numpy_data_list, list):
585            raise TypeError("The data_list should be a list.")
586        for numpy_data in numpy_data_list:
587            if not isinstance(numpy_data, np.ndarray):
588                raise TypeError("The data should be a numpy.ndarray.")
589        _check_operator(allgather_op)
590        concat_group = allgather_op[1][:-1]
591        if len(concat_group) != len(numpy_data_list):
592            raise ValueError("The length of data_list {} should be equal to concat_group size {}".
593                             format(len(numpy_data_list), len(concat_group)))
594        concat_axis = allgather_op[1][-1]
595        return np.concatenate(numpy_data_list, concat_axis)
596
597    def _apply_slice_operator(numpy_data, slice_op):
598        """
599        Apply reshape operator.
600
601        Args:
602            numpy_data (numpy.ndarray): The data of tensor to apply operator.
603            slice_op (tuple): slice operator information, the second item is the slice information.
604        Returns:
605            The data of tensor after apply operator.
606        """
607        if not isinstance(numpy_data, np.ndarray):
608            raise TypeError("The data should be a numpy.ndarray.")
609        _check_operator(slice_op)
610        if len(slice_op[1]) % 3 != 0:
611            raise ValueError("The slice operator information is wrong.")
612        shape_size = len(slice_op[1]) // 3
613        begin = slice_op[1][:shape_size]
614        end = slice_op[1][shape_size:shape_size * 2]
615        stride = slice_op[1][shape_size * 2:]
616        slice_index = []
617        for begin_i, end_i, strides_i in zip(begin, end, stride):
618            s = slice(begin_i, end_i, strides_i)
619            slice_index.append(s)
620        slice_index = tuple(slice_index)
621        return numpy_data[slice_index]
622
623    _apply_operator_map = {"Reshape": _apply_reshape_operator, "StridedSlice": _apply_slice_operator,
624                           "AllGather": _apply_allconcat_operator}
625    return _apply_operator_map.get(operator_name)
626
627
628def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
629    """
630    Combine param slice by the device matrix, used in model parallel scenario.
631
632    Args:
633        param_data (Tensor): The tensor to be reshaped and rearrangement,
634        generated from all the device from AllGatherParamNet.
635        dev_mat (list): The device matrix of devices.
636    Returns:
637        Tensor, the combined tensor which with the whole data value.
638
639    Examples:
640        >>> param_data = _allgather_param_net(param_data)
641        >>> dev_mat = [2, 2]
642        >>> field_size = [39]
643        >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
644    """
645    device_count = 1
646    for dim in dev_mat:
647        device_count *= dim
648
649    tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
650    tensor_slices_col = []
651    for i in range(len(tensor_slices[0][0])):
652        tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
653        for j in range(1, device_count):
654            tensor_slices_new = np.concatenate((tensor_slices_new, \
655                                                np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
656        tensor_slices_col.append(tensor_slices_new)
657    new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
658    for i in range(1, len(tensor_slices_col)):
659        new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
660    return Tensor(new_tensor)
661