• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""load tensor and combine tensor"""
16import numpy as np
17
18from mindspore.common.tensor import Tensor
19from ..communication.management import get_rank, get_group_size
20
21
22def _get_tensor_strategy(dev_mat, tensor_map):
23    """
24    Get split strategy by device arrangement and tensor map.
25
26    Args:
27        dev_mat (list): The device matrix.
28        tensor_map (list): The map relation between tensor and devices.
29
30    Returns:
31        List, the split strategy with the same size of np_tensor.
32    """
33    tensor_strategy = []
34    for dim in tensor_map:
35        if dim == -1:
36            tensor_strategy.append(1)
37        else:
38            tensor_strategy.append(dev_mat[-dim-1])
39    return tensor_strategy
40
41
42def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index):
43    """
44    Get the tensor slice index for the local device.
45
46    Args:
47        device_arrangement (list): The device matrix.
48        tensor_strategy (list): The split strategy with the same size of np_tensor.
49        tensor_map (list): The map relation between tensor and devices.
50        rank_index (int): The rank of local device.
51
52    Returns:
53        Integer, the index of the local device for tensor slices.
54    """
55    device_coordinate = _rank_to_coordinate(rank_index, device_arrangement)
56    device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map)
57    tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy)
58    return tensor_slice_index
59
60
61def _rank_to_coordinate(rank_index, device_arrangement):
62    """
63    Convert rank index to device coordinate.
64
65    Args:
66        rank_index (int): The index of the local device.
67        device_arrangement (list): The device matrix.
68
69    Returns:
70        List, the coordinate for local device in the device matrix
71    """
72    dim_len = len(device_arrangement)
73    device_coordinate = np.zeros(dim_len)
74    for i in range(dim_len):
75        size = device_arrangement[dim_len - 1 - i]
76        device_coordinate[dim_len - 1 - i] = rank_index % size
77        rank_index = int(rank_index / size)
78    return device_coordinate
79
80
81def _coordinate_to_rank(device_coordinate, device_arrangement):
82    """
83    Convert device coordinate to rank index.
84
85    Args:
86        device_coordinate (list): The coordinate for local device in the device matrix.
87        device_arrangement (list): The device matrix.
88
89    Returns:
90        Integer, the index of the local device for tensor slices.
91    """
92    rank_index = 0
93    size = 1
94    for i in range(len(device_coordinate)):
95        rank_index += size * device_coordinate[len(device_coordinate) - 1 - i]
96        size *= device_arrangement[len(device_coordinate) - 1 - i]
97    return rank_index
98
99
100def _convert_to_new_device_coordinate(device_coordinate, tensor_map):
101    """
102    Convert device_coordinate according to the tensor map.
103
104    Args:
105        device_coordinate (list): The coordinate for local device in the device matrix.
106        tensor_map (list): The map relation between tensor and devices.
107
108    Returns:
109        List, the converted coordinate.
110    """
111    device_coordinate_new = []
112    for i in range(len(tensor_map)):
113        if tensor_map[len(tensor_map) - 1 - i] != -1:
114            device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 -
115                                                              tensor_map[len(tensor_map) - 1 - i]])
116        else:
117            device_coordinate_new.insert(0, 0)
118    return device_coordinate_new
119
120
121def _chunk_tensor(np_tensor, strategy, depth):
122    """
123    Recursive function to chunk tensor.
124
125    Args:
126        np_tensor (NDarray): The matrix to be split.
127        strategy (list): The split strategy with the same size of np_tensor.
128        depth (int): Recursion depth.
129
130    Returns:
131        NDarray, the splited matrix.
132
133    Raises:
134        ValueError: If np_tensor can not be split by strategy.
135    """
136    output = []
137    axis = len(np_tensor.shape) - depth
138    if np_tensor.shape[axis] % strategy[0] != 0:
139        raise ValueError("np_tensor can not be split by strategy!")
140    ret = list(np.split(np_tensor, strategy[0], axis))
141    if depth == 1:
142        return ret
143    for ret_ in ret:
144        output.extend(
145            _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
146
147    return output
148
149
150def _chunk_tensor_by_strategy(np_tensor, strategy):
151    """
152    Split the input by strategy.
153
154    Args:
155        np_tensor (NDarray): The matrix to be split.
156        strategy (list): The split strategy with the same size of np_tensor.
157
158    Returns:
159        NDarray, the splited matrix.
160
161    Raises:
162        TypeError: If np_tensor is not ndarray
163        ValueError: If the length of np_tensor does not match the length of strategy.
164    """
165    if not isinstance(np_tensor, np.ndarray):
166        raise TypeError("np_tensor should be ndarray!")
167    if len(strategy) != len(np_tensor.shape):
168        raise ValueError("The length of np_tensor does not match the length of strategy!")
169    return _chunk_tensor(np_tensor, strategy, len(strategy))
170
171
172def _get_slice_index(dev_mat, tensor_map):
173    """
174    Get the slice index for current slice.
175
176    Args:
177        dev_mat (list): The device matrix of devices.
178        tensor_map (list): The split strategy of tensor.
179
180    Returns:
181        Integer, the slice index for slice on this device.
182    """
183    rank = get_rank()
184    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
185    tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
186    return tensor_slice_index
187
188
189def _load_tensor(tensor, dev_mat, tensor_map):
190    """
191    Get the tensor slice of the local device by the device matrix and the tensor map
192
193    Args:
194        tensor (Tensor): The tensor to be split.
195        dev_mat (list): The device matrix of devices.
196        tensor_map (list): The split strategy of tensor.
197
198    Returns:
199        numpy.array, the sliced array.
200
201    Examples:
202        >>> tensor = Tensor(np.ones([32, 32]))
203        >>> dev_mat = [2, 4]
204        >>> tensor_map = [1, -1]
205        >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
206    """
207    rank = get_rank()
208    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
209    tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
210    np_tensor = tensor.asnumpy()
211    np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
212    np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
213    return np_tensor_slice
214
215
216def _load_tensor_by_layout(tensor, layout):
217    """
218    Load tensor by layout.
219
220    Args:
221        tensor (Tensor): The input tensor.
222        layout (list): The tensor layout in auto parallel.
223
224    Returns:
225        Tensor, the sliced tensor.
226
227    Raises:
228        TypeError: If layout is not list.
229        ValueError: If the length of layout is not 3.
230    """
231    if not isinstance(layout, tuple):
232        raise TypeError("The layout should be tuple! layout is {}".format(layout))
233    if len(layout) < 6:
234        raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
235    dev_mat = layout[0]
236    tensor_map = layout[1]
237    uniform_split = layout[4]
238    group = layout[5]
239    if uniform_split == 0:
240        raise RuntimeError("The load tensor only support uniform split now")
241    if tensor.size == 1:
242        return tensor
243    tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
244    if group:
245        # get a totally shard tensor slice for parallel optimizer
246        rank = get_rank(group)
247        size = get_group_size(group)
248        tensor_slice = np.split(tensor_slice, size)[rank]
249    return Tensor(tensor_slice)
250
251
252def _reshape_param_data(param_data, dev_mat, tensor_map):
253    """
254    Combine param slice by the device matrix and the tensor map, used in model parallel scenario.
255
256    Args:
257        param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet.
258        dev_mat (list): The device matrix of devices.
259        tensor_map (list): The split strategy of tensor.
260
261    Returns:
262        Tensor, the combined tensor which with the whole data value.
263
264    Examples:
265        >>> param_data = _allgather_param_net(param_data)
266        >>> dev_mat = [2, 2]
267        >>> tensor_map = [1, 0]
268        >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map)
269    """
270
271    device_count = 1
272    for dim in dev_mat:
273        device_count *= dim
274
275    tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
276    tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
277
278    # get the actual number of slices,as: different devices may load the same slice
279    slice_count = 1
280    for dim in tensor_strategy:
281        slice_count *= dim
282
283    # reorder slices and remove duplicates based on device matrix and tensor_map
284    tensor_slices_new = list(range(slice_count))
285    for i in range(device_count):
286        slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)
287        tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i])
288
289    # combine slices to generate complete parameter
290    dim_len = len(tensor_strategy)
291    for i in range(dim_len):
292        ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
293        tensor_slices_new_inner = []
294        for j in range(ele_count):
295            new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
296            for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
297                           (j + 1) * tensor_strategy[dim_len - 1 - i]):
298                new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
299
300            tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
301        tensor_slices_new = tensor_slices_new_inner
302
303    return Tensor(tensor_slices_new[0])
304
305
306def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
307    """
308    Combine param slice by the device matrix, used in model parallel scenario.
309
310    Args:
311        param_data (Tensor): The tensor to be reshaped and rearrangement,
312        generated from all the device from AllGatherParamNet.
313        dev_mat (list): The device matrix of devices.
314    Returns:
315        Tensor, the combined tensor which with the whole data value.
316
317    Examples:
318        >>> param_data = _allgather_param_net(param_data)
319        >>> dev_mat = [2, 2]
320        >>> field_size = [39]
321        >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
322    """
323    device_count = 1
324    for dim in dev_mat:
325        device_count *= dim
326
327    tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
328    tensor_slices_col = []
329    for i in range(len(tensor_slices[0][0])):
330        tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
331        for j in range(1, device_count):
332            tensor_slices_new = np.concatenate((tensor_slices_new,\
333                                   np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
334        tensor_slices_col.append(tensor_slices_new)
335    new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
336    for i in range(1, len(tensor_slices_col)):
337        new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
338    return Tensor(new_tensor)
339