1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""load tensor and combine tensor""" 16import numpy as np 17 18from mindspore.common.tensor import Tensor 19from ..communication.management import get_rank, get_group_size 20 21 22def _get_tensor_strategy(dev_mat, tensor_map): 23 """ 24 Get split strategy by device arrangement and tensor map. 25 26 Args: 27 dev_mat (list): The device matrix. 28 tensor_map (list): The map relation between tensor and devices. 29 30 Returns: 31 List, the split strategy with the same size of np_tensor. 32 """ 33 tensor_strategy = [] 34 for dim in tensor_map: 35 if dim == -1: 36 tensor_strategy.append(1) 37 else: 38 tensor_strategy.append(dev_mat[-dim-1]) 39 return tensor_strategy 40 41 42def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index): 43 """ 44 Get the tensor slice index for the local device. 45 46 Args: 47 device_arrangement (list): The device matrix. 48 tensor_strategy (list): The split strategy with the same size of np_tensor. 49 tensor_map (list): The map relation between tensor and devices. 50 rank_index (int): The rank of local device. 51 52 Returns: 53 Integer, the index of the local device for tensor slices. 54 """ 55 device_coordinate = _rank_to_coordinate(rank_index, device_arrangement) 56 device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map) 57 tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy) 58 return tensor_slice_index 59 60 61def _rank_to_coordinate(rank_index, device_arrangement): 62 """ 63 Convert rank index to device coordinate. 64 65 Args: 66 rank_index (int): The index of the local device. 67 device_arrangement (list): The device matrix. 68 69 Returns: 70 List, the coordinate for local device in the device matrix 71 """ 72 dim_len = len(device_arrangement) 73 device_coordinate = np.zeros(dim_len) 74 for i in range(dim_len): 75 size = device_arrangement[dim_len - 1 - i] 76 device_coordinate[dim_len - 1 - i] = rank_index % size 77 rank_index = int(rank_index / size) 78 return device_coordinate 79 80 81def _coordinate_to_rank(device_coordinate, device_arrangement): 82 """ 83 Convert device coordinate to rank index. 84 85 Args: 86 device_coordinate (list): The coordinate for local device in the device matrix. 87 device_arrangement (list): The device matrix. 88 89 Returns: 90 Integer, the index of the local device for tensor slices. 91 """ 92 rank_index = 0 93 size = 1 94 for i in range(len(device_coordinate)): 95 rank_index += size * device_coordinate[len(device_coordinate) - 1 - i] 96 size *= device_arrangement[len(device_coordinate) - 1 - i] 97 return rank_index 98 99 100def _convert_to_new_device_coordinate(device_coordinate, tensor_map): 101 """ 102 Convert device_coordinate according to the tensor map. 103 104 Args: 105 device_coordinate (list): The coordinate for local device in the device matrix. 106 tensor_map (list): The map relation between tensor and devices. 107 108 Returns: 109 List, the converted coordinate. 110 """ 111 device_coordinate_new = [] 112 for i in range(len(tensor_map)): 113 if tensor_map[len(tensor_map) - 1 - i] != -1: 114 device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 - 115 tensor_map[len(tensor_map) - 1 - i]]) 116 else: 117 device_coordinate_new.insert(0, 0) 118 return device_coordinate_new 119 120 121def _chunk_tensor(np_tensor, strategy, depth): 122 """ 123 Recursive function to chunk tensor. 124 125 Args: 126 np_tensor (NDarray): The matrix to be split. 127 strategy (list): The split strategy with the same size of np_tensor. 128 depth (int): Recursion depth. 129 130 Returns: 131 NDarray, the splited matrix. 132 133 Raises: 134 ValueError: If np_tensor can not be split by strategy. 135 """ 136 output = [] 137 axis = len(np_tensor.shape) - depth 138 if np_tensor.shape[axis] % strategy[0] != 0: 139 raise ValueError("np_tensor can not be split by strategy!") 140 ret = list(np.split(np_tensor, strategy[0], axis)) 141 if depth == 1: 142 return ret 143 for ret_ in ret: 144 output.extend( 145 _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1)) 146 147 return output 148 149 150def _chunk_tensor_by_strategy(np_tensor, strategy): 151 """ 152 Split the input by strategy. 153 154 Args: 155 np_tensor (NDarray): The matrix to be split. 156 strategy (list): The split strategy with the same size of np_tensor. 157 158 Returns: 159 NDarray, the splited matrix. 160 161 Raises: 162 TypeError: If np_tensor is not ndarray 163 ValueError: If the length of np_tensor does not match the length of strategy. 164 """ 165 if not isinstance(np_tensor, np.ndarray): 166 raise TypeError("np_tensor should be ndarray!") 167 if len(strategy) != len(np_tensor.shape): 168 raise ValueError("The length of np_tensor does not match the length of strategy!") 169 return _chunk_tensor(np_tensor, strategy, len(strategy)) 170 171 172def _get_slice_index(dev_mat, tensor_map): 173 """ 174 Get the slice index for current slice. 175 176 Args: 177 dev_mat (list): The device matrix of devices. 178 tensor_map (list): The split strategy of tensor. 179 180 Returns: 181 Integer, the slice index for slice on this device. 182 """ 183 rank = get_rank() 184 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 185 tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) 186 return tensor_slice_index 187 188 189def _load_tensor(tensor, dev_mat, tensor_map): 190 """ 191 Get the tensor slice of the local device by the device matrix and the tensor map 192 193 Args: 194 tensor (Tensor): The tensor to be split. 195 dev_mat (list): The device matrix of devices. 196 tensor_map (list): The split strategy of tensor. 197 198 Returns: 199 numpy.array, the sliced array. 200 201 Examples: 202 >>> tensor = Tensor(np.ones([32, 32])) 203 >>> dev_mat = [2, 4] 204 >>> tensor_map = [1, -1] 205 >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) 206 """ 207 rank = get_rank() 208 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 209 tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) 210 np_tensor = tensor.asnumpy() 211 np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy) 212 np_tensor_slice = np_tensor_list[int(tensor_slice_index)] 213 return np_tensor_slice 214 215 216def _load_tensor_by_layout(tensor, layout): 217 """ 218 Load tensor by layout. 219 220 Args: 221 tensor (Tensor): The input tensor. 222 layout (list): The tensor layout in auto parallel. 223 224 Returns: 225 Tensor, the sliced tensor. 226 227 Raises: 228 TypeError: If layout is not list. 229 ValueError: If the length of layout is not 3. 230 """ 231 if not isinstance(layout, tuple): 232 raise TypeError("The layout should be tuple! layout is {}".format(layout)) 233 if len(layout) < 6: 234 raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout)) 235 dev_mat = layout[0] 236 tensor_map = layout[1] 237 uniform_split = layout[4] 238 group = layout[5] 239 if uniform_split == 0: 240 raise RuntimeError("The load tensor only support uniform split now") 241 if tensor.size == 1: 242 return tensor 243 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) 244 if group: 245 # get a totally shard tensor slice for parallel optimizer 246 rank = get_rank(group) 247 size = get_group_size(group) 248 tensor_slice = np.split(tensor_slice, size)[rank] 249 return Tensor(tensor_slice) 250 251 252def _reshape_param_data(param_data, dev_mat, tensor_map): 253 """ 254 Combine param slice by the device matrix and the tensor map, used in model parallel scenario. 255 256 Args: 257 param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet. 258 dev_mat (list): The device matrix of devices. 259 tensor_map (list): The split strategy of tensor. 260 261 Returns: 262 Tensor, the combined tensor which with the whole data value. 263 264 Examples: 265 >>> param_data = _allgather_param_net(param_data) 266 >>> dev_mat = [2, 2] 267 >>> tensor_map = [1, 0] 268 >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map) 269 """ 270 271 device_count = 1 272 for dim in dev_mat: 273 device_count *= dim 274 275 tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) 276 tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) 277 278 # get the actual number of slices,as: different devices may load the same slice 279 slice_count = 1 280 for dim in tensor_strategy: 281 slice_count *= dim 282 283 # reorder slices and remove duplicates based on device matrix and tensor_map 284 tensor_slices_new = list(range(slice_count)) 285 for i in range(device_count): 286 slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i) 287 tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i]) 288 289 # combine slices to generate complete parameter 290 dim_len = len(tensor_strategy) 291 for i in range(dim_len): 292 ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) 293 tensor_slices_new_inner = [] 294 for j in range(ele_count): 295 new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] 296 for l in range(j * tensor_strategy[dim_len - 1 - i] + 1, 297 (j + 1) * tensor_strategy[dim_len - 1 - i]): 298 new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i) 299 300 tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) 301 tensor_slices_new = tensor_slices_new_inner 302 303 return Tensor(tensor_slices_new[0]) 304 305 306def _reshape_param_data_with_weight(param_data, dev_mat, field_size): 307 """ 308 Combine param slice by the device matrix, used in model parallel scenario. 309 310 Args: 311 param_data (Tensor): The tensor to be reshaped and rearrangement, 312 generated from all the device from AllGatherParamNet. 313 dev_mat (list): The device matrix of devices. 314 Returns: 315 Tensor, the combined tensor which with the whole data value. 316 317 Examples: 318 >>> param_data = _allgather_param_net(param_data) 319 >>> dev_mat = [2, 2] 320 >>> field_size = [39] 321 >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size) 322 """ 323 device_count = 1 324 for dim in dev_mat: 325 device_count *= dim 326 327 tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) 328 tensor_slices_col = [] 329 for i in range(len(tensor_slices[0][0])): 330 tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1) 331 for j in range(1, device_count): 332 tensor_slices_new = np.concatenate((tensor_slices_new,\ 333 np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1) 334 tensor_slices_col.append(tensor_slices_new) 335 new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) 336 for i in range(1, len(tensor_slices_col)): 337 new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1) 338 return Tensor(new_tensor) 339