1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""grad reducer cell for distributed training""" 16from __future__ import absolute_import 17 18from mindspore import context 19from mindspore import log as logger 20from mindspore.nn.cell import Cell 21from mindspore.nn.layer import Identity 22from mindspore.communication.management import GlobalComm, get_group_size 23from mindspore.common.sparse_tensor import RowTensorInner 24from mindspore.ops import functional as F, composite as C, operations as P 25from mindspore.ops.operations.comm_ops import AllReduce, AllGather 26from mindspore.parallel._auto_parallel_context import auto_parallel_context 27import mindspore.common.dtype as mstype 28from mindspore.common.sparse_tensor import Tensor 29from mindspore.common.api import jit 30from mindspore.common.parameter import Parameter 31from mindspore.parallel._utils import _get_enable_parallel_optimizer 32 33reduce_opt = C.MultitypeFuncGraph("reduce_opt") 34grad_scale = C.MultitypeFuncGraph("grad_scale") 35shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") 36reciprocal = P.Reciprocal() 37 38 39@grad_scale.register("Tensor", "Tensor", "Tensor") 40def tensor_grad_scale_pipeline(scale, grad, accu_grad): 41 accu_grad = F.depend(accu_grad, grad) 42 new_grad = accu_grad * reciprocal(scale) 43 accu_grad = F.depend(accu_grad, new_grad) 44 zeros = F.tensor_mul(accu_grad, 0.0) 45 new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) 46 return new_grad 47 48 49@shard_grad_scale.register("Tensor", "Tensor", "Tensor") 50def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad): 51 new_grad = grad * reciprocal(scale) 52 accu_grad = F.depend(accu_grad, new_grad) 53 new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad))) 54 return new_grad 55 56 57def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP): 58 """ initialize allreduce communication operators""" 59 for indices in split_indices: 60 if indices >= length: 61 logger.warning(f"AllReduce's split index {indices} is greater than or equal to" 62 f"the total gradient's number of {length}") 63 fusion_type = 2 ** 10 64 split = 0 65 fusion = () 66 for i in range(length): 67 fusion = fusion + (fusion_type,) 68 if split >= len(split_indices): 69 continue 70 if split_indices[split] <= i: 71 fusion_type += 1 72 split += 1 73 74 index = tuple(range(1, length + 1)) 75 op_list = () 76 for i in range(length): 77 op = AllReduce('sum', group) 78 op_fusion_id = fusion[i] 79 op.add_prim_attr('fusion', op_fusion_id) 80 op.add_prim_attr('index', index[i]) 81 op_list = op_list + (op,) 82 return op_list 83 84 85def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fusion_type=1): 86 """ initialize allreduce communication operators by parameters""" 87 op_list = () 88 param_fusion = False 89 last_comm_fusion = None 90 first_parameter_flag = True 91 index = 1 92 for parameter in parameters: 93 comm_fusion = parameter.comm_fusion 94 if first_parameter_flag: 95 last_comm_fusion = comm_fusion 96 first_parameter_flag = False 97 elif not param_fusion: 98 if comm_fusion != last_comm_fusion: 99 param_fusion = True 100 last_comm_fusion = comm_fusion 101 op = AllReduce('sum', group) 102 op.add_prim_attr('fusion', comm_fusion) 103 op.add_prim_attr('index', index) 104 index += 1 105 op_list = op_list + (op,) 106 107 if not param_fusion: 108 if split_indices and fusion_type == 1: 109 op_list = _init_allreduce_operators(len(parameters), split_indices, group) 110 param_fusion = True 111 else: 112 op_list = () 113 return op_list, param_fusion 114 115 116@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor") 117def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): 118 """ 119 Apply allreduce on gradient. 120 121 Args: 122 degree (int): The mean coefficient. 123 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 124 allgather (Primitive): The communication operator for sparse gradients. 125 allreduce (Primitive): The communication operator for gradients. 126 allreduce_filter (bool): When it is true, allreduce would apply. 127 grad (Tensor): The gradient tensor before operation. 128 129 Returns: 130 Tensor, the gradient tensor after operation. 131 """ 132 if allreduce_filter: 133 grad = allreduce(grad) 134 if mean: 135 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 136 return grad 137 return grad 138 139 140@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor") 141def _tensors_allreduce_post(degree, mean, allreduce_filter, grad): 142 """ 143 Apply allreduce on gradient in PyNative mode. 144 145 Args: 146 degree (int): The mean coefficient. 147 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 148 allreduce_filter (bool): When it is true, allreduce would apply. 149 grad (Tensor): The gradient tensor before operation. 150 151 Returns: 152 Tensor, the gradient tensor after operation. 153 """ 154 if allreduce_filter: 155 if mean: 156 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 157 return grad 158 return grad 159 160 161@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") 162def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): 163 """ 164 Apply allreduce on gradient. 165 166 Args: 167 degree (int): The mean coefficient. 168 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 169 allgather (Primitive): The communication operator for sparse gradients. 170 allreduce (Primitive): The communication operator for gradients. 171 allreduce_filter (bool): When it is true, allreduce would apply. 172 grad (Tensor): The gradient tensor before operation. 173 ps_parameter (bool): Use parameter server or not. 174 175 Returns: 176 Tensor, the gradient tensor after operation. 177 """ 178 if ps_parameter: 179 return grad 180 181 if allreduce_filter: 182 grad = allreduce(grad) 183 if mean: 184 grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) 185 return grad 186 return grad 187 188 189@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor") 190def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): 191 """ 192 Apply allgather on gradient instead of allreduce for sparse feature. 193 Allgather is a communication operation used for distributed deep learning. 194 195 Args: 196 degree (int): The mean coefficient. 197 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 198 allgather (Primitive): The communication operator for sparse gradients. 199 allreduce (Primitive): The communication operator for gradients. 200 allreduce_filter (bool): When it is true, allgather would apply. 201 grad (tuple): The indices, gradient tensor and tensor_shape before operation. 202 203 Returns: 204 RowTensor, the gradient after operation. 205 """ 206 if allreduce_filter: 207 indices = allgather(grad.indices) 208 dout = allgather(grad.values) 209 if mean: 210 dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) 211 grad = RowTensorInner(indices, dout, grad.dense_shape) 212 return grad 213 214 215@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") 216def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): 217 """ 218 Apply allgather on gradient instead of allreduce for sparse feature. 219 Allgather is a communication operation used for distributed deep learning. 220 221 Args: 222 degree (int): The mean coefficient. 223 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 224 allgather (Primitive): The communication operator for sparse gradients. 225 allreduce (Primitive): The communication operator for gradients. 226 allreduce_filter (bool): When it is true, allgather would apply. 227 grad (tuple): The indices, gradient tensor and tensor_shape before operation. 228 ps_parameter (bool): Use parameter server or not. 229 230 Returns: 231 RowTensor, the gradient after operation. 232 """ 233 if ps_parameter: 234 return grad 235 236 if allreduce_filter: 237 indices = allgather(grad.indices) 238 dout = allgather(grad.values) 239 if mean: 240 dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) 241 grad = RowTensorInner(indices, dout, grad.dense_shape) 242 return grad 243 244 245_get_datatype = C.MultitypeFuncGraph("_get_datatype") 246 247 248@_get_datatype.register("Tensor") 249def _tensors_get_datatype(grad): 250 """ 251 Acquire gradient datatype. 252 253 Args: 254 grad (Tensor): The gradient tensor before operation. 255 256 Returns: 257 mstype, the datatype of gradient. 258 """ 259 return F.dtype(grad) 260 261 262@_get_datatype.register("RowTensor") 263def _tensors_get_datatype_with_sparse(grad): 264 """ 265 Acquire gradient datatype. 266 267 Args: 268 grad (RowTensor): The gradient before operation. 269 270 Returns: 271 mstype, the datatype of gradient. 272 """ 273 return F.dtype(grad.values) 274 275 276_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") 277 278 279@_cast_datatype.register("TypeType", "Tensor") 280def _tensors_cast_datatype(datatype, grad): 281 """ 282 Cast gradient to datatype. 283 284 Args: 285 datatype (mstype): the destination datatype of gradient. 286 grad (Tensor): The gradient tensor before operation. 287 288 Returns: 289 Tensor, the gradient tensor after operation. 290 """ 291 return F.cast(grad, datatype) 292 293 294@_cast_datatype.register("TypeType", "RowTensor") 295def _tensors_cast_datatype_with_sparse(datatype, grad): 296 """ 297 Cast gradient to datatype. 298 299 Args: 300 datatype (mstype): the destination datatype of gradient. 301 grad (RowTensor): The gradient before operation. 302 303 Returns: 304 RowTensor, the gradient after operation. 305 """ 306 dout = F.cast(grad.values, datatype) 307 return RowTensorInner(grad.indices, dout, grad.dense_shape) 308 309 310class DistributedGradReducer(Cell): 311 """ 312 A distributed optimizer. 313 314 Aggregate the gradients for all cards by using AllReduce in data parallel. 315 316 Args: 317 parameters (list): the parameters to be updated. 318 mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. 319 When it is not specified, using the configuration `gradients_mean` in auto_parallel_context. 320 Default: ``None`` . 321 degree (int): The mean coefficient. Usually it equals to device number. Default: ``None`` . 322 fusion_type (int): The type of all reduce fusion. Default: ``1`` . 323 group (str): The communication group to work on. Normally, the group should be created by create_group, 324 otherwise, using the default group. Default: ``GlobalComm.WORLD_COMM_GROUP`` . 325 326 Raises: 327 ValueError: If degree is not an int or less than 0. 328 329 Supported Platforms: 330 ``Ascend`` ``GPU`` 331 332 Examples: 333 .. note:: 334 Before running the following examples, you need to configure the communication environment variables. 335 336 For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. 337 Please see the `rank table Startup 338 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_ 339 for more details. 340 341 For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup 342 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ . 343 344 For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster 345 Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ . 346 347 This example should be run with multiple devices. 348 349 >>> import numpy as np 350 >>> import mindspore as ms 351 >>> from mindspore.communication import init 352 >>> from mindspore import Parameter, Tensor, ops, nn 353 >>> 354 >>> ms.set_context(mode=ms.GRAPH_MODE) 355 >>> init() 356 >>> ms.reset_auto_parallel_context() 357 >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL) 358 >>> 359 >>> class TrainingWrapper(nn.Cell): 360 ... def __init__(self, network, optimizer, sens=1.0): 361 ... super(TrainingWrapper, self).__init__(auto_prefix=False) 362 ... self.network = network 363 ... self.network.add_flags(defer_inline=True) 364 ... self.weights = optimizer.parameters 365 ... self.optimizer = optimizer 366 ... self.grad = ops.GradOperation(get_by_list=True, sens_param=True) 367 ... self.sens = sens 368 ... self.reducer_flag = False 369 ... self.grad_reducer = None 370 ... self.parallel_mode = context.get_auto_parallel_context("parallel_mode") 371 ... self.depend = ops.Depend() 372 ... if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: 373 ... self.reducer_flag = True 374 ... if self.reducer_flag: 375 ... mean = context.get_auto_parallel_context("gradients_mean") 376 ... degree = context.get_auto_parallel_context("device_num") 377 ... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) 378 ... 379 ... def construct(self, *args): 380 ... weights = self.weights 381 ... loss = self.network(*args) 382 ... sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens) 383 ... grads = self.grad(self.network, weights)(*args, sens) 384 ... if self.reducer_flag: 385 ... # apply grad reducer on grads 386 ... grads = self.grad_reducer(grads) 387 ... return self.depend(loss, self.optimizer(grads)) 388 >>> 389 >>> class Net(nn.Cell): 390 ... def __init__(self, in_features, out_features): 391 ... super(Net, self).__init__() 392 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 393 ... name='weight') 394 ... self.matmul = ops.MatMul() 395 ... 396 ... def construct(self, x): 397 ... output = self.matmul(x, self.weight) 398 ... return output 399 >>> 400 >>> size, in_features, out_features = 16, 16, 10 401 >>> network = Net(in_features, out_features) 402 >>> loss = nn.MSELoss() 403 >>> net_with_loss = nn.WithLossCell(network, loss) 404 >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) 405 >>> train_cell = TrainingWrapper(net_with_loss, optimizer) 406 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 407 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 408 >>> grads = train_cell(inputs, label) 409 >>> print(grads) 410 256.0 411 """ 412 413 def __init__(self, parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP): 414 super(DistributedGradReducer, self).__init__(auto_prefix=False) 415 self._check_parallel_mode() 416 self.map_ = C.Map() 417 self.mean = mean 418 if mean is None: 419 self.mean = auto_parallel_context().get_gradients_mean() 420 if degree is None: 421 self.degree = get_group_size() 422 else: 423 if not isinstance(degree, int) or degree <= 0: 424 raise ValueError("For 'DistributedGradReducer', " 425 "parameter 'degree' in DistributedGradReducer " 426 "should large than 0 and be int, degree: {}.".format(degree)) 427 self.degree = degree 428 self.degree = Tensor(1.0 / self.degree, mstype.float32) 429 430 self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters) 431 is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") 432 split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() 433 if is_parallel_optimizer and split_indices: 434 self.split_fusion = True 435 self.op_list = _init_allreduce_operators(len(parameters), split_indices, group) 436 else: 437 self.split_fusion = True 438 self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices, group, 439 fusion_type) 440 if not param_fusion: 441 self.split_fusion = False 442 self.allreduce = AllReduce('sum', group).add_prim_attr('fusion', fusion_type) 443 self.allgather = AllGather(group) 444 ps_filter = lambda x: x.is_param_ps 445 self.ps_parameters = tuple(ps_filter(x) for x in parameters) 446 self.enable_parameter_server = any(self.ps_parameters) 447 self.mode = context.get_context("mode") 448 self.enable_tuple_broaden = True 449 450 @jit 451 def construct(self, grads): 452 """ 453 Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the 454 result of AllReduce is unreliable. To solve the problem, grads must be cast to float32 before AllReduce, 455 and cast back after the operation. 456 457 Args: 458 grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation. 459 460 Returns: 461 new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation. 462 """ 463 datatypes = self.map_(F.partial(_get_datatype), grads) 464 grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) 465 466 if self.split_fusion: 467 if self.enable_parameter_server: 468 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), 469 self.op_list, self.allreduce_filter, grads, self.ps_parameters) 470 else: 471 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), 472 self.op_list, self.allreduce_filter, grads) 473 else: 474 if self.enable_parameter_server: 475 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, 476 self.allreduce), self.allreduce_filter, grads, self.ps_parameters) 477 else: 478 new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, 479 self.allreduce), self.allreduce_filter, grads) 480 new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) 481 return new_grad 482 483 def _check_parallel_mode(self): 484 """check parallel mode""" 485 parallel_mode = context.get_auto_parallel_context('parallel_mode') 486 if context.get_context('mode') == context.GRAPH_MODE and parallel_mode in ( 487 context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL): 488 raise RuntimeError("{} can not use DistributedGradReducer in graph mode".format(parallel_mode)) 489 490 491class PipelineGradReducer(Cell): 492 """ 493 PipelineGradReducer is a gradient reducer for pipeline parallelism. 494 495 Args: 496 parameters (list): the parameters to be updated. 497 scale_sense (float): the scale sense of the gradient. Default: 1.0. 498 499 Raise: 500 RuntimeError: If the mode is not graph mode. 501 RuntimeError: If the parallel mode is not semi auto parallel or auto parallel. 502 503 Supported Platforms: 504 ``Ascend`` ``GPU`` 505 506 Examples: 507 .. note:: 508 Before running the following examples, you need to configure the communication environment variables. 509 510 For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. 511 Please see the `rank table Startup 512 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_ 513 for more details. 514 515 For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup 516 <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ . 517 518 This example should be run with multiple devices. 519 520 >>> import numpy as np 521 >>> import mindspore as ms 522 >>> from mindspore import nn, ops, Tensor 523 >>> from mindspore.communication import init 524 >>> 525 >>> ms.set_context(mode=ms.GRAPH_MODE) 526 >>> ms.reset_auto_parallel_context() 527 >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2) 528 >>> init() 529 >>> ms.set_seed(1) 530 >>> 531 >>> class Network(nn.Cell): 532 ... def __init__(self, in_features, out_features, sens=1.0): 533 ... super().__init__() 534 ... self.layer1 = nn.Dense(in_features, 16) 535 ... self.relu1 = nn.ReLU() 536 ... self.layer2 = nn.Dense(16, 16) 537 ... self.relu2 = nn.ReLU() 538 ... self.layer3 = nn.Dense(16, out_features) 539 ... 540 ... def construct(self, x): 541 ... x = self.layer1(x) 542 ... x = self.relu1(x) 543 ... x = self.layer2(x) 544 ... x = self.relu2(x) 545 ... logits = self.layer3(x) 546 ... return logits 547 >>> 548 >>> size, in_features, out_features = 16, 32, 10 549 >>> net = Network(in_features, out_features) 550 >>> net.layer1.pipeline_stage = 0 551 >>> net.relu1.pipeline_stage = 0 552 >>> net.layer2.pipeline_stage = 0 553 >>> net.relu2.pipeline_stage = 1 554 >>> net.layer3.pipeline_stage = 1 555 >>> loss_fn = nn.CrossEntropyLoss() 556 >>> optimizer = nn.SGD(net.trainable_params(), 1e-2) 557 >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 2) 558 >>> net_with_loss.set_train() 559 >>> def forward_fn(inputs, target): 560 ... loss = net_with_loss(inputs, target) 561 ... return loss 562 >>> 563 >>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params()) 564 >>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters) 565 >>> 566 >>> @ms.jit 567 >>> def train_one_step(inputs, target): 568 ... loss, grads = grad_fn(inputs, target) 569 ... grads = pp_grad_reducer(grads) 570 ... optimizer(grads) 571 ... return loss, grads 572 >>> 573 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 574 >>> label = Tensor(np.ones([size, out_features]).astype(np.float32)) 575 >>> loss, _ = train_one_step(inputs, label) 576 >>> print(loss) 577 46.36721 578 """ 579 def __init__(self, parameters, scale_sense=1.0): 580 super(PipelineGradReducer, self).__init__(auto_prefix=False) 581 self._check_mode() 582 self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros") 583 self.grad_reducer = Identity() 584 self.degree = Tensor(1, mstype.float32) 585 self.scale_sense = Parameter(scale_sense, name='scale_sense') 586 self.hyper_map = C.HyperMap() 587 self.opt_shard = _get_enable_parallel_optimizer() 588 589 @jit 590 def construct(self, grads): 591 new_grads = None 592 if self.opt_shard: 593 grads = self.grad_reducer(grads) 594 new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree), 595 grads, self.accu_grads) 596 else: 597 accu_grads = self.grad_reducer(self.accu_grads) 598 new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads) 599 return new_grads 600 601 def _check_mode(self): 602 """check parallel mode""" 603 mode = context.get_context('mode') 604 if mode != context.GRAPH_MODE: 605 raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}") 606 parallel_mode = context.get_auto_parallel_context('parallel_mode') 607 if parallel_mode not in (context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL): 608 raise RuntimeError(f"{parallel_mode} can not use PipelineGradReducer in graph mode") 609