• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Loss scale cell for loss scale training."""
16from __future__ import absolute_import
17
18import os
19import mindspore.context as context
20from mindspore.context import ParallelMode
21from mindspore.parallel._utils import _get_enable_parallel_optimizer
22from mindspore import nn
23from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
24from mindspore.nn.cell import Cell
25from mindspore.common import Tensor
26from mindspore.common.sparse_tensor import RowTensorInner
27from mindspore.common.parameter import Parameter
28from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
29from mindspore.ops import functional as F
30from mindspore.ops import composite as C
31from mindspore.ops import operations as P
32from mindspore.ops.operations.nn_ops import AllFinite
33from mindspore.common import dtype as mstype
34from mindspore.common.api import jit
35from mindspore._c_expression import MSContext
36
37_grad_scale = C.MultitypeFuncGraph("grad_scale")
38reciprocal = P.Reciprocal()
39
40
41@_grad_scale.register("Tensor", "Tensor")
42def tensor_grad_scale(scale, grad):
43    return grad * F.cast(reciprocal(scale), F.dtype(grad))
44
45
46@_grad_scale.register("Tensor", "RowTensor")
47def tensor_grad_scale_row_tensor(scale, grad):
48    return RowTensorInner(grad.indices,
49                          grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
50                          grad.dense_shape)
51
52_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
53grad_overflow = P.FloatStatus()
54
55
56@_grad_overflow.register("Tensor")
57def _tensor_grad_overflow(grad):
58    return grad_overflow(grad)
59
60
61@_grad_overflow.register("RowTensor")
62def _tensor_grad_overflow_row_tensor(grad):
63    return grad_overflow(grad.values)
64
65
66_ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow")
67ascend_grad_overflow = P.IsFinite()
68
69
70@_ascend_grad_overflow.register("Tensor")
71def _tensor_ascend_grad_overflow(grad):
72    status = ascend_grad_overflow(grad)
73    base = Tensor(1.0, dtype=mstype.float32)
74    output = base - status.all()
75    output = P.Reshape()(output, ((-1,)))
76    return output
77
78
79@_ascend_grad_overflow.register("RowTensor")
80def _tensor_ascend_grad_overflow_row_tensor(grad):
81    status = ascend_grad_overflow(grad.values)
82    base = Tensor(1.0, dtype=mstype.float32)
83    output = base - status.all()
84    output = P.Reshape()(output, ((1,)))
85    return output
86
87
88class DynamicLossScaleUpdateCell(Cell):
89    r"""
90    Dynamic Loss scale update cell.
91
92    For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`.
93    In each training step, the loss scaling value will be decreased by `loss_scale`/`scale_factor`
94    when there is an overflow. And it will be increased by `loss_scale` * `scale_factor` if there is no
95    overflow for a continuous `scale_window` steps.
96
97    `get_update_cell` method of :class:`mindspore.amp.DynamicLossScaleManager` will return this class. It will be called
98    by :class:`mindspore.nn.TrainOneStepWithLossScaleCell` during training to update loss scale.
99
100    Args:
101        loss_scale_value (float): Initializes loss scale.
102        scale_factor (int): Coefficient of increase and decrease.
103        scale_window (int): Maximum continuous training steps that do not have overflow to increase the loss scale.
104
105    Inputs:
106        - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`.
107        - **overflow** (bool) - Whether the overflow occurs or not.
108
109    Outputs:
110        bool, the input `overflow`.
111
112    Supported Platforms:
113        ``Ascend`` ``GPU``
114
115    Examples:
116        >>> import numpy as np
117        >>> import mindspore
118        >>> from mindspore import Tensor, Parameter, nn, ops
119        >>>
120        >>> class Net(nn.Cell):
121        ...     def __init__(self, in_features, out_features):
122        ...         super(Net, self).__init__()
123        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
124        ...                                 name='weight')
125        ...         self.matmul = ops.MatMul()
126        ...
127        ...     def construct(self, x):
128        ...         output = self.matmul(x, self.weight)
129        ...         return output
130        ...
131        >>> in_features, out_features = 16, 10
132        >>> net = Net(in_features, out_features)
133        >>> loss = nn.MSELoss()
134        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
135        >>> net_with_loss = nn.WithLossCell(net, loss)
136        >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
137        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
138        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
139        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
140        >>> output = train_network(input, labels)
141    """
142
143    def __init__(self,
144                 loss_scale_value,
145                 scale_factor,
146                 scale_window):
147        super(DynamicLossScaleUpdateCell, self).__init__()
148
149        self.scale_window = Tensor(scale_window, dtype=mstype.int32)
150        self.scale_factor = Tensor(scale_factor, dtype=mstype.float32)
151        self.loss_scale_value = loss_scale_value
152
153        self.cur_iter = Parameter(Tensor(1, dtype=mstype.int32), name="current_iterator_step")
154        self.last_overflow_iter = Parameter(Tensor(0, dtype=mstype.int32), name="last_overflow_iterator_step")
155        self.select = P.Select()
156        self.max = P.Maximum()
157        self.minimum_loss_scale = Tensor(1.0, dtype=mstype.float32)
158        self.reciprocal = P.Reciprocal()
159        self.less_equal = P.LessEqual()
160        self.logic_and = P.LogicalAnd()
161        self.logic_not = P.LogicalNot()
162        self.logic_or = P.LogicalOr()
163        self.const_true = Tensor(True, dtype=mstype.bool_)
164
165    def get_loss_scale(self):
166        """
167        Get Loss Scale value.
168
169        Returns:
170            float, the loss scale value.
171
172        Examples:
173            >>> from mindspore import nn
174            >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=212, scale_factor=2, scale_window=1000)
175            >>> output = manager.get_loss_scale()
176            >>> print(output)
177            212
178        """
179        return self.loss_scale_value
180
181    def construct(self, loss_scale, overflow):
182        overflow_cond = overflow
183        loss_scale_on_overflow = self.select(overflow_cond, self.max(loss_scale * self.reciprocal(self.scale_factor),
184                                                                     self.minimum_loss_scale), loss_scale)
185        should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter)
186        last_iter_cond = self.logic_or(overflow_cond, should_inc)
187        last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter)
188        last_iter = F.assign(self.last_overflow_iter, last_overflow_iter)
189        update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond))
190        scale_mul_res = loss_scale_on_overflow * self.scale_factor
191        scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow)
192        F.assign(loss_scale, scaled_loss_scale)
193        inc_cur_iter = self.cur_iter + 1
194        inc_cur_iter = F.depend(inc_cur_iter, last_iter)
195        F.assign(self.cur_iter, inc_cur_iter)
196        return overflow
197
198
199class FixedLossScaleUpdateCell(Cell):
200    """
201    Update cell with fixed loss scaling value.
202
203    `get_update_cell` method of :class:`mindspore.amp.FixedLossScaleManager` will return this class. It will be called
204    by :class:`mindspore.nn.TrainOneStepWithLossScaleCell` during trainning.
205
206    Args:
207        loss_scale_value (float): Initializes loss scale.
208
209    Inputs:
210        - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`, it is ignored in this
211          class.
212        - **overflow** (bool) - Whether the overflow occurs or not.
213
214    Outputs:
215        bool, the input `overflow`.
216
217    Supported Platforms:
218        ``Ascend`` ``GPU``
219
220    Examples:
221        >>> import numpy as np
222        >>> import mindspore
223        >>> from mindspore import Tensor, Parameter, nn, ops
224        >>>
225        >>> class Net(nn.Cell):
226        ...     def __init__(self, in_features, out_features):
227        ...         super(Net, self).__init__()
228        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
229        ...                                 name='weight')
230        ...         self.matmul = ops.MatMul()
231        ...
232        ...     def construct(self, x):
233        ...         output = self.matmul(x, self.weight)
234        ...         return output
235        ...
236        >>> in_features, out_features = 16, 10
237        >>> net = Net(in_features, out_features)
238        >>> loss = nn.MSELoss()
239        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
240        >>> net_with_loss = nn.WithLossCell(net, loss)
241        >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
242        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
243        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
244        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
245        >>> output = train_network(input, labels)
246    """
247
248    def __init__(self, loss_scale_value):
249        super(FixedLossScaleUpdateCell, self).__init__()
250        self.loss_scale_value = loss_scale_value
251
252    def get_loss_scale(self):
253        """
254        Get Loss Scale value.
255
256        Returns:
257            float, the loss scale value.
258
259        Examples:
260            >>> from mindspore import nn
261            >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=212)
262            >>> output = manager.get_loss_scale()
263            >>> print(output)
264            212
265        """
266        return self.loss_scale_value
267
268    def construct(self, _, overflow):
269        return overflow
270
271
272class TrainOneStepWithLossScaleCell(TrainOneStepCell):
273    r"""
274    Network training with loss scaling.
275
276    This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as
277    args. The loss scale value can be updated in both host side or device side. If you want to update it on
278    host side, using a value of Tensor type as `scale_sense`, otherwise, using a Cell instance for updating loss
279    scale as `scale_sense`.
280
281    Args:
282        network (Cell): The training network. The network only supports single output.
283        optimizer (Cell): Optimizer for updating the network parameters.
284        scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called by `TrainOneStepWithLossScaleCell`
285            to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
286            the shape should be :math:`()` or :math:`(1,)`.
287
288    Inputs:
289        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
290
291    Outputs:
292        Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.
293
294        - **loss** (Tensor) -  A scalar, the loss value.
295        - **overflow** (Tensor) -  A scalar, whether overflow occur or not, the type is bool.
296        - **loss scale** (Tensor) -  The loss scale value, the shape is :math:`()` or :math:`(1,)`.
297
298    Raises:
299        TypeError: If `scale_sense` is neither Cell nor Tensor.
300        ValueError: If shape of `scale_sense` is neither :math:`(1,)` nor :math:`()`.
301
302    Supported Platforms:
303        ``Ascend`` ``GPU``
304
305    Examples:
306        >>> import numpy as np
307        >>> import mindspore
308        >>> from mindspore import Tensor, Parameter, nn, ops
309        >>>
310        >>> class Net(nn.Cell):
311        ...     def __init__(self, in_features, out_features):
312        ...         super(Net, self).__init__()
313        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
314        ...                                 name='weight')
315        ...         self.matmul = ops.MatMul()
316        ...
317        ...     def construct(self, x):
318        ...         output = self.matmul(x, self.weight)
319        ...         return output
320        ...
321        >>> size, in_features, out_features = 16, 16, 10
322        >>> #1) when the type of scale_sense is Cell:
323        >>> net = Net(in_features, out_features)
324        >>> loss_fn = nn.MSELoss()
325        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
326        >>> net_with_loss = nn.WithLossCell(net, loss_fn)
327        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
328        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
329        >>> loss = net_with_loss(input, labels)
330        >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
331        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
332        >>> status = Tensor([0] * 8, mindspore.int32)
333        >>> scaling_sens = train_network.scale_sense
334        >>> scaling_sens_filled = ops.ones_like(loss) * ops.cast(scaling_sens, ops.dtype(loss))
335        >>> grads = train_network.grad(train_network.network, train_network.weights)(input, labels, scaling_sens_filled)
336        >>> grads = train_network.grad_reducer(grads)
337        >>> cond = train_network.get_overflow_status(status, grads)
338        >>> overflow = train_network.process_loss_scale(cond)
339        >>>
340        >>> #2) when the type of scale_sense is Tensor:
341        >>> net = Net(in_features, out_features)
342        >>> loss = nn.MSELoss()
343        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
344        >>> net_with_loss = nn.WithLossCell(net, loss)
345        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
346        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
347        >>> scaling_sens = Tensor([1024], dtype=mindspore.float32)
348        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
349        >>> scaling_sens = Tensor([1], dtype=mstype.float32)
350        >>> train_network.set_sense_scale(scaling_sens)
351        >>> output = train_network(inputs, label)
352        >>>
353        >>> # update scaling sens and train the network
354        >>> scaling_sens = Tensor([1], dtype=mindspore.float32)
355        >>> train_network.set_sense_scale(scaling_sens)
356        >>> output = train_network(inputs, label)
357    """
358    def __init__(self, network, optimizer, scale_sense):
359        super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
360        self.hyper_map = C.HyperMap()
361        self.base = Tensor(1, mstype.float32)
362        self.base0 = Tensor(0, mstype.int32)
363        self.reduce_sum = P.ReduceSum(keep_dims=False)
364        self.reduce_all = P.ReduceAll(keep_dims=False)
365        self.less_equal = P.LessEqual()
366        self.equal = P.Equal()
367        self.logic_not = P.LogicalNot()
368        self.allreduce = P.AllReduce()
369        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
370        self.gpu_target = (context.get_context("device_target") == "GPU")
371        self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
372        self.ascend_910bc_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910c'])
373        self.loss_scaling_manager = None
374        self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
375
376        self.enable_allfinite = False
377        runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
378        global_jit_config = context.get_jit_config()
379        if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
380            self.enable_allfinite = True
381        elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
382            self.enable_allfinite = False
383        elif global_jit_config:
384            self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
385
386        if isinstance(scale_sense, Cell):
387            self.loss_scaling_manager = scale_sense
388            self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
389                                         name="scale_sense")
390        elif isinstance(scale_sense, Tensor):
391            if scale_sense.shape == (1,) or scale_sense.shape == ():
392                self.scale_sense = Parameter(scale_sense, name='scale_sense')
393            else:
394                raise ValueError("For 'TrainOneStepWithLossScaleCell', "
395                                 "the shape of 'scale_sense' must be (1,) or (), but got {}."
396                                 .format(scale_sense.shape))
397        else:
398            raise TypeError("For 'TrainOneStepWithLossScaleCell', "
399                            "the 'scale_sense' must be Cell or Tensor, but got 'scale_sense' type: {}."
400                            .format(type(scale_sense)))
401        self.enable_tuple_broaden = True
402        self._get_attr_from_cell(network)
403
404    def construct(self, *inputs):
405        weights = self.weights
406        loss = self.network(*inputs)
407        scaling_sens = self.scale_sense
408        status = Tensor([0] * 8, mstype.int32)
409
410        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
411        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
412        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
413        # apply grad reducer on grads
414        grads = self.grad_reducer(grads)
415
416        # get the overflow buffer
417        cond = self.get_overflow_status(status, grads)
418        overflow = self.process_loss_scale(cond)
419        # if there is no overflow, do optimize
420        if not overflow:
421            loss = F.depend(loss, self.optimizer(grads))
422        return loss, cond, scaling_sens
423
424    def set_sense_scale(self, sens):
425        """
426        If the user has set the `scale_sense` of Tensor type, he can call this function to reassign the value.
427
428        Args:
429            sens(Tensor): The new sense whose shape and type are the same with original `scale_sense`.
430        """
431        if self.scale_sense and isinstance(sens, Tensor):
432            self.scale_sense.set_data(sens)
433        else:
434            raise TypeError("For 'TrainOneStepWithLossScaleCell', "
435                            "the type of 'sens' must be Tensor, but got {}".format(type(sens)))
436
437    def start_overflow_check(self, pre_cond, compute_input):
438        """
439        Start floating-point overflow detection. Create and clear the overflow detection state.
440
441        Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
442        Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
443        overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
444        function, and compute_input should be the input of gradients-computing function. User-defined training network
445        based on this class can also call this interface to process the overflow.
446
447        Args:
448            pre_cond(Tensor): A precondition for starting overflow detection. It determines the executing order
449              of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
450              clears status after finishing the process of precondition.
451            compute_input(object): The input of subsequent process. Overflow detection should be performed on a
452              certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
453              cleared before executing the computation.
454
455        Returns:
456            Tuple[object, object], the first output is used to control the execution sequence. To ensure that the
457            `start_overflow_check` is executed before get_overflow_status after compilation optimization is performed.
458            This value should be used as the first input of get_overflow_status. The second output is the same as
459            the input of compute_input, used to control the execution sequence, and make ensure that the overflow flag
460            is cleaned up when the function returns.
461        """
462        status = Tensor([0] * 8, mstype.int32)
463        if self.ascend_910a_target or (self.ascend_910bc_target and \
464                                       self._ascend_check_overflow_mode == "SATURATION_MODE"):
465            status = F.depend(status, pre_cond)
466            # clear overflow buffer
467            clear_status = NPUClearFloatStatusV2()(status)
468            compute_input = F.depend(compute_input, clear_status)
469        return status, compute_input
470
471    def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
472        """check overflow status on infnan mode."""
473        flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output)
474        flag_sum = P.AddN()(flag_sum)
475        # convert flag_sum to scalar
476        flag_sum = P.Reshape()(flag_sum, (()))
477        return flag_sum
478
479    def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
480        """converge the distributed overflow status on infnan mode."""
481        flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output)
482
483        if self.is_distributed:
484            # sum overflow flag over devices
485            flag_reduce = self.allreduce(flag_sum)
486            overflow = self.less_equal(self.base, flag_reduce)
487        else:
488            overflow = self.less_equal(self.base, flag_sum)
489        return overflow
490
491    def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output):
492        """check overflow status on infnan kernel mode."""
493        overflow = AllFinite()(compute_output)
494
495        if self.is_distributed:
496            overflow = P.Cast()(overflow, mstype.int8)
497            overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
498        return overflow
499
500    def _get_gpu_overflow_status(self, compute_output):
501        """get overflow status of gpu."""
502        overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
503        return overflow
504
505    def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
506        """get overflow status of ascend on infnan mode."""
507        overflow = False
508        if self.enable_allfinite:
509            overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output)
510        else:
511            overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
512        return overflow
513
514    def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
515        """get overflow status of ascend on saturation mode"""
516        status = F.depend(status, compute_output)
517        get_status = NPUGetFloatStatusV2()(status)
518
519        if self.is_distributed:
520            # sum overflow flag over devices
521            flag_reduce = self.allreduce(get_status)
522            # get_status not equal to [0]*8 means overflow
523            flag = self.equal(self.base0, flag_reduce)
524            status = F.depend(status, flag)
525            # distributed needs to skip allreduce to avoid its overflow affecting the next step
526            clear_status = NPUClearFloatStatusV2()(status)
527            flag = F.depend(flag, clear_status)
528            overall_finite = self.reduce_all(flag)
529        else:
530            status = F.depend(status, get_status)
531            clear_status = NPUClearFloatStatusV2()(status)
532            get_status = F.depend(get_status, clear_status)
533            flag = self.equal(self.base0, get_status)
534            overall_finite = self.reduce_all(flag)
535        overflow = self.logic_not(overall_finite)
536        return overflow
537
538    @jit
539    def get_overflow_status(self, status, compute_output):
540        """
541        Get floating-point overflow status.
542
543        Get overflow results after executing the target process for overflow detection. User-defined training network
544        based on this class can also call this interface to process the overflow.
545
546        Args:
547            status (object): To control the execution sequence with start_overflow_check, it should be set as the first
548              output of start_overflow_check.
549            compute_output: Overflow detection should be performed in a certain computation process. Set
550              `compute_output` as the output of the computation process.
551
552        Returns:
553            bool, whether the overflow occurs or not.
554        """
555        if self.gpu_target:
556            overflow = self._get_gpu_overflow_status(compute_output)
557        elif self.ascend_910bc_target:
558            if self._ascend_check_overflow_mode == "SATURATION_MODE":
559                overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
560            else:
561                overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output)
562        else:  # ascend_910a_target
563            overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
564        return overflow
565
566    def process_loss_scale(self, overflow):
567        """
568        Calculate loss scale according to the overflow.
569
570        User-defined training network based on this class can also call this interface to process the overflow.
571
572        Args:
573            overflow(bool): Whether the overflow occurs or not.
574
575        Returns:
576            bool, the input overflow value.
577        """
578        if self.loss_scaling_manager is not None:
579            return self.loss_scaling_manager(self.scale_sense, overflow)
580        return overflow
581
582
583grad_scale = C.MultitypeFuncGraph("grad_scale")
584shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
585reciprocal = P.Reciprocal()
586
587
588@grad_scale.register("Tensor", "Tensor", "Tensor")
589def tensor_grad_scale_pipeline(scale, grad, accu_grad):
590    accu_grad = F.depend(accu_grad, grad)
591    new_grad = accu_grad * reciprocal(scale)
592    accu_grad = F.depend(accu_grad, new_grad)
593    zeros = F.tensor_mul(accu_grad, 0.0)
594    new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
595    return new_grad
596
597
598@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
599def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
600    new_grad = grad * reciprocal(scale)
601    accu_grad = F.depend(accu_grad, new_grad)
602    new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
603    return new_grad
604
605
606class _TrainGradAccuWithLossScaleCell(TrainOneStepCell):
607    """
608    Append an optimizer to the training network after that the construct
609    function can be called to create the backward graph.
610
611    Args:
612        network (Cell): The training network. Note that loss function should have been added.
613        optimizer (Optimizer): Optimizer for updating the weights.
614        scale_sense (Cell): Cell to do the loss scale.
615    """
616    def __init__(self, network, optimizer, scale_sense):
617        super(_TrainGradAccuWithLossScaleCell, self).__init__(network, optimizer, sens=None)
618        self.network = network
619        self.network.add_flags(defer_inline=True)
620        self.weights = optimizer.parameters
621        self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
622        self.optimizer = optimizer
623        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
624        self.grad_reducer = nn.Identity()
625        self.degree = 1
626        self.cast = P.Cast()
627        self.alloc_status = P.NPUAllocFloatStatus()
628        self.get_status = P.NPUGetFloatStatus()
629        self.clear_before_grad = P.NPUClearFloatStatus()
630        self.reduce_sum = P.ReduceSum(keep_dims=False)
631        if self.parallel_mode not in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]:
632            raise ValueError(f"ParallelMode must be one of "
633                             f"[ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL], but found "
634                             f"{self.parallel_mode}.")
635        self.allreduce = P.AllReduce()
636        self.base = Tensor(1, mstype.float32)
637        self.less_equal = P.LessEqual()
638        self.hyper_map = C.HyperMap()
639        self.reshape = P.Reshape()
640        self.loss_scaling_manager = None
641        if isinstance(scale_sense, Cell):
642            self.loss_scaling_manager = scale_sense
643            self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
644                                         name="scale_sense")
645        elif isinstance(scale_sense, Tensor):
646            if scale_sense.shape == (1,) or scale_sense.shape == ():
647                self.scale_sense = Parameter(scale_sense, name='scale_sense')
648            else:
649                raise ValueError("The shape of 'scale_sense' must be (1,) or (), but got {}"
650                                 .format(scale_sense.shape))
651        else:
652            raise TypeError("The 'scale_sense' must be Cell or Tensor, but got {}".format(type(scale_sense)))
653        self.opt_shard = _get_enable_parallel_optimizer()
654
655    def construct(self, *inputs):
656        loss = self.network(*inputs)
657        scaling_sens = self.scale_sense
658        init = self.alloc_status()
659        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
660        scaling_sens_filled = F.depend(scaling_sens_filled, self.clear_before_grad(init))
661        grads = self.grad(self.network, self.weights)(*inputs, scaling_sens_filled)
662        init = F.depend(init, grads)
663        get_status = self.get_status(init)
664        init = F.depend(init, get_status)
665        flag_sum = self.reduce_sum(init, (0,))
666        if self.opt_shard:
667            grads = self.grad_reducer(grads)
668            grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
669        else:
670            accu_grads = self.grad_reducer(self.accu_grads)
671            grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
672        # sum overflow flag over devices
673        flag_reduce = self.allreduce(flag_sum)
674        cond = self.less_equal(self.base, flag_reduce)
675        overflow = cond
676        if self.loss_scaling_manager is not None:
677            overflow = self.loss_scaling_manager(self.scale_sense, cond)
678        if not overflow:
679            self.optimizer(grads)
680        return (loss, overflow, scaling_sens)
681