• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Loss scale cell for loss scale training."""
16import mindspore.context as context
17from mindspore.context import ParallelMode
18from mindspore.parallel._utils import _get_enable_parallel_optimizer
19from .cell_wrapper import TrainOneStepCell
20from ..cell import Cell
21from ...common import Tensor, RowTensor
22from ...common.parameter import Parameter
23from ...ops import functional as F
24from ...ops import composite as C
25from ...ops import operations as P
26from ...common import dtype as mstype
27
28_grad_scale = C.MultitypeFuncGraph("grad_scale")
29reciprocal = P.Reciprocal()
30
31
32@_grad_scale.register("Tensor", "Tensor")
33def tensor_grad_scale(scale, grad):
34    return grad * F.cast(reciprocal(scale), F.dtype(grad))
35
36
37@_grad_scale.register("Tensor", "RowTensor")
38def tensor_grad_scale_row_tensor(scale, grad):
39    return RowTensor(grad.indices,
40                     grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
41                     grad.dense_shape)
42
43_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
44grad_overflow = P.FloatStatus()
45
46
47@_grad_overflow.register("Tensor")
48def _tensor_grad_overflow(grad):
49    return grad_overflow(grad)
50
51
52@_grad_overflow.register("RowTensor")
53def _tensor_grad_overflow_row_tensor(grad):
54    return grad_overflow(grad.values)
55
56
57class DynamicLossScaleUpdateCell(Cell):
58    r"""
59    Dynamic Loss scale update cell.
60
61    For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`.
62    In each training step, the loss scaling value  will be updated by loss scaling value/`scale_factor`
63    when there is an overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
64    overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
65    logic will be executed on device side(Another training mode is normal(non-sink) mode in which some logic will be
66    executed on host).
67
68    Args:
69        loss_scale_value (float): Initializes loss scale.
70        scale_factor (int): Coefficient of increase and decrease.
71        scale_window (int): Maximum continuous training steps that do not have overflow.
72
73    Inputs:
74        - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`.
75        - **overflow** (bool) - Whether the overflow occurs or not.
76
77    Outputs:
78        bool, the input `overflow`.
79
80    Raises:
81        TypeError: If dtype of `inputs` or `label` is neither float16 nor float32.
82
83    Supported Platforms:
84        ``Ascend`` ``GPU``
85
86    Examples:
87        >>> import numpy as np
88        >>> from mindspore import Tensor, Parameter, nn
89        >>> import mindspore.ops as ops
90        >>>
91        >>> class Net(nn.Cell):
92        ...     def __init__(self, in_features, out_features):
93        ...         super(Net, self).__init__()
94        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
95        ...                                 name='weight')
96        ...         self.matmul = ops.MatMul()
97        ...
98        ...     def construct(self, x):
99        ...         output = self.matmul(x, self.weight)
100        ...         return output
101        ...
102        >>> in_features, out_features = 16, 10
103        >>> net = Net(in_features, out_features)
104        >>> loss = nn.MSELoss()
105        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
106        >>> net_with_loss = nn.WithLossCell(net, loss)
107        >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
108        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
109        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
110        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
111        >>> output = train_network(input, labels)
112    """
113
114    def __init__(self,
115                 loss_scale_value,
116                 scale_factor,
117                 scale_window):
118        super(DynamicLossScaleUpdateCell, self).__init__()
119
120        self.scale_window = Tensor(scale_window, dtype=mstype.int32)
121        self.scale_factor = Tensor(scale_factor, dtype=mstype.float32)
122        self.loss_scale_value = loss_scale_value
123
124        self.cur_iter = Parameter(Tensor(1, dtype=mstype.int32), name="current_iterator_step")
125        self.last_overflow_iter = Parameter(Tensor(0, dtype=mstype.int32), name="last_overflow_iterator_step")
126        self.select = P.Select()
127        self.max = P.Maximum()
128        self.minimum_loss_scale = Tensor(1.0, dtype=mstype.float32)
129        self.reciprocal = P.Reciprocal()
130        self.less_equal = P.LessEqual()
131        self.logic_and = P.LogicalAnd()
132        self.logic_not = P.LogicalNot()
133        self.logic_or = P.LogicalOr()
134        self.const_true = Tensor(True, dtype=mstype.bool_)
135
136    def get_loss_scale(self):
137        """
138        Get Loss Scale value.
139        """
140        return self.loss_scale_value
141
142    def construct(self, loss_scale, overflow):
143        overflow_cond = overflow
144        loss_scale_on_overflow = self.select(overflow_cond, self.max(loss_scale * self.reciprocal(self.scale_factor),
145                                                                     self.minimum_loss_scale), loss_scale)
146        should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter)
147        last_iter_cond = self.logic_or(overflow_cond, should_inc)
148        last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter)
149        last_iter = F.assign(self.last_overflow_iter, last_overflow_iter)
150        update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond))
151        scale_mul_res = loss_scale_on_overflow * self.scale_factor
152        scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow)
153        F.assign(loss_scale, scaled_loss_scale)
154        inc_cur_iter = self.cur_iter + 1
155        inc_cur_iter = F.depend(inc_cur_iter, last_iter)
156        F.assign(self.cur_iter, inc_cur_iter)
157        return overflow
158
159
160class FixedLossScaleUpdateCell(Cell):
161    """
162    Static scale update cell, the loss scaling value will not be updated.
163
164    For usage, refer to `DynamicLossScaleUpdateCell`.
165
166    Args:
167        loss_scale_value (float): Initializes loss scale.
168
169    Inputs:
170        - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`, that will be ignored.
171        - **overflow** (bool) - Whether the overflow occurs or not.
172
173    Outputs:
174        bool, the input `overflow`.
175
176    Supported Platforms:
177        ``Ascend`` ``GPU``
178
179    Examples:
180        >>> import numpy as np
181        >>> from mindspore import Tensor, Parameter, nn, ops
182        >>>
183        >>> class Net(nn.Cell):
184        ...     def __init__(self, in_features, out_features):
185        ...         super(Net, self).__init__()
186        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
187        ...                                 name='weight')
188        ...         self.matmul = ops.MatMul()
189        ...
190        ...     def construct(self, x):
191        ...         output = self.matmul(x, self.weight)
192        ...         return output
193        ...
194        >>> in_features, out_features = 16, 10
195        >>> net = Net(in_features, out_features)
196        >>> loss = nn.MSELoss()
197        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
198        >>> net_with_loss = nn.WithLossCell(net, loss)
199        >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
200        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
201        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
202        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
203        >>> output = train_network(input, labels)
204    """
205
206    def __init__(self, loss_scale_value):
207        super(FixedLossScaleUpdateCell, self).__init__()
208        self.loss_scale_value = loss_scale_value
209
210    def get_loss_scale(self):
211        """
212        Get Loss Scale value.
213        """
214        return self.loss_scale_value
215
216    def construct(self, _, overflow):
217        return overflow
218
219
220class TrainOneStepWithLossScaleCell(TrainOneStepCell):
221    r"""
222    Network training with loss scaling.
223
224    This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
225    Cell as args. The loss scale value can be updated in both host side or device side. The
226    TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
227    The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
228    the value must be provided. If  the Tensor type of `scale_sense` is not given, the loss scale update logic
229    must be provied by Cell type of `scale_sense`.
230
231    Args:
232        network (Cell): The training network. The network only supports single output.
233        optimizer (Cell): Optimizer for updating the weights.
234        scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
235                                          is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
236
237    Inputs:
238        - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
239
240    Outputs:
241        Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
242
243        - **loss** (Tensor) -  Tensor with shape :math:`()`.
244        - **overflow** (Tensor) -  Tensor with shape :math:`()`, type is bool.
245        - **loss scaling value** (Tensor) -  Tensor with shape :math:`()`
246
247    Raises:
248        TypeError: If `scale_sense` is neither Cell nor Tensor.
249        ValueError: If shape of `scale_sense` is neither (1,) nor ().
250
251    Supported Platforms:
252        ``Ascend`` ``GPU``
253
254    Examples:
255        >>> import numpy as np
256        >>> from mindspore import Tensor, Parameter, nn, ops
257        >>> from mindspore import dtype as mstype
258        >>>
259        >>> class Net(nn.Cell):
260        ...     def __init__(self, in_features, out_features):
261        ...         super(Net, self).__init__()
262        ...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
263        ...                                 name='weight')
264        ...         self.matmul = ops.MatMul()
265        ...
266        ...     def construct(self, x):
267        ...         output = self.matmul(x, self.weight)
268        ...         return output
269        ...
270        >>> size, in_features, out_features = 16, 16, 10
271        >>> #1) when the type of scale_sense is Cell:
272        >>> net = Net(in_features, out_features)
273        >>> loss = nn.MSELoss()
274        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
275        >>> net_with_loss = nn.WithLossCell(net, loss)
276        >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
277        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
278        >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
279        >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
280        >>> output = train_network(input, labels)
281        >>>
282        >>> #2) when the type of scale_sense is Tensor:
283        >>> net = Net(in_features, out_features)
284        >>> loss = nn.MSELoss()
285        >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
286        >>> net_with_loss = nn.WithLossCell(net, loss)
287        >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
288        >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
289        >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
290        >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
291        >>> output = train_network(inputs, label)
292    """
293    def __init__(self, network, optimizer, scale_sense):
294        super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
295        self.hyper_map = C.HyperMap()
296        self.base = Tensor(1, mstype.float32)
297        self.reduce_sum = P.ReduceSum(keep_dims=False)
298        self.less_equal = P.LessEqual()
299        self.allreduce = P.AllReduce()
300        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
301        self.gpu_target = (context.get_context("device_target") == "GPU")
302        self.loss_scaling_manager = None
303        if isinstance(scale_sense, Cell):
304            self.loss_scaling_manager = scale_sense
305            self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
306                                         name="scale_sense")
307        elif isinstance(scale_sense, Tensor):
308            if scale_sense.shape == (1,) or scale_sense.shape == ():
309                self.scale_sense = Parameter(scale_sense, name='scale_sense')
310            else:
311                raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
312        else:
313            raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
314
315    def construct(self, *inputs):
316        weights = self.weights
317        loss = self.network(*inputs)
318        scaling_sens = self.scale_sense
319
320        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
321
322        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
323        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
324        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
325        # apply grad reducer on grads
326        grads = self.grad_reducer(grads)
327
328        # get the overflow buffer
329        cond = self.get_overflow_status(status, grads)
330        overflow = self.process_loss_scale(cond)
331        # if there is no overflow, do optimize
332        if not overflow:
333            loss = F.depend(loss, self.optimizer(grads))
334        return loss, cond, scaling_sens
335
336    def set_sense_scale(self, sens):
337        """
338        If the user has set the sens in the training process and wants to reassign the value, he can call
339        this function again to make modification, and sens needs to be of type Tensor.
340
341        Inputs:
342            - **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`.
343        """
344        if self.scale_sense and isinstance(sens, Tensor):
345            self.scale_sense.set_data(sens)
346        else:
347            raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
348
349    def start_overflow_check(self, pre_cond, compute_input):
350        """
351        Start floating-point overflow detection. Create and clear the overflow detection state.
352
353        Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time.
354        Taking this situation as an example, we need to execute state clearing after loss calculation and then detect
355        overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
356        function, and compute_input should be the input of gradients-computing function.
357
358        Inputs:
359            - **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order
360              of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
361              clears status after finishing the process of precondition.
362            - **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a
363              certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
364              cleared before executing the computation.
365
366        Outputs:
367            Tuple[object, object], the first value is False for GPU backend, while it is a instance of
368            NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
369            The second value is the same as the input of `compute_input`, but contains some information about the
370            execution order.
371        """
372        status = False
373        if not self.gpu_target:
374            # init overflow buffer
375            status = P.NPUAllocFloatStatus()()
376            status = F.depend(status, pre_cond)
377            # clear overflow buffer
378            clear_status = P.NPUClearFloatStatus()(status)
379            compute_input = F.depend(compute_input, clear_status)
380        return status, compute_input
381
382    def get_overflow_status(self, status, compute_output):
383        """
384        Get floating-point overflow status.
385
386        Get overflow results after executing the target process for overflow detection.
387
388        Inputs:
389            - **status** (object) - A status instance used to detect the overflow.
390            - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
391              as the output of the computation, to ensure overflow status is acquired before executing the
392              computation.
393
394        Outputs:
395            bool, whether the overflow occurs or not.
396        """
397        if not self.gpu_target:
398            status = F.depend(status, compute_output)
399            get_status = P.NPUGetFloatStatus()(status)
400            status = F.depend(status, get_status)
401            # sum overflow buffer elements, 0:not overflow , >0:overflow
402            flag_sum = self.reduce_sum(status, (0,))
403        else:
404            flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
405            flag_sum = P.AddN()(flag_sum)
406            # convert flag_sum to scalar
407            flag_sum = P.Reshape()(flag_sum, (()))
408
409        if self.is_distributed:
410            # sum overflow flag over devices
411            flag_reduce = self.allreduce(flag_sum)
412            overflow = self.less_equal(self.base, flag_reduce)
413        else:
414            overflow = self.less_equal(self.base, flag_sum)
415        return overflow
416
417    def process_loss_scale(self, overflow):
418        """
419        Calculate loss scale according to the overflow.
420
421        Inputs:
422            - **overflow** (bool) - Whether the overflow occurs or not.
423
424        Outputs:
425            bool, overflow value.
426        """
427        if self.loss_scaling_manager is not None:
428            return self.loss_scaling_manager(self.scale_sense, overflow)
429        return overflow
430
431
432grad_scale = C.MultitypeFuncGraph("grad_scale")
433shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
434reciprocal = P.Reciprocal()
435
436
437@grad_scale.register("Tensor", "Tensor", "Tensor")
438def tensor_grad_scale_pipeline(scale, grad, accu_grad):
439    accu_grad = F.depend(accu_grad, grad)
440    new_grad = accu_grad * reciprocal(scale)
441    accu_grad = F.depend(accu_grad, new_grad)
442    zeros = F.tensor_mul(accu_grad, 0.0)
443    new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
444    return new_grad
445
446
447@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
448def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
449    new_grad = grad * reciprocal(scale)
450    accu_grad = F.depend(accu_grad, new_grad)
451    new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
452    return new_grad
453
454
455class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
456    """
457    Append an optimizer to the training network after that the construct
458    function can be called to create the backward graph.
459
460    Args:
461        network (Cell): The training network. Note that loss function should have been added.
462        optimizer (Optimizer): Optimizer for updating the weights.
463        scale_sense (Cell): Cell to do the loss scale.
464    """
465    def __init__(self, network, optimizer, scale_sense):
466        super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None)
467        self.network = network
468        self.network.add_flags(defer_inline=True)
469        self.weights = optimizer.parameters
470        self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
471        self.optimizer = optimizer
472        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
473        self.grad_reducer = F.identity
474        self.degree = 1
475        self.cast = P.Cast()
476        self.alloc_status = P.NPUAllocFloatStatus()
477        self.get_status = P.NPUGetFloatStatus()
478        self.clear_before_grad = P.NPUClearFloatStatus()
479        self.reduce_sum = P.ReduceSum(keep_dims=False)
480        self.base = Tensor(1, mstype.float32)
481        self.less_equal = P.LessEqual()
482        self.hyper_map = C.HyperMap()
483        self.reshape = P.Reshape()
484        self.loss_scaling_manager = None
485        if isinstance(scale_sense, Cell):
486            self.loss_scaling_manager = scale_sense
487            self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
488                                         name="scale_sense")
489        elif isinstance(scale_sense, Tensor):
490            if scale_sense.shape == (1,) or scale_sense.shape == ():
491                self.scale_sense = Parameter(scale_sense, name='scale_sense')
492            else:
493                raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
494        else:
495            raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
496        self.opt_shard = _get_enable_parallel_optimizer()
497
498    def construct(self, *inputs):
499        weights = self.weights
500        loss = self.network(*inputs)
501        scaling_sens = self.scale_sense
502        init = self.alloc_status()
503        status_clear = self.clear_before_grad(init)
504        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
505        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
506        init = F.depend(init, grads)
507        get_status = self.get_status(init)
508        init = F.depend(init, get_status)
509        flag_sum = self.reduce_sum(init, (0,))
510        loss = F.depend(loss, status_clear)
511        if self.opt_shard:
512            grads = self.grad_reducer(grads)
513            grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
514        else:
515            accu_grads = self.grad_reducer(self.accu_grads)
516            grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
517        cond = self.less_equal(self.base, flag_sum)
518        overflow = cond
519        if self.loss_scaling_manager is not None:
520            overflow = self.loss_scaling_manager(self.scale_sense, cond)
521        if overflow:
522            succ = False
523        else:
524            succ = self.optimizer(grads)
525        ret = (loss, overflow, scaling_sens)
526        return F.depend(ret, succ)
527