1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""grad reducer cell for distributed training""" 16from mindspore import context 17from mindspore.nn.cell import Cell 18from mindspore.communication.management import GlobalComm, get_group_size 19from mindspore.common.tensor import RowTensor 20from mindspore.ops import functional as F, composite as C 21from mindspore.ops.operations.comm_ops import AllReduce, AllGather 22from mindspore.parallel._auto_parallel_context import auto_parallel_context 23import mindspore.common.dtype as mstype 24from mindspore.common.tensor import Tensor 25 26 27reduce_opt = C.MultitypeFuncGraph("reduce_opt") 28 29 30def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP): 31 """ initialize allreduce communication operators""" 32 fusion_type = 2 ** 10 33 split = 0 34 fusion = () 35 for i in range(length): 36 fusion = fusion + (fusion_type,) 37 if split >= len(split_indices): 38 continue 39 if split_indices[split] <= i: 40 fusion_type += 1 41 split += 1 42 index = tuple(range(1, length + 1)) 43 op_list = () 44 for i in range(length): 45 op = AllReduce('sum', group) 46 op.add_prim_attr('fusion', fusion[i]) 47 op.add_prim_attr('index', index[i]) 48 op_list = op_list + (op,) 49 return op_list 50 51 52def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fusion_type=1): 53 """ initialize allreduce communication operators by parameters""" 54 op_list = () 55 param_fusion = False 56 last_comm_fusion = None 57 first_parameter_flag = True 58 index = 1 59 for parameter in parameters: 60 comm_fusion = parameter.comm_fusion 61 if first_parameter_flag: 62 last_comm_fusion = comm_fusion 63 first_parameter_flag = False 64 elif not param_fusion: 65 if comm_fusion != last_comm_fusion: 66 param_fusion = True 67 last_comm_fusion = comm_fusion 68 op = AllReduce('sum', group) 69 op.add_prim_attr('fusion', comm_fusion) 70 op.add_prim_attr('index', index) 71 index += 1 72 op_list = op_list + (op,) 73 if not param_fusion: 74 if split_indices and fusion_type == 1: 75 op_list = _init_allreduce_operators(len(parameters), split_indices, group) 76 param_fusion = True 77 else: 78 op_list = () 79 return op_list, param_fusion 80 81 82@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor") 83def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): 84 """ 85 Apply allreduce on gradient. 86 87 Args: 88 degree (int): The mean coefficient. 89 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 90 allgather (Primitive): The communication operator for sparse gradients. 91 allreduce (Primitive): The communication operator for gradients. 92 allreduce_filter (bool): When it is true, allreduce would apply. 93 grad (Tensor): The gradient tensor before operation. 94 95 Returns: 96 Tensor, the gradient tensor after operation. 97 """ 98 if allreduce_filter: 99 grad = allreduce(grad) 100 if mean: 101 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 102 return grad 103 return grad 104 105 106@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor") 107def _tensors_allreduce_post(degree, mean, allreduce_filter, grad): 108 """ 109 Apply allreduce on gradient in PyNative mode. 110 111 Args: 112 degree (int): The mean coefficient. 113 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 114 allgather (Primitive): The communication operator for sparse gradients. 115 allreduce (Primitive): The communication operator for gradients. 116 allreduce_filter (bool): When it is true, allreduce would apply. 117 grad (Tensor): The gradient tensor before operation. 118 119 Returns: 120 Tensor, the gradient tensor after operation. 121 """ 122 if allreduce_filter: 123 if mean: 124 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 125 return grad 126 return grad 127 128 129@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") 130def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): 131 """ 132 Apply allreduce on gradient. 133 134 Args: 135 degree (int): The mean coefficient. 136 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 137 allgather (Primitive): The communication operator for sparse gradients. 138 allreduce (Primitive): The communication operator for gradients. 139 allreduce_filter (bool): When it is true, allreduce would apply. 140 grad (Tensor): The gradient tensor before operation. 141 ps_parameter (bool): Use parameter server or not. 142 143 Returns: 144 Tensor, the gradient tensor after operation. 145 """ 146 if ps_parameter: 147 return grad 148 149 if allreduce_filter: 150 grad = allreduce(grad) 151 if mean: 152 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 153 return grad 154 return grad 155 156 157@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor") 158def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): 159 """ 160 Apply allgather on gradient instead of allreduce for sparse feature. 161 Allgather is a communication operation used for distributed deep learning. 162 163 Args: 164 degree (int): The mean coefficient. 165 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 166 allgather (Primitive): The communication operator for sparse gradients. 167 allreduce (Primitive): The communication operator for gradients. 168 allreduce_filter (bool): When it is true, allgather would apply. 169 grad (tuple): The indices, gradient tensor and tensor_shape before operation. 170 171 Returns: 172 RowTensor, the gradient after operation. 173 """ 174 if allreduce_filter: 175 indices = allgather(grad.indices) 176 dout = allgather(grad.values) 177 if mean: 178 dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) 179 grad = RowTensor(indices, dout, grad.dense_shape) 180 return grad 181 182 183@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") 184def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): 185 """ 186 Apply allgather on gradient instead of allreduce for sparse feature. 187 Allgather is a communication operation used for distributed deep learning. 188 189 Args: 190 degree (int): The mean coefficient. 191 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 192 allgather (Primitive): The communication operator for sparse gradients. 193 allreduce (Primitive): The communication operator for gradients. 194 allreduce_filter (bool): When it is true, allgather would apply. 195 grad (tuple): The indices, gradient tensor and tensor_shape before operation. 196 ps_parameter (bool): Use parameter server or not. 197 198 Returns: 199 RowTensor, the gradient after operation. 200 """ 201 if ps_parameter: 202 return grad 203 204 if allreduce_filter: 205 indices = allgather(grad.indices) 206 dout = allgather(grad.values) 207 if mean: 208 dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) 209 grad = RowTensor(indices, dout, grad.dense_shape) 210 return grad 211 212 213_get_datatype = C.MultitypeFuncGraph("_get_datatype") 214 215 216@_get_datatype.register("Tensor") 217def _tensors_get_datatype(grad): 218 """ 219 Acquire gradient datatype. 220 221 Args: 222 grad (Tensor): The gradient tensor before operation. 223 224 Returns: 225 mstype, the datatype of gradient. 226 """ 227 return F.dtype(grad) 228 229 230@_get_datatype.register("RowTensor") 231def _tensors_get_datatype_with_sparse(grad): 232 """ 233 Acquire gradient datatype. 234 235 Args: 236 grad (RowTensor): The gradient before operation. 237 238 Returns: 239 mstype, the datatype of gradient. 240 """ 241 return F.dtype(grad.values) 242 243 244_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") 245 246 247@_cast_datatype.register("TypeType", "Tensor") 248def _tensors_cast_datatype(datatype, grad): 249 """ 250 Cast gradient to datatype. 251 252 Args: 253 datatype (mstype): the destination datatype of gradient. 254 grad (Tensor): The gradient tensor before operation. 255 256 Returns: 257 Tensor, the gradient tensor after operation. 258 """ 259 return F.cast(grad, datatype) 260 261 262@_cast_datatype.register("TypeType", "RowTensor") 263def _tensors_cast_datatype_with_sparse(datatype, grad): 264 """ 265 Cast gradient to datatype. 266 267 Args: 268 datatype (mstype): the destination datatype of gradient. 269 grad (RowTensor): The gradient before operation. 270 271 Returns: 272 RowTensor, the gradient after operation. 273 """ 274 dout = F.cast(grad.values, datatype) 275 return RowTensor(grad.indices, dout, grad.dense_shape) 276 277 278class DistributedGradReducer(Cell): 279 """ 280 A distributed optimizer. 281 282 Constructs a gradient reducer Cell, which applies communication and average operations on 283 single-process gradient values. 284 285 Args: 286 parameters (list): the parameters to be updated. 287 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. 288 degree (int): The mean coefficient. Usually it equals to device number. Default: None. 289 fusion_type (int): The type of all reduce fusion. Default: 1. 290 291 Raises: 292 ValueError: If degree is not a int or less than 0. 293 294 Supported Platforms: 295 ``Ascend`` ``GPU`` 296 297 Examples: 298 >>> # This example should be run with multiple processes. 299 >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. 300 >>> import numpy as np 301 >>> from mindspore.communication import init 302 >>> from mindspore import ops 303 >>> from mindspore import context 304 >>> from mindspore.context import ParallelMode 305 >>> from mindspore import Parameter, Tensor 306 >>> from mindspore import nn 307 >>> 308 >>> context.set_context(mode=context.GRAPH_MODE) 309 >>> init() 310 >>> context.reset_auto_parallel_context() 311 >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) 312 >>> 313 >>> class TrainingWrapper(nn.Cell): 314 ... def __init__(self, network, optimizer, sens=1.0): 315 ... super(TrainingWrapper, self).__init__(auto_prefix=False) 316 ... self.network = network 317 ... self.network.add_flags(defer_inline=True) 318 ... self.weights = optimizer.parameters 319 ... self.optimizer = optimizer 320 ... self.grad = ops.GradOperation(get_by_list=True, sens_param=True) 321 ... self.sens = sens 322 ... self.reducer_flag = False 323 ... self.grad_reducer = None 324 ... self.parallel_mode = context.get_auto_parallel_context("parallel_mode") 325 ... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: 326 ... self.reducer_flag = True 327 ... if self.reducer_flag: 328 ... mean = context.get_auto_parallel_context("gradients_mean") 329 ... degree = context.get_auto_parallel_context("device_num") 330 ... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) 331 ... 332 ... def construct(self, *args): 333 ... weights = self.weights 334 ... loss = self.network(*args) 335 ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) 336 ... grads = self.grad(self.network, weights)(*args, sens) 337 ... if self.reducer_flag: 338 ... # apply grad reducer on grads 339 ... grads = self.grad_reducer(grads) 340 ... return ops.Depend(loss, self.optimizer(grads)) 341 >>> 342 >>> class Net(nn.Cell): 343 ... def __init__(self, in_features, out_features): 344 ... super(Net, self).__init__() 345 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 346 ... name='weight') 347 ... self.matmul = ops.MatMul() 348 ... 349 ... def construct(self, x): 350 ... output = self.matmul(x, self.weight) 351 ... return output 352 >>> 353 >>> size, in_features, out_features = 16, 16, 10 354 >>> network = Net(in_features, out_features) 355 >>> loss = nn.MSELoss() 356 >>> net_with_loss = nn.WithLossCell(network, loss) 357 >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) 358 >>> train_cell = TrainingWrapper(net_with_loss, optimizer) 359 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 360 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 361 >>> grads = train_cell(inputs, label) 362 >>> print(grads) 363 256.0 364 """ 365 366 def __init__(self, parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP): 367 super(DistributedGradReducer, self).__init__(auto_prefix=False) 368 self.map_ = C.Map() 369 if degree is None: 370 self.degree = get_group_size() 371 else: 372 if not isinstance(degree, int) or degree <= 0: 373 raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") 374 self.degree = degree 375 self.degree = Tensor(1.0 / self.degree, mstype.float32) 376 self.mean = mean 377 self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) 378 is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") 379 split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() 380 if is_parallel_optimizer and split_indices: 381 self.split_fusion = True 382 self.op_list = _init_allreduce_operators(len(parameters), split_indices, group) 383 else: 384 self.split_fusion = True 385 self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices, group, 386 fusion_type) 387 if not param_fusion: 388 self.split_fusion = False 389 self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type) 390 self.allgather = AllGather(group) 391 ps_filter = lambda x: x.is_param_ps 392 self.ps_parameters = tuple(ps_filter(x) for x in parameters) 393 self.enable_parameter_server = any(self.ps_parameters) 394 self.mode = context.get_context("mode") 395 396 def construct(self, grads): 397 """ 398 Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the 399 result of AllReduce is unreliable. To solve the problem, grads must be cast to float32 before AllReduce, 400 and cast back after the operation. 401 402 Args: 403 grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation. 404 405 Returns: 406 new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation. 407 """ 408 datatypes = self.map_(F.partial(_get_datatype), grads) 409 grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) 410 if self.mode == context.PYNATIVE_MODE: 411 new_grad = grads 412 elif self.split_fusion: 413 if self.enable_parameter_server: 414 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), 415 self.op_list, self.allreduce_filter, grads, self.ps_parameters) 416 else: 417 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), 418 self.op_list, self.allreduce_filter, grads) 419 else: 420 if self.enable_parameter_server: 421 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, 422 self.allreduce), self.allreduce_filter, grads, self.ps_parameters) 423 else: 424 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, 425 self.allreduce), self.allreduce_filter, grads) 426 new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) 427 return new_grad 428