• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""grad reducer cell for distributed training"""
16from __future__ import absolute_import
17
18from mindspore import context
19from mindspore import log as logger
20from mindspore.nn.cell import Cell
21from mindspore.nn.layer import Identity
22from mindspore.communication.management import GlobalComm, get_group_size
23from mindspore.common.sparse_tensor import RowTensorInner
24from mindspore.ops import functional as F, composite as C, operations as P
25from mindspore.ops.operations.comm_ops import AllReduce, AllGather
26from mindspore.parallel._auto_parallel_context import auto_parallel_context
27import mindspore.common.dtype as mstype
28from mindspore.common.sparse_tensor import Tensor
29from mindspore.common.api import jit
30from mindspore.common.parameter import Parameter
31from mindspore.parallel._utils import _get_enable_parallel_optimizer
32
33reduce_opt = C.MultitypeFuncGraph("reduce_opt")
34grad_scale = C.MultitypeFuncGraph("grad_scale")
35shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
36reciprocal = P.Reciprocal()
37
38
39@grad_scale.register("Tensor", "Tensor", "Tensor")
40def tensor_grad_scale_pipeline(scale, grad, accu_grad):
41    accu_grad = F.depend(accu_grad, grad)
42    new_grad = accu_grad * reciprocal(scale)
43    accu_grad = F.depend(accu_grad, new_grad)
44    zeros = F.tensor_mul(accu_grad, 0.0)
45    new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
46    return new_grad
47
48
49@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
50def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
51    new_grad = grad * reciprocal(scale)
52    accu_grad = F.depend(accu_grad, new_grad)
53    new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
54    return new_grad
55
56
57def _init_allreduce_operators(length, split_indices, group=GlobalComm.WORLD_COMM_GROUP):
58    """ initialize allreduce communication operators"""
59    for indices in split_indices:
60        if indices >= length:
61            logger.warning(f"AllReduce's split index {indices} is greater than or equal to"
62                           f"the total gradient's number of {length}")
63    fusion_type = 2 ** 10
64    split = 0
65    fusion = ()
66    for i in range(length):
67        fusion = fusion + (fusion_type,)
68        if split >= len(split_indices):
69            continue
70        if split_indices[split] <= i:
71            fusion_type += 1
72            split += 1
73
74    index = tuple(range(1, length + 1))
75    op_list = ()
76    for i in range(length):
77        op = AllReduce('sum', group)
78        op_fusion_id = fusion[i]
79        op.add_prim_attr('fusion', op_fusion_id)
80        op.add_prim_attr('index', index[i])
81        op_list = op_list + (op,)
82    return op_list
83
84
85def _init_allreduce_operators_by_parameters(parameters, split_indices, group, fusion_type=1):
86    """ initialize allreduce communication operators by parameters"""
87    op_list = ()
88    param_fusion = False
89    last_comm_fusion = None
90    first_parameter_flag = True
91    index = 1
92    for parameter in parameters:
93        comm_fusion = parameter.comm_fusion
94        if first_parameter_flag:
95            last_comm_fusion = comm_fusion
96            first_parameter_flag = False
97        elif not param_fusion:
98            if comm_fusion != last_comm_fusion:
99                param_fusion = True
100                last_comm_fusion = comm_fusion
101        op = AllReduce('sum', group)
102        op.add_prim_attr('fusion', comm_fusion)
103        op.add_prim_attr('index', index)
104        index += 1
105        op_list = op_list + (op,)
106
107    if not param_fusion:
108        if split_indices and fusion_type == 1:
109            op_list = _init_allreduce_operators(len(parameters), split_indices, group)
110            param_fusion = True
111        else:
112            op_list = ()
113    return op_list, param_fusion
114
115
116@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
117def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
118    """
119    Apply allreduce on gradient.
120
121    Args:
122        degree (int): The mean coefficient.
123        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
124        allgather (Primitive): The communication operator for sparse gradients.
125        allreduce (Primitive): The communication operator for gradients.
126        allreduce_filter (bool): When it is true, allreduce would apply.
127        grad (Tensor): The gradient tensor before operation.
128
129    Returns:
130        Tensor, the gradient tensor after operation.
131    """
132    if allreduce_filter:
133        grad = allreduce(grad)
134        if mean:
135            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
136        return grad
137    return grad
138
139
140@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor")
141def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
142    """
143    Apply allreduce on gradient in PyNative mode.
144
145    Args:
146        degree (int): The mean coefficient.
147        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
148        allreduce_filter (bool): When it is true, allreduce would apply.
149        grad (Tensor): The gradient tensor before operation.
150
151    Returns:
152        Tensor, the gradient tensor after operation.
153    """
154    if allreduce_filter:
155        if mean:
156            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
157            return grad
158    return grad
159
160
161@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
162def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
163    """
164    Apply allreduce on gradient.
165
166    Args:
167        degree (int): The mean coefficient.
168        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
169        allgather (Primitive): The communication operator for sparse gradients.
170        allreduce (Primitive): The communication operator for gradients.
171        allreduce_filter (bool): When it is true, allreduce would apply.
172        grad (Tensor): The gradient tensor before operation.
173        ps_parameter (bool): Use parameter server or not.
174
175    Returns:
176        Tensor, the gradient tensor after operation.
177    """
178    if ps_parameter:
179        return grad
180
181    if allreduce_filter:
182        grad = allreduce(grad)
183        if mean:
184            grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
185        return grad
186    return grad
187
188
189@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor")
190def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
191    """
192    Apply allgather on gradient instead of allreduce for sparse feature.
193    Allgather is a communication operation used for distributed deep learning.
194
195    Args:
196        degree (int): The mean coefficient.
197        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
198        allgather (Primitive): The communication operator for sparse gradients.
199        allreduce (Primitive): The communication operator for gradients.
200        allreduce_filter (bool): When it is true, allgather would apply.
201        grad (tuple): The indices, gradient tensor and tensor_shape before operation.
202
203    Returns:
204        RowTensor, the gradient after operation.
205    """
206    if allreduce_filter:
207        indices = allgather(grad.indices)
208        dout = allgather(grad.values)
209        if mean:
210            dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
211        grad = RowTensorInner(indices, dout, grad.dense_shape)
212    return grad
213
214
215@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
216def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
217    """
218    Apply allgather on gradient instead of allreduce for sparse feature.
219    Allgather is a communication operation used for distributed deep learning.
220
221    Args:
222        degree (int): The mean coefficient.
223        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
224        allgather (Primitive): The communication operator for sparse gradients.
225        allreduce (Primitive): The communication operator for gradients.
226        allreduce_filter (bool): When it is true, allgather would apply.
227        grad (tuple): The indices, gradient tensor and tensor_shape before operation.
228        ps_parameter (bool): Use parameter server or not.
229
230    Returns:
231        RowTensor, the gradient after operation.
232    """
233    if ps_parameter:
234        return grad
235
236    if allreduce_filter:
237        indices = allgather(grad.indices)
238        dout = allgather(grad.values)
239        if mean:
240            dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
241        grad = RowTensorInner(indices, dout, grad.dense_shape)
242    return grad
243
244
245_get_datatype = C.MultitypeFuncGraph("_get_datatype")
246
247
248@_get_datatype.register("Tensor")
249def _tensors_get_datatype(grad):
250    """
251    Acquire gradient datatype.
252
253    Args:
254        grad (Tensor): The gradient tensor before operation.
255
256    Returns:
257        mstype, the datatype of gradient.
258    """
259    return F.dtype(grad)
260
261
262@_get_datatype.register("RowTensor")
263def _tensors_get_datatype_with_sparse(grad):
264    """
265    Acquire gradient datatype.
266
267    Args:
268        grad (RowTensor): The gradient before operation.
269
270    Returns:
271        mstype, the datatype of gradient.
272    """
273    return F.dtype(grad.values)
274
275
276_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
277
278
279@_cast_datatype.register("TypeType", "Tensor")
280def _tensors_cast_datatype(datatype, grad):
281    """
282    Cast gradient to datatype.
283
284    Args:
285        datatype (mstype): the destination datatype of gradient.
286        grad (Tensor): The gradient tensor before operation.
287
288    Returns:
289        Tensor, the gradient tensor after operation.
290    """
291    return F.cast(grad, datatype)
292
293
294@_cast_datatype.register("TypeType", "RowTensor")
295def _tensors_cast_datatype_with_sparse(datatype, grad):
296    """
297    Cast gradient to datatype.
298
299    Args:
300        datatype (mstype): the destination datatype of gradient.
301        grad (RowTensor): The gradient before operation.
302
303    Returns:
304        RowTensor, the gradient after operation.
305    """
306    dout = F.cast(grad.values, datatype)
307    return RowTensorInner(grad.indices, dout, grad.dense_shape)
308
309
310class DistributedGradReducer(Cell):
311    """
312    A distributed optimizer.
313
314    Aggregate the gradients for all cards by using AllReduce in data parallel.
315
316    Args:
317        parameters (list): the parameters to be updated.
318        mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
319                     When it is not specified, using the configuration `gradients_mean` in auto_parallel_context.
320                     Default: ``None`` .
321        degree (int): The mean coefficient. Usually it equals to device number. Default: ``None`` .
322        fusion_type (int): The type of all reduce fusion. Default: ``1`` .
323        group (str): The communication group to work on. Normally, the group should be created by create_group,
324                     otherwise, using the default group. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
325
326    Raises:
327        ValueError: If degree is not an int or less than 0.
328
329    Supported Platforms:
330        ``Ascend`` ``GPU``
331
332    Examples:
333        .. note::
334            Before running the following examples, you need to configure the communication environment variables.
335
336            For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
337            Please see the `rank table Startup
338            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
339            for more details.
340
341            For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
342            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
343
344            For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
345            Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
346
347            This example should be run with multiple devices.
348
349        >>> import numpy as np
350        >>> import mindspore as ms
351        >>> from mindspore.communication import init
352        >>> from mindspore import Parameter, Tensor, ops, nn
353        >>>
354        >>> ms.set_context(mode=ms.GRAPH_MODE)
355        >>> init()
356        >>> ms.reset_auto_parallel_context()
357        >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
358        >>>
359        >>> class TrainingWrapper(nn.Cell):
360        ...     def __init__(self, network, optimizer, sens=1.0):
361        ...         super(TrainingWrapper, self).__init__(auto_prefix=False)
362        ...         self.network = network
363        ...         self.network.add_flags(defer_inline=True)
364        ...         self.weights = optimizer.parameters
365        ...         self.optimizer = optimizer
366        ...         self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
367        ...         self.sens = sens
368        ...         self.reducer_flag = False
369        ...         self.grad_reducer = None
370        ...         self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
371        ...         self.depend = ops.Depend()
372        ...         if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
373        ...             self.reducer_flag = True
374        ...         if self.reducer_flag:
375        ...             mean = context.get_auto_parallel_context("gradients_mean")
376        ...             degree = context.get_auto_parallel_context("device_num")
377        ...             self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
378        ...
379        ...     def construct(self, *args):
380        ...         weights = self.weights
381        ...         loss = self.network(*args)
382        ...         sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
383        ...         grads = self.grad(self.network, weights)(*args, sens)
384        ...         if self.reducer_flag:
385        ...             # apply grad reducer on grads
386        ...             grads = self.grad_reducer(grads)
387        ...         return self.depend(loss, self.optimizer(grads))
388        >>>
389        >>> class Net(nn.Cell):
390        ...     def __init__(self, in_features, out_features):
391        ...         super(Net, self).__init__()
392        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
393        ...                                 name='weight')
394        ...         self.matmul = ops.MatMul()
395        ...
396        ...     def construct(self, x):
397        ...         output = self.matmul(x, self.weight)
398        ...         return output
399        >>>
400        >>> size, in_features, out_features = 16, 16, 10
401        >>> network = Net(in_features, out_features)
402        >>> loss = nn.MSELoss()
403        >>> net_with_loss = nn.WithLossCell(network, loss)
404        >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
405        >>> train_cell = TrainingWrapper(net_with_loss, optimizer)
406        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
407        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
408        >>> grads = train_cell(inputs, label)
409        >>> print(grads)
410        256.0
411    """
412
413    def __init__(self, parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
414        super(DistributedGradReducer, self).__init__(auto_prefix=False)
415        self._check_parallel_mode()
416        self.map_ = C.Map()
417        self.mean = mean
418        if mean is None:
419            self.mean = auto_parallel_context().get_gradients_mean()
420        if degree is None:
421            self.degree = get_group_size()
422        else:
423            if not isinstance(degree, int) or degree <= 0:
424                raise ValueError("For 'DistributedGradReducer', "
425                                 "parameter 'degree' in DistributedGradReducer "
426                                 "should large than 0 and be int, degree: {}.".format(degree))
427            self.degree = degree
428        self.degree = Tensor(1.0 / self.degree, mstype.float32)
429
430        self.allreduce_filter = tuple((x.layerwise_parallel is False) and (x.is_in_shard is False) for x in parameters)
431        is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
432        split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
433        if is_parallel_optimizer and split_indices:
434            self.split_fusion = True
435            self.op_list = _init_allreduce_operators(len(parameters), split_indices, group)
436        else:
437            self.split_fusion = True
438            self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices, group,
439                                                                                 fusion_type)
440            if not param_fusion:
441                self.split_fusion = False
442                self.allreduce = AllReduce('sum', group).add_prim_attr('fusion', fusion_type)
443        self.allgather = AllGather(group)
444        ps_filter = lambda x: x.is_param_ps
445        self.ps_parameters = tuple(ps_filter(x) for x in parameters)
446        self.enable_parameter_server = any(self.ps_parameters)
447        self.mode = context.get_context("mode")
448        self.enable_tuple_broaden = True
449
450    @jit
451    def construct(self, grads):
452        """
453        Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
454        result of AllReduce is unreliable. To solve the problem, grads must be cast to float32 before AllReduce,
455        and cast back after the operation.
456
457        Args:
458            grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation.
459
460        Returns:
461            new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation.
462        """
463        datatypes = self.map_(F.partial(_get_datatype), grads)
464        grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
465
466        if self.split_fusion:
467            if self.enable_parameter_server:
468                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
469                                     self.op_list, self.allreduce_filter, grads, self.ps_parameters)
470            else:
471                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
472                                     self.op_list, self.allreduce_filter, grads)
473        else:
474            if self.enable_parameter_server:
475                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
476                                               self.allreduce), self.allreduce_filter, grads, self.ps_parameters)
477            else:
478                new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
479                                               self.allreduce), self.allreduce_filter, grads)
480        new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
481        return new_grad
482
483    def _check_parallel_mode(self):
484        """check parallel mode"""
485        parallel_mode = context.get_auto_parallel_context('parallel_mode')
486        if context.get_context('mode') == context.GRAPH_MODE and parallel_mode in (
487                context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL):
488            raise RuntimeError("{} can not use DistributedGradReducer in graph mode".format(parallel_mode))
489
490
491class PipelineGradReducer(Cell):
492    """
493    PipelineGradReducer is a gradient reducer for pipeline parallelism.
494
495    Args:
496        parameters (list): the parameters to be updated.
497        scale_sense (float): the scale sense of the gradient. Default: 1.0.
498
499    Raise:
500        RuntimeError: If the mode is not graph mode.
501        RuntimeError: If the parallel mode is not semi auto parallel or auto parallel.
502
503    Supported Platforms:
504        ``Ascend`` ``GPU``
505
506    Examples:
507        .. note::
508            Before running the following examples, you need to configure the communication environment variables.
509
510            For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
511            Please see the `rank table Startup
512            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
513            for more details.
514
515            For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
516            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
517
518            This example should be run with multiple devices.
519
520        >>> import numpy as np
521        >>> import mindspore as ms
522        >>> from mindspore import nn, ops, Tensor
523        >>> from mindspore.communication import init
524        >>>
525        >>> ms.set_context(mode=ms.GRAPH_MODE)
526        >>> ms.reset_auto_parallel_context()
527        >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
528        >>> init()
529        >>> ms.set_seed(1)
530        >>>
531        >>> class Network(nn.Cell):
532        ...     def __init__(self, in_features, out_features, sens=1.0):
533        ...         super().__init__()
534        ...         self.layer1 = nn.Dense(in_features, 16)
535        ...         self.relu1 = nn.ReLU()
536        ...         self.layer2 = nn.Dense(16, 16)
537        ...         self.relu2 = nn.ReLU()
538        ...         self.layer3 = nn.Dense(16, out_features)
539        ...
540        ...     def construct(self, x):
541        ...         x = self.layer1(x)
542        ...         x = self.relu1(x)
543        ...         x = self.layer2(x)
544        ...         x = self.relu2(x)
545        ...         logits = self.layer3(x)
546        ...         return logits
547        >>>
548        >>> size, in_features, out_features = 16, 32, 10
549        >>> net = Network(in_features, out_features)
550        >>> net.layer1.pipeline_stage = 0
551        >>> net.relu1.pipeline_stage = 0
552        >>> net.layer2.pipeline_stage = 0
553        >>> net.relu2.pipeline_stage = 1
554        >>> net.layer3.pipeline_stage = 1
555        >>> loss_fn = nn.CrossEntropyLoss()
556        >>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
557        >>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 2)
558        >>> net_with_loss.set_train()
559        >>> def forward_fn(inputs, target):
560        ...     loss = net_with_loss(inputs, target)
561        ...     return loss
562        >>>
563        >>> grad_fn = ops.value_and_grad(forward_fn, None, net_with_loss.trainable_params())
564        >>> pp_grad_reducer = nn.PipelineGradReducer(optimizer.parameters)
565        >>>
566        >>> @ms.jit
567        >>> def train_one_step(inputs, target):
568        ...     loss, grads = grad_fn(inputs, target)
569        ...     grads = pp_grad_reducer(grads)
570        ...     optimizer(grads)
571        ...     return loss, grads
572        >>>
573        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
574        >>> label = Tensor(np.ones([size, out_features]).astype(np.float32))
575        >>> loss, _ = train_one_step(inputs, label)
576        >>> print(loss)
577        46.36721
578    """
579    def __init__(self, parameters, scale_sense=1.0):
580        super(PipelineGradReducer, self).__init__(auto_prefix=False)
581        self._check_mode()
582        self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
583        self.grad_reducer = Identity()
584        self.degree = Tensor(1, mstype.float32)
585        self.scale_sense = Parameter(scale_sense, name='scale_sense')
586        self.hyper_map = C.HyperMap()
587        self.opt_shard = _get_enable_parallel_optimizer()
588
589    @jit
590    def construct(self, grads):
591        new_grads = None
592        if self.opt_shard:
593            grads = self.grad_reducer(grads)
594            new_grads = self.hyper_map(F.partial(shard_grad_scale, self.scale_sense * self.degree),
595                                       grads, self.accu_grads)
596        else:
597            accu_grads = self.grad_reducer(self.accu_grads)
598            new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
599        return new_grads
600
601    def _check_mode(self):
602        """check parallel mode"""
603        mode = context.get_context('mode')
604        if mode != context.GRAPH_MODE:
605            raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
606        parallel_mode = context.get_auto_parallel_context('parallel_mode')
607        if parallel_mode not in (context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL):
608            raise RuntimeError(f"{parallel_mode} can not use PipelineGradReducer in graph mode")
609