• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Boost Mode Cell Wrapper."""
16from __future__ import absolute_import
17
18import numpy as np
19from mindspore.nn.wrap import TrainOneStepCell
20import mindspore.context as context
21from mindspore.context import ParallelMode
22from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean
23from mindspore.communication.management import get_group_size, create_group
24from mindspore.nn.cell import Cell
25from mindspore.nn import SequentialCell
26from mindspore.common import Tensor
27from mindspore.common.sparse_tensor import RowTensorInner
28from mindspore.common.parameter import Parameter, ParameterTuple
29from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
30from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
31from mindspore.ops import functional as F
32from mindspore.ops import composite as C
33from mindspore.ops import operations as P
34from mindspore.common import dtype as mstype
35from mindspore.boost.boost import AutoBoost
36from mindspore.boost.grad_freeze import FreezeOpt, freeze_cell
37from mindspore.boost.adasum import AdaSum
38from mindspore.boost.dim_reduce import DimReduce
39from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
40from mindspore.boost.base import _load_local_pca_mat
41
42__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"]
43
44_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
45
46
47@_get_delta_weight.register("Tensor", "Tensor")
48def _get_delta_weight_process(new_parameter, old_parameter):
49    delta_w = old_parameter - new_parameter
50    return delta_w
51
52
53_save_weight = C.MultitypeFuncGraph("_save_weight")
54
55
56@_save_weight.register("Tensor", "Tensor")
57def _save_weight_process(new_parameter, old_parameter):
58    P.Assign()(new_parameter, old_parameter)
59    return new_parameter
60
61
62_grad_scale = C.MultitypeFuncGraph("grad_scale")
63reciprocal = P.Reciprocal()
64
65
66@_grad_scale.register("Tensor", "Tensor")
67def tensor_grad_scale(scale, grad):
68    """grad scale function for tensor"""
69    return grad * F.cast(reciprocal(scale), F.dtype(grad))
70
71
72@_grad_scale.register("Tensor", "RowTensor")
73def tensor_grad_scale_row_tensor(scale, grad):
74    """grad scale function for row tensor"""
75    return RowTensorInner(grad.indices,
76                          grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
77                          grad.dense_shape)
78
79
80_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
81grad_overflow = P.FloatStatus()
82
83
84@_grad_overflow.register("Tensor")
85def _tensor_grad_overflow(grad):
86    return grad_overflow(grad)
87
88
89@_grad_overflow.register("RowTensor")
90def _tensor_grad_overflow_row_tensor(grad):
91    return grad_overflow(grad.values)
92
93
94class _OutputToFloat16(Cell):
95    "Wrap cell for amp. Cast network output back to float16"
96
97    def __init__(self, op):
98        super(_OutputToFloat16, self).__init__(auto_prefix=False)
99        self._op = op
100
101    def construct(self, *inputs):
102        return F.cast(self._op(*inputs), mstype.float16)
103
104
105class BoostTrainOneStepCell(TrainOneStepCell):
106    r"""
107    Boost Network training package class.
108
109    Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'.
110    The backward graph will be created in the construct function to update the parameter, and different
111    parallel modes are available for training.
112
113    Args:
114        network (Cell): The training network. The network only supports single output.
115        optimizer (Union[Cell]): Optimizer for updating the weights.
116        sens (numbers.Number): The scaling number to be filled as the input of backpropagation.
117            Default: ``None`` , which is ``1.0`` .
118
119    Inputs:
120        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
121
122    Outputs:
123        Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
124
125        - loss(Tensor): A scalar Tensor.
126        - overflow(Tensor): A scalar Tensor which type is bool.
127        - loss scaling value(Tensor): A scalar Tensor.
128
129    Raises:
130        TypeError: If `sens` is not a number.
131
132    Supported Platforms:
133        ``Ascend`` ``GPU`` ``CPU``
134
135    Examples:
136        >>> from mindspore import boost
137        >>> from mindspore import nn
138        >>> # Define the network structure of LeNet5. Refer to
139        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
140        >>> net = LeNet5()
141        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
142        >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
143        >>> #1) Using the WithLossCell existing provide
144        >>> loss_net = nn.WithLossCell(net, loss_fn)
145        >>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
146        >>>
147        >>> #2) Using user-defined WithLossCell
148        >>> class MyWithLossCell(nn.Cell):
149        ...    def __init__(self, backbone, loss_fn):
150        ...        super(MyWithLossCell, self).__init__(auto_prefix=False)
151        ...        self._backbone = backbone
152        ...        self._loss_fn = loss_fn
153        ...
154        ...    def construct(self, x, y, label):
155        ...        out = self._backbone(x, y)
156        ...        return self._loss_fn(out, label)
157        ...
158        ...    @property
159        ...    def backbone_network(self):
160        ...        return self._backbone
161        ...
162        >>> loss_net = MyWithLossCell(net, loss_fn)
163        >>> train_net = boost.BoostTrainOneStepCell(loss_net, optim)
164    """
165
166    def __init__(self, network, optimizer, sens=None):
167        super(BoostTrainOneStepCell, self).__init__(network, optimizer, sens)
168        self.hyper_map = C.HyperMap()
169        self.freeze = isinstance(optimizer, FreezeOpt)
170        if not self.freeze:
171            self.weights = self.optimizer.parameters
172        self.train_strategy = getattr(self.optimizer, 'train_strategy', None)
173
174        self.auto_boost = AutoBoost()
175        self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE)
176        self.use_grad_accumulation = self.use_grad_accumulation & \
177                                     self.auto_boost.boost_config.get("grad_accumulation", False)
178        self.max_accumulation_step = 1
179        if self.use_grad_accumulation:
180
181            self.max_accumulation_step = self.auto_boost.grad_accumulation_step
182            if self.max_accumulation_step <= 1:
183                self.max_accumulation_step = 1
184                self.use_grad_accumulation = False
185        self.accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
186        if self.use_grad_accumulation:
187            self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
188
189        self.enable_dim_reduce = self.check_dim_reduce_enable()
190        if self.enable_dim_reduce:
191            self.__init_dim_reduce()
192
193        self.freeze_nets = None
194        self.step = Parameter(Tensor(0, dtype=mstype.int32))
195        if self.freeze:
196            if self.reducer_flag:
197                self.mean = _get_gradients_mean()
198                self.degree = _get_device_num()
199            else:
200                self.mean = None
201                self.degree = None
202            self.freeze_nets = freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens,
203                                           self.grad, self.use_grad_accumulation, self.mean, self.degree,
204                                           self.max_accumulation_step)
205
206        self.enable_adasum = self.check_adasum_enable()
207        self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
208        if self.enable_adasum:
209            self.__init_adasum()
210
211    def construct(self, *inputs):
212        if self.freeze:
213            loss = self.gradient_freeze_process(*inputs)
214        else:
215            if not self.sense_flag:
216                return self._no_sens_impl(*inputs)
217            loss = self.network(*inputs)
218            sens = F.fill(loss.dtype, loss.shape, self.sens)
219            grads = self.grad(self.network, self.weights)(*inputs, sens)
220            grads = self.grad_reducer(grads)
221            if self.use_grad_accumulation:
222                loss = self.gradient_accumulation_process(loss, grads, sens, *inputs)
223            else:
224                if self.enable_dim_reduce:
225                    loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs))
226                elif self.enable_adasum:
227                    loss = F.depend(loss, self.adasum_process(loss, grads))
228                else:
229                    loss = F.depend(loss, self.optimizer(grads))
230        return loss
231
232    def gradient_freeze_process(self, *inputs):
233        r"""
234        Gradient freeze algorithm process.
235
236        Args:
237            inputs (tuple(Tensor)): Tuple of input tensors with shape :math:`(N, \ldots)`.
238
239        Returns:
240            - **loss** (Tensor) -  Network loss, tensor with shape :math:`()`.
241        """
242        if self.train_strategy is None:
243            step = self.step
244            max_index = len(self.freeze_nets)
245        else:
246            step = self.train_strategy[self.step]
247            max_index = len(self.train_strategy)
248        loss = self.freeze_nets[step](*inputs)
249        if self.step + 1 >= max_index:
250            self.step = 0
251        else:
252            self.step += 1
253        return loss
254
255    def gradient_accumulation_process(self, loss, grads, sens, *inputs):
256        r"""
257        Gradient accumulation algorithm process.
258
259        Args:
260            loss (Tensor): Tensor with shape :math:`()`.
261            grads (tuple(Tensor)): Tuple of gradient tensors.
262            sens (Tensor): Tensor with shape :math:`()`.
263            inputs (tuple(Tensor)): Tuple of input tensors with shape :math:`(N, \ldots)`.
264
265        Returns:
266            - **loss** (Tensor) - Network loss, tensor with shape :math:`()`.
267        """
268        loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step),
269                                             self.grad_accumulation, grads))
270        self.accumulation_step += 1
271
272        if self.accumulation_step >= self.max_accumulation_step:
273            if self.enable_dim_reduce:
274                loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, sens, self.weights,
275                                                      self.weights_clone, *inputs))
276            elif self.enable_adasum:
277                loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation))
278            else:
279                loss = F.depend(loss, self.optimizer(self.grad_accumulation))
280            self.accumulation_step = 0
281
282        if self.accumulation_step == 0:
283            loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self.grad_accumulation))
284
285        return loss
286
287    def adasum_process(self, loss, grads):
288        r"""
289        Adasum algorithm process.
290
291        Args:
292            loss (Tensor): Tensor with shape :math:`()`.
293            grads (tuple(Tensor)): Tuple of gradient tensors.
294
295        Returns:
296            - **loss** (Tensor) - Network loss, tensor with shape :math:`()`.
297        """
298        loss = F.depend(loss, self.optimizer(grads))
299        rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
300        grad_clone = F.depend(self.grad_clone, loss)
301        delta_w = self.hyper_map(F.partial(_get_delta_weight), rank_weights, grad_clone)
302        adasum_res = self.adasum(delta_w, rank_weights, grad_clone)
303        sync_tensor = F.depend(self.sync_tensor, adasum_res)
304        sync_flag = self.adasum.sync_barrier(sync_tensor)
305        for i in range(self.device_number):
306            weight_tuple = self.weights[self.start[i]: self.end[i]]
307            node_rank = F.depend(weight_tuple, sync_flag)
308            update_weights = self.adasum.broadcast_list[i](node_rank)
309            if i == self.server_rank:
310                self.hyper_map(F.partial(_save_weight), self.grad_clone, update_weights)
311            else:
312                self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights)
313        return loss
314
315    def check_adasum_enable(self):
316        r"""
317        Check adasum enable.
318
319        Returns:
320            - **enable_adasum** (bool) - Check whether the Adasum algorithm is enabled.
321        """
322        if not getattr(self.optimizer, "adasum", None) or not self.reducer_flag:
323            return False
324        _rank_size = get_group_size()
325        _device_number = 8
326        group_number = _rank_size // _device_number
327        is_enable = bool(group_number > 1 and group_number & (group_number - 1) == 0)
328        return is_enable
329
330    def check_dim_reduce_enable(self):
331        r"""
332        Check dim_reduce enable.
333
334        Returns:
335            - **enable_dim_reduce** (bool) - Check whether the dimensionality reduction second-order training
336              algorithm is enabled.
337        """
338        if not getattr(self.optimizer, "dim_reduce", None):
339            return False
340        return True
341
342    def _no_sens_impl(self, *inputs):
343        """construct implementation when the 'sens' parameter is passed in."""
344        loss = self.network(*inputs)
345        sens = F.fill(loss.dtype, loss.shape, self.sens)
346        grads = self.grad_no_sens(self.network, self.weights)(*inputs)
347        grads = self.grad_reducer(grads)
348        if self.use_grad_accumulation:
349            loss = self.gradient_accumulation_process(loss, grads, sens, *inputs)
350        else:
351            if self.enable_dim_reduce:
352                loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs))
353            elif self.enable_adasum:
354                loss = F.depend(loss, self.adasum_process(loss, grads))
355            else:
356                loss = F.depend(loss, self.optimizer(grads))
357
358    def __init_dim_reduce(self):
359        """dim reduce algorithm init method."""
360        local_pca_mat_path = self.auto_boost.local_pca_mat_path
361        rho = self.auto_boost.rho
362        gamma = self.auto_boost.gamma
363        alpha = self.auto_boost.alpha
364        sigma = self.auto_boost.sigma
365        _rank = _get_global_rank()
366        _rank_size = 1 if self.parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
367        n_components = self.auto_boost.n_components
368        timeout = self.auto_boost.timeout
369        pca_mat = _load_local_pca_mat(local_pca_mat_path, timeout)
370        self.weights_clone = ParameterTuple(self.weights).clone(prefix="weights_clone", init="same")
371        self.dim_reduce = DimReduce(self.network, self.optimizer, self.weights, pca_mat, n_components, rho, gamma,
372                                    alpha, sigma, _rank, _rank_size)
373
374    def __init_adasum(self):
375        """adasum algorithm init method."""
376        _rank = _get_global_rank()
377        _rank_size = get_group_size()
378        _device_number = self.auto_boost.device_number
379        self.device_number = _device_number
380        group_number = _rank_size // _device_number
381
382        self.server_rank = _rank % _device_number
383        parameter_rank_number = len(self.weights) // _device_number
384        self.start = [x * parameter_rank_number for x in range(_device_number)]
385        self.end = [(x + 1) * parameter_rank_number for x in range(_device_number)]
386        self.end[-1] = len(self.weights)
387
388        current_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
389        self.grad_clone = ParameterTuple(current_weights).clone(prefix="delta_weight")
390        self.adasum = AdaSum(_rank, _device_number, group_number, self.grad_clone)
391
392        self.degree = int(self.degree // group_number)
393        group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)]
394        current_index = _rank // _device_number
395        server_group_name = "allreduce_" + str(current_index)
396        create_group(server_group_name, group_list[current_index])
397        self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=server_group_name)
398
399
400class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell):
401    r"""
402    Boost Network training with loss scaling.
403
404    This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
405    Cell as args. The loss scale value can be updated in both host side or device side. The
406    BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
407    The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
408    the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
409    must be provide by Cell type of `scale_sense`.
410
411    Args:
412        network (Cell): The training network. The network only supports single output.
413        optimizer (Cell): Optimizer for updating the weights.
414        scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
415            is Tensor type, :func:`mindspore.nn.TrainOneStepWithLossScaleCell.set_sense_scale` can be called to update
416            loss scale factor, Tensor with shape :math:`()` or :math:`(1,)`.
417
418    Inputs:
419        - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
420
421    Outputs:
422        Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
423
424        - **loss** (Tensor) -  Tensor with shape :math:`()`.
425        - **overflow** (Tensor) -  Tensor with shape :math:`()`, type is bool.
426        - **loss scaling value** (Tensor) -  Tensor with shape :math:`()`
427
428    Raises:
429        TypeError: If `scale_sense` is neither Cell nor Tensor.
430        ValueError: If shape of `scale_sense` is neither :math:`(1,)` nor :math:`()`.
431
432    Supported Platforms:
433        ``Ascend`` ``GPU``
434
435    Examples:
436        >>> import numpy as np
437        >>> from mindspore import Tensor, Parameter, nn
438        >>> from mindspore import ops
439        >>> from mindspore.nn import WithLossCell
440        >>> from mindspore import dtype as mstype
441        >>> from mindspore import boost
442        >>>
443        >>> class Net(nn.Cell):
444        ...     def __init__(self, in_features, out_features):
445        ...         super(Net, self).__init__()
446        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
447        ...                                 name='weight')
448        ...         self.matmul = ops.MatMul()
449        ...
450        ...     def construct(self, x):
451        ...         output = self.matmul(x, self.weight)
452        ...         return output
453        ...
454        >>> size, in_features, out_features = 16, 16, 10
455        >>> #1) when the type of scale_sense is Cell:
456        >>> net = Net(in_features, out_features)
457        >>> loss = nn.MSELoss()
458        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
459        >>> net_with_loss = WithLossCell(net, loss)
460        >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
461        >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
462        >>> input = Tensor(np.ones([out_features, in_features]), mstype.float32)
463        >>> labels = Tensor(np.ones([out_features,]), mstype.float32)
464        >>> output = train_network(input, labels)
465        >>>
466        >>> #2) when the type of scale_sense is Tensor:
467        >>> net = Net(in_features, out_features)
468        >>> loss = nn.MSELoss()
469        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
470        >>> net_with_loss = WithLossCell(net, loss)
471        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
472        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
473        >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
474        >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
475        >>> output = train_network(inputs, label)
476    """
477
478    def __init__(self, network, optimizer, scale_sense):
479        super(BoostTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
480        self.base = Tensor(1, mstype.float32)
481        self.reduce_sum = P.ReduceSum(keep_dims=False)
482        self.less_equal = P.LessEqual()
483        self.allreduce = P.AllReduce()
484        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
485        self.gpu_target = (context.get_context("device_target") == "GPU")
486        self.loss_scaling_manager = None
487        self.base0 = Tensor(0, mstype.int32)
488        self.reduce_all = P.ReduceAll(keep_dims=False)
489        self.logic_not = P.LogicalNot()
490        self.equal = P.Equal()
491
492        if self.auto_boost.boost_config.get("loss_scale_group", False):
493            self.enable_enhanced_amp = True
494            if not isinstance(scale_sense, Cell) or not hasattr(scale_sense, "set_loss_scale_status"):
495                raise TypeError("The scale_sense must be enhanced amp Cell, bug got {}".format(type(scale_sense)))
496            self.loss_scaling_manager = scale_sense
497            self.loss_scale_groups = scale_sense.loss_scale_groups
498            self._init_enhanced_amp()
499            self._do_keep_mix_fp32(self.network)
500        else:
501            self.enable_enhanced_amp = False
502            if isinstance(scale_sense, Cell):
503                self.loss_scaling_manager = scale_sense
504                self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
505                                             name="scale_sense")
506            elif isinstance(scale_sense, Tensor):
507                if scale_sense.shape == (1,) or scale_sense.shape == ():
508                    self.scale_sense = Parameter(scale_sense, name='scale_sense')
509                else:
510                    raise ValueError("The shape of scale_sense must be (1,) or (), \
511                                     but got {}".format(scale_sense.shape))
512            else:
513                raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
514
515    def construct(self, *inputs):
516        weights = self.weights
517        loss = self.network(*inputs)
518
519        if self.enable_enhanced_amp:
520            scaling_sens = F.fill(loss.dtype, loss.shape, 1)
521            grads = self.grad(self.network, weights)(*inputs, scaling_sens)
522            grads = self.grad_reducer(grads)
523            cond, scaling_sens = self._enhanced_amp_process_overflow_status(grads)
524        else:
525            scaling_sens = self.scale_sense
526            status, scaling_sens = self._start_overflow_check(loss, scaling_sens)
527            scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
528
529            grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
530            grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
531            grads = self.grad_reducer(grads)
532
533            # get the overflow buffer
534            cond = self._get_overflow_status(status, grads)
535            overflow = self._process_loss_scale(cond)
536            # if there is no overflow, do optimize
537            if not overflow:
538                loss = self.__multi_update(loss, grads, scaling_sens_filled, *inputs)
539        return loss, cond, scaling_sens
540
541    def __multi_update(self, loss, grads, scaling_sens_filled, *inputs):
542        """enable multi-algorithm's process"""
543        if self.use_grad_accumulation:
544            loss = self.gradient_accumulation_process(loss, grads, scaling_sens_filled, *inputs)
545        else:
546            if self.enable_dim_reduce:
547                loss = F.depend(loss, self.dim_reduce(loss, grads, scaling_sens_filled, self.weights,
548                                                      self.weights_clone, *inputs))
549            elif self.enable_adasum:
550                loss = F.depend(loss, self.adasum_process(loss, grads))
551            else:
552                loss = F.depend(loss, self.optimizer(grads))
553        return loss
554
555    def _get_dynamic_overflow_status(self, param):
556        """
557        Judge whether the current network overflows.
558
559        Inputs:
560            - **param** (Tensor) - Whether the overflow occurs or not.
561
562        Outputs:
563            bool, overflow value.
564            float, update ratio.
565        """
566        flag_sum = self.equal(self.base0, param)
567        if self.reducer_flag:
568            flag_reduce = self.allreduce(flag_sum)
569            overflow = self.logic_not(self.reduce_all(flag_reduce))
570        else:
571            overflow = self.logic_not(self.reduce_all(flag_sum))
572
573        if overflow:
574            update_ratio = self.reduce_ratio
575        else:
576            update_ratio = self.growth_ratio
577        return overflow, update_ratio
578
579    def _enhanced_amp_process_overflow_status(self, grads):
580        """
581        Enhanced hybrid precision update loss scale and update weights.
582
583        Inputs:
584            - **grads** (Tuple(Tensor)) - Tuple of gradients.
585
586        Outputs:
587            bool, overflow value.
588            float, loss scale value.
589        """
590        overflow_global_flag = Tensor(0, mstype.int32)
591        layer = 0
592        loss_scale_temp = ()
593        for param in self.overflow_status_list:
594            overflow, update_ratio = self._get_dynamic_overflow_status(param)
595            if overflow:
596                overflow_global_flag += 1
597            new_loss_scale_value = self.loss_scaling_manager.update_loss_scale_status(layer, update_ratio)
598            loss_scale_temp += (new_loss_scale_value,) * self.optimizer_loss_scale[layer]
599            layer += 1
600        if P.Less()(overflow_global_flag, self.base):
601            grads = self.hyper_map(F.partial(_grad_scale), loss_scale_temp, grads)
602            overflow_global_flag = F.depend(overflow_global_flag, self.optimizer(grads))
603        return overflow_global_flag, loss_scale_temp[0]
604
605    def _set_sense_scale(self, sens):
606        """
607        If the user has set the sens in the training process and wants to reassign the value, he can call
608        this function again to make modification, and sens needs to be of type Tensor.
609
610        Inputs:
611            - **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`.
612        """
613        if self.scale_sense and isinstance(sens, Tensor):
614            self.scale_sense.set_data(sens)
615        else:
616            raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
617
618    def _start_overflow_check(self, pre_cond, compute_input):
619        """
620        Start floating-point overflow detection. Create and clear the overflow detection state.
621
622        Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
623        Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
624        overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
625        function, and compute_input should be the input of gradients-computing function.
626
627        Inputs:
628            - **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order
629              of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
630              clears status after finishing the process of precondition.
631            - **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a
632              certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
633              cleared before executing the computation.
634
635        Outputs:
636            Tuple[object, object], the first value is False for GPU backend, while it is an instance of
637            NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
638            The second value is the same as the input of `compute_input`, but contains some information about the
639            execution order.
640        """
641        status = Tensor([0] * 8, mstype.int32)
642        if not self.gpu_target:
643            status = F.depend(status, pre_cond)
644            # clear overflow buffer
645            clear_status = NPUClearFloatStatusV2()(status)
646            compute_input = F.depend(compute_input, clear_status)
647        return status, compute_input
648
649    def _get_overflow_status(self, status, compute_output):
650        """
651        Get floating-point overflow status.
652
653        Get overflow results after executing the target process for overflow detection.
654
655        Inputs:
656            - **status** (object) - A status instance used to detect the overflow.
657            - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
658              as the output of the computation, to ensure overflow status is acquired before executing the
659              computation.
660
661        Outputs:
662            bool, whether the overflow occurs or not.
663        """
664        if not self.gpu_target:
665            status = F.depend(status, compute_output)
666            get_status = NPUGetFloatStatusV2()(status)
667
668            if self.is_distributed:
669                # sum overflow flag over devices
670                flag_reduce = self.allreduce(get_status)
671                # get_status not equal to [0]*8 means overflow
672                flag = self.equal(self.base0, flag_reduce)
673                status = F.depend(status, flag)
674                # distributed needs to skip allreduce to avoid its overflow affecting the next step
675                clear_status = NPUClearFloatStatusV2()(status)
676                flag = F.depend(flag, clear_status)
677                overall_finite = self.reduce_all(flag)
678            else:
679                status = F.depend(status, get_status)
680                clear_status = NPUClearFloatStatusV2()(status)
681                get_status = F.depend(get_status, clear_status)
682                flag = self.equal(self.base0, get_status)
683                overall_finite = self.reduce_all(flag)
684            overflow = self.logic_not(overall_finite)
685        else:
686            flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
687            flag_sum = P.AddN()(flag_sum)
688            # convert flag_sum to scalar
689            flag_sum = P.Reshape()(flag_sum, (()))
690
691            if self.is_distributed:
692                # sum overflow flag over devices
693                flag_reduce = self.allreduce(flag_sum)
694                overflow = self.less_equal(self.base, flag_reduce)
695            else:
696                overflow = self.less_equal(self.base, flag_sum)
697        return overflow
698
699    def _process_loss_scale(self, overflow):
700        """
701        Calculate loss scale according to the overflow.
702
703        Inputs:
704            - **overflow** (bool) - Whether the overflow occurs or not.
705
706        Outputs:
707            bool, overflow value.
708        """
709        if self.loss_scaling_manager is not None:
710            return self.loss_scaling_manager(self.scale_sense, overflow)
711        return overflow
712
713    def _init_enhanced_amp(self):
714        """
715        Init enhanced hybrid precision.
716        """
717        self.params_len = len(self.optimizer.params)
718        self.parent = list(range(self.params_len))
719        self.layer_rank = [0 for _ in range(self.params_len)]
720        index = 0
721        loss_scale_number = len(self.loss_scale_groups)
722        for loss_scale_group in self.loss_scale_groups:
723            for i, _ in enumerate(loss_scale_group):
724                if i == 0:
725                    index += 1
726                    continue
727                self._union(index - 1, index)
728                index += 1
729        parent_set = list(set(self.parent))
730        self.optimizer_loss_scale = [self.parent.count(x) for x in parent_set]
731        self.reduce_ratio = Tensor(1.0 / (2 ** 0.5), mstype.float32)
732        self.growth_ratio = Tensor(2 ** (1.0 / 1000.0), mstype.float32)
733        self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.int32),
734                                                             name='mix_layer_status_{}'.format(x), requires_grad=False)
735                                                   for x in range(loss_scale_number))
736        self.loss_scaling_manager.set_loss_scale_status(loss_scale_number, self.loss_scaling_manager.get_loss_scale())
737
738    def _get_root(self, i):
739        """
740        Get parent id.
741
742        Args:
743            i (int): the current parameters's id.
744
745        Returns:
746            Number, the parent id.
747        """
748        if self.parent[i] != self.parent[self.parent[i]]:
749            self.parent[i] = self.get_root(self.parent[i])
750        return self.parent[i]
751
752    def _union(self, i, j):
753        """
754        Aggregate parameters of the same category.
755
756        Args:
757            i (int): the last parameters's id.
758            j (int): the current parameters's id.
759        """
760        i_root = self._get_root(i)
761        j_root = self._get_root(j)
762
763        if self.layer_rank[i_root] == self.layer_rank[j_root]:
764            self.parent[j_root] = i_root
765            self.layer_rank[i_root] += 1
766        elif self.layer_rank[i_root] > self.layer_rank[j_root]:
767            self.parent[j_root] = i_root
768        else:
769            self.parent[i_root] = j_root
770
771    def _do_keep_mix_fp32(self, network):
772        """
773        Keep enhanced amp cell of type float32.
774
775        Args:
776            network (Cell): The training network.
777        """
778        cells = network.name_cells()
779        change = False
780        for name in cells:
781            subcell = cells[name]
782            if subcell == network:
783                continue
784            if "GroupLossScaleManager" in subcell.cls_name:
785                network._cells[name] = _OutputToFloat16(subcell.to_float(mstype.float32))  # pylint: disable=W0212
786                change = True
787            else:
788                self._do_keep_mix_fp32(subcell)
789        if isinstance(network, SequentialCell) and change:
790            network.cell_list = list(network.cells())
791