1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Utils of auto parallel""" 16 17import numpy as np 18from mindspore import context, log as logger 19from mindspore.context import ParallelMode 20from mindspore._c_expression import reset_op_id 21from mindspore.common.tensor import Tensor 22from mindspore.common.dtype import dtype_to_nptype 23from mindspore.common import dtype as mstype 24from mindspore.communication.management import get_group_size, get_rank 25from mindspore.parallel._auto_parallel_context import auto_parallel_context 26from mindspore.common.seed import get_seed 27 28 29def _get_parallel_mode(): 30 """Get parallel mode.""" 31 return auto_parallel_context().get_parallel_mode() 32 33 34def _is_in_auto_parallel_mode(): 35 return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL] 36 37 38def _get_full_batch(): 39 """Get whether to use full_batch.""" 40 return auto_parallel_context().get_full_batch() 41 42 43def _get_pipeline_stages(): 44 """Get pipeline stages""" 45 return auto_parallel_context().get_pipeline_stages() 46 47 48def _check_task_sink_envs(): 49 """ 50 Check whether task_sink environment variables have been exported or not. 51 52 return True if task_sink environment variables have been exported, False otherwise. 53 """ 54 import os 55 task_sink = os.getenv("GRAPH_OP_RUN") 56 if task_sink and task_sink.isdigit() and int(task_sink) == 1: 57 return False 58 return True 59 60 61def _check_full_batch(): 62 """ 63 full_batch could only be used under semi_auto_parallel or auto_parallel, check it. 64 65 Raises: 66 RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel. 67 """ 68 parallel_mode = _get_parallel_mode() 69 full_batch = _get_full_batch() 70 if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch): 71 raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.") 72 73 74def _need_to_full(): 75 """Check whether to convert input to full shape or tensor.""" 76 if _get_parallel_mode() not in ("semi_auto_parallel", "auto_parallel"): 77 return False 78 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 79 if dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"): 80 return True 81 return not _get_full_batch() 82 83 84def _to_full_shapes(shapes, device_num): 85 """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution.""" 86 new_shapes = [] 87 dataset_strategy = () 88 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 89 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 90 if dataset_strategy: 91 if len(shapes) != len(dataset_strategy): 92 raise ValueError("The input shapes size {} is not equal to " 93 "dataset strategy size {}".format(len(shapes), len(dataset_strategy))) 94 for index, shape in enumerate(shapes): 95 if len(shape) != len(dataset_strategy[index]): 96 raise ValueError("The input shapes item size {} is not equal to " 97 "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index]))) 98 new_shape = () 99 for i, item in enumerate(shape): 100 new_shape += (item * dataset_strategy[index][i],) 101 new_shapes.append(new_shape) 102 return new_shapes 103 for shape in shapes: 104 new_shape = () 105 for i, item in enumerate(shape): 106 if i == 0: 107 new_shape += (item * device_num,) 108 else: 109 new_shape += (item,) 110 new_shapes.append(new_shape) 111 return new_shapes 112 113 114def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None): 115 """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data 116 from host solution. 117 """ 118 lst = [] 119 device_num = global_device_num // _get_pipeline_stages() 120 stage_rank = global_rank % device_num 121 if not isinstance(elem, (tuple, list)): 122 elem = [elem] 123 if stage_rank >= device_num: 124 raise ValueError("The global rank must be smaller than device number, the global rank is {}, " 125 "the device num is {}".format(stage_rank, device_num)) 126 dataset_strategy = () 127 if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"): 128 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 129 if elem and dataset_strategy: 130 if len(elem) != len(dataset_strategy): 131 raise ValueError("The input size {} is not equal to " 132 "dataset strategy size {}".format(len(elem), len(dataset_strategy))) 133 for index, data in enumerate(elem): 134 if isinstance(data, np.ndarray): 135 data = Tensor(data) 136 if not isinstance(data, Tensor): 137 raise ValueError("elements in tensors must be Tensor") 138 shape_ = data.shape 139 type_ = data.dtype 140 new_shape = () 141 if not dataset_strategy: 142 batchsize_per_device = 1 143 for i, item in enumerate(shape_): 144 if i == 0: 145 new_shape += (item * device_num,) 146 batchsize_per_device = item 147 else: 148 new_shape += (item,) 149 new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) 150 start = stage_rank * batchsize_per_device 151 new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy() 152 else: 153 if len(shape_) != len(dataset_strategy[index]): 154 raise ValueError("The input shapes item size {} is not equal to " 155 "dataset strategy item size {}".format(len(shape_), len(dataset_strategy[index]))) 156 slice_index = () 157 for i, item in enumerate(shape_): 158 new_shape += (item * dataset_strategy[index][i],) 159 start = (stage_rank % dataset_strategy[index][i]) * item 160 end = (stage_rank % dataset_strategy[index][i] + 1) * item 161 s = slice(start, end, 1) 162 slice_index += (s,) 163 new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) 164 new_tensor_numpy[slice_index] = data.asnumpy() 165 new_tensor = Tensor(new_tensor_numpy) 166 lst.append(new_tensor) 167 if scaling_sens: 168 lst.append(Tensor(scaling_sens, mstype.float32)) 169 return tuple(lst) 170 171 172def _get_gradients_mean(): 173 """Get if using gradients_mean.""" 174 return auto_parallel_context().get_gradients_mean() 175 176 177def _get_device_num(): 178 """Get the device num.""" 179 parallel_mode = auto_parallel_context().get_parallel_mode() 180 if parallel_mode == "stand_alone": 181 device_num = 1 182 return device_num 183 184 if auto_parallel_context().get_device_num_is_set() is False: 185 device_num = get_group_size() 186 else: 187 device_num = auto_parallel_context().get_device_num() 188 return device_num 189 190 191def _get_global_rank(): 192 """Get the global rank.""" 193 parallel_mode = auto_parallel_context().get_parallel_mode() 194 if parallel_mode == "stand_alone": 195 global_rank = 0 196 return global_rank 197 198 if auto_parallel_context().get_global_rank_is_set() is False: 199 global_rank = get_rank() 200 else: 201 global_rank = auto_parallel_context().get_global_rank() 202 return global_rank 203 204 205def _get_parameter_broadcast(): 206 """Get the parameter broadcast.""" 207 parallel_mode = auto_parallel_context().get_parallel_mode() 208 parameter_broadcast = auto_parallel_context().get_parameter_broadcast() 209 210 if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed() is None: 211 logger.warning("You are suggested to use mindspore.context.set_auto_parallel_context(parameter_broadcast=True)" 212 " or mindspore.common.set_seed() to share parameters among multi-devices.") 213 214 return parameter_broadcast 215 216 217def _get_enable_parallel_optimizer(): 218 """Get if using parallel optimizer.""" 219 return auto_parallel_context().get_enable_parallel_optimizer() 220 221 222def _device_number_check(parallel_mode, device_number): 223 """ 224 Check device num. 225 226 Args: 227 parallel_mode (str): The parallel mode. 228 device_number (int): The device number. 229 """ 230 if parallel_mode == "stand_alone" and device_number != 1: 231 raise ValueError("If parallel_mode is stand_alone, device_number must be 1, " 232 "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode)) 233 234 235def _parameter_broadcast_check(parallel_mode, parameter_broadcast): 236 """ 237 Check parameter broadcast. 238 239 Note: 240 If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same 241 random seed to make sure parameters on multiple devices are the same. 242 243 Args: 244 parallel_mode (str): The parallel mode. 245 parameter_broadcast (bool): The parameter broadcast. 246 247 Raises: 248 ValueError: If parameter is broadcasted 249 but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel"). 250 """ 251 if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"): 252 raise ValueError("stand_alone, semi_auto_parallel and auto_parallel " 253 "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}" 254 .format(parallel_mode, parameter_broadcast)) 255 256 257def _get_python_op(op_name, op_path, instance_name, arglist): 258 """Get python operator.""" 259 module = __import__(op_path, fromlist=["None"]) 260 cls = getattr(module, op_name) 261 if op_path != "mindspore.ops.functional": 262 op = cls(*arglist) 263 else: 264 op = cls 265 op.set_prim_instance_name(instance_name) 266 return op 267 268 269def _reset_op_id(): 270 """Reset op id.""" 271 reset_op_id() 272 273 274def _parallel_predict_check(): 275 """validate parallel model prediction""" 276 if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): 277 dataset_strategy = context.get_auto_parallel_context("dataset_strategy") 278 is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch")) 279 if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp: 280 raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.') 281 282 283def _check_similar_layout(tensor_layout1, tensor_layout2): 284 """check if two tensor layouts are same""" 285 if tensor_layout1[1] != tensor_layout2[1]: 286 return False 287 for i in tensor_layout1[1]: 288 if i == -1: 289 continue 290 if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]: 291 return False 292 return True 293 294 295def _check_same_layout(tensor_layout1, tensor_layout2): 296 """check if two tensor layouts are same""" 297 return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1] 298 299 300def _remove_repeated_slices(tensor_layout): 301 """generate unrepeated tensor layout""" 302 import copy 303 new_tensor_layout = copy.deepcopy(tensor_layout) 304 dev_mat = tensor_layout[0][:] 305 tensor_map = tensor_layout[1] 306 for dim in range(len(dev_mat)): 307 if dim not in tensor_map: 308 dev_mat[-1-dim] = 1 309 new_tensor_layout[0] = dev_mat 310 return new_tensor_layout 311 312 313def _infer_rank_list(train_map, predict_map=None): 314 """infer checkpoint slices to be loaded""" 315 ret = {} 316 if _get_pipeline_stages() > 1: 317 local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages())) 318 else: 319 local_rank = _get_global_rank() 320 for param_name in train_map: 321 train_layout = train_map[param_name] 322 train_dev_mat = train_layout[0] 323 dev_num = np.array(train_dev_mat).prod() 324 new_train_layout = _remove_repeated_slices(train_layout) 325 array = np.arange(dev_num).reshape(train_dev_mat) 326 index = () 327 for i in new_train_layout[0]: 328 if i == 1: 329 index = index + (0,) 330 else: 331 index = index + (slice(None),) 332 rank_list = array[index].flatten() 333 if not predict_map: 334 ret[param_name] = (rank_list, False) 335 continue 336 if param_name not in predict_map: 337 logger.warning("predict_map does not contain %s", param_name) 338 continue 339 predict_layout = predict_map[param_name] 340 dev_num = np.array(predict_layout[0]).prod() 341 # optimization pass 342 if _check_same_layout(train_layout, predict_layout): 343 ret[param_name] = ([local_rank], True) 344 continue 345 if _check_similar_layout(train_layout, predict_layout): 346 if len(rank_list) == 1: 347 ret[param_name] = (rank_list, True) 348 elif len(rank_list) == dev_num: 349 ret[param_name] = ([rank_list[local_rank]], True) 350 else: 351 ret[param_name] = (rank_list, False) 352 else: 353 ret[param_name] = (rank_list, False) 354 return ret 355