• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ms function for mixed precision."""
16from __future__ import absolute_import
17
18import os
19from abc import ABC, abstractmethod
20from mindspore.common import mutable
21from mindspore.ops._primitive_cache import _get_cache_prim
22from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2
23from mindspore.ops.operations.nn_ops import AllFinite
24from mindspore import _checkparam as validator
25from mindspore._c_expression import MSContext
26from .common import dtype as mstype
27from . import context
28from . import ops
29from .ops import constexpr
30from .common.api import jit_class, jit
31from .common.parameter import Parameter
32from .common.tensor import Tensor
33from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
34from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision,\
35    get_white_list, get_black_list
36
37
38_hypermap = ops.HyperMap()
39_partial = ops.Partial()
40
41
42@constexpr
43def _ascend_target():
44    return context.get_context("device_target") == "Ascend"
45
46
47@constexpr
48def _ascend_910a_target():
49    return MSContext.get_instance().get_ascend_soc_version() == "ascend910"
50
51
52@constexpr
53def _ascend_910bc_target():
54    return MSContext.get_instance().get_ascend_soc_version() in ["ascend910b", "ascend910c"]
55
56
57@constexpr
58def _gpu_target():
59    return context.get_context("device_target") == "GPU"
60
61
62@constexpr
63def _enable_all_finite():
64    """check whether enable all finite"""
65    runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
66    global_jit_config = context.get_jit_config()
67    if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
68        return True
69
70    if runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
71        return False
72
73    if global_jit_config:
74        return global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
75    return False
76
77
78def _grad_unscale(scale, grad):
79    return grad * ops.Reciprocal()(scale).astype(grad.dtype)
80
81
82def _grad_scale(scale, grad):
83    return grad * scale.astype(grad.dtype)
84
85
86@jit
87def _grad_scale_map(scale_value, inputs):
88    return _hypermap(_partial(_grad_scale, scale_value), inputs)
89
90
91@jit
92def _grad_unscale_map(scale_value, inputs):
93    return _hypermap(_partial(_grad_unscale, scale_value), inputs)
94
95
96def _overflow(inputs):
97    if _gpu_target():
98        return ops.FloatStatus()(inputs)
99    status = ops.isfinite(inputs)
100    return 1 - status.all()
101
102
103@jit
104def _all_finite(inputs, check_overflow_mode, enable_allfinite):
105    """all finite check"""
106    if _ascend_target():
107        if (_ascend_910a_target()) or \
108           (_ascend_910bc_target() and check_overflow_mode == "SATURATION_MODE"):
109            status = Tensor([0] * 8, mstype.int32)
110            status = ops.depend(status, inputs)
111            get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status)
112            status = ops.depend(status, get_status)
113            clear_status = _get_cache_prim(NPUClearFloatStatusV2)()(status)
114            get_status = ops.depend(get_status, clear_status)
115            status_finite = get_status.equal(Tensor(0, mstype.int32)).all()
116            return status_finite
117
118    status_finite = False
119    if enable_allfinite:
120        status_finite = ~AllFinite()(inputs)  # pylint: disable=invalid-unary-operand-type
121    else:
122        outputs = _hypermap(_partial(_overflow), inputs)
123        flag_sum = ops.addn(outputs).reshape(())
124        status_finite = ops.less(flag_sum, 1)
125    return status_finite
126
127
128def all_finite(inputs):
129    r"""
130    Returns a scalar Tensor indicating whether the inputs are finite.
131
132    .. warning::
133        This is an experimental API that is subject to change or deletion.
134
135        The interface must be used in whole network training scenario to detect
136        whether grads are finite, and the results may be different on different
137        device targets.
138
139    Args:
140        inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
141
142    Returns:
143        Tensor, a scalar Tensor and the dtype is bool.
144
145    Supported Platforms:
146        ``Ascend`` ``GPU`` ``CPU``
147
148    Examples:
149        >>> from mindspore import amp, Tensor
150        >>> import numpy as np
151        >>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0])))
152        >>> output = amp.all_finite(x)
153
154    Tutorial Examples:
155        - `Automatic Mix Precision - Loss Scaling
156          <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_
157    """
158    inputs = mutable(inputs)
159    _check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
160    return _all_finite(inputs, _check_overflow_mode, _enable_all_finite())
161
162
163@jit_class
164class LossScaler(ABC):
165    r"""
166    Loss scaler abstract class when using mixed precision.
167
168    Derived class needs to implement all of its methods. During training, `scale` and `unscale` is used
169    to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the
170    loss scale value.
171
172    For more information, refer to the `tutorials  <https://mindspore.cn/tutorials/en/master/advanced/
173    mixed_precision.html#loss-scaling>`_.
174
175    .. warning::
176        This is an experimental API that is subject to change or deletion.
177
178    Examples:
179        >>> from mindspore.amp import LossScaler, _grad_scale_map, _grad_unscale_map
180        >>> from mindspore import ops, Parameter, Tensor
181        >>> from mindspore.common import dtype as mstype
182        >>>
183        >>> class MyLossScaler(LossScaler):
184        ...     def __init__(self, scale_value):
185        ...         self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
186        ...
187        ...     def scale(self, inputs):
188        ...         inputs = mutable(inputs)
189        ...         return _grad_scale_map(self.scale_value, inputs)
190        ...
191        ...     def unscale(self, inputs):
192        ...         inputs = mutable(inputs)
193        ...         return _grad_unscale_map(self.scale_value, inputs)
194        ...
195        ...     def adjust(self, grads_finite):
196        ...         scale_mul_factor = self.scale_value * self.scale_factor
197        ...         scale_value = ops.select(grads_finite, scale_mul_factor, self.scale_value)
198        ...         ops.assign(self.scale_value, scale_value)
199        ...         return True
200        >>>
201        >>> loss_scaler = MyLossScaler(1024)
202    """
203    @abstractmethod
204    def scale(self, inputs):
205        """
206        Scaling inputs by `scale_value`.
207
208        Args:
209            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
210        """
211        raise NotImplementedError
212
213    @abstractmethod
214    def unscale(self, inputs):
215        """
216        Unscaling inputs by `scale_value`.
217
218        Args:
219            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
220        """
221        raise NotImplementedError
222
223    @abstractmethod
224    def adjust(self, grads_finite):
225        """
226        Adjust the `scale_value` dependent on whether grads are finite.
227
228        Args:
229            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
230        """
231        raise NotImplementedError
232
233
234class StaticLossScaler(LossScaler):
235    r"""
236    Static Loss scale class.
237
238    Scales and unscales loss or gradients by a fixed constant.
239
240    .. warning::
241        This is an experimental API that is subject to change or deletion.
242
243    Args:
244        scale_value (Union(float, int)): The initial loss scale value.
245
246    Supported Platforms:
247        ``Ascend`` ``GPU`` ``CPU``
248
249    Examples:
250        >>> import mindspore
251        >>> from mindspore import amp, Tensor
252        >>> import numpy as np
253        >>> loss_scaler = amp.StaticLossScaler(scale_value=2**10)
254        >>> loss_value = Tensor([1.], mindspore.float32)
255        >>> scaled_loss_value = loss_scaler.scale(loss_value)
256        >>> print(scaled_loss_value)
257        [1024.]
258        >>> grads = (Tensor(np.array([1.5, 1.0]), mindspore.float16),
259        ...      Tensor(np.array([1.2]), mindspore.float16))
260        >>> unscaled_grads = loss_scaler.unscale(grads)
261        >>> print(unscaled_grads)
262        (Tensor(shape=[2], dtype=Float16, value= [ 1.4648e-03,  9.7656e-04]),
263        Tensor(shape=[1], dtype=Float16, value= [ 1.1721e-03]))
264    """
265    def __init__(self, scale_value):
266        scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
267        if scale_value < 1.0:
268            raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value))
269        self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
270
271    def scale(self, inputs):
272        """
273        Scaling inputs by `scale_value`.
274
275        Args:
276            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
277
278        Returns:
279            Union(Tensor, tuple(Tensor)), the scaled value.
280        """
281        inputs = mutable(inputs)
282        return _grad_scale_map(self.scale_value, inputs)
283
284    def unscale(self, inputs):
285        """
286        Unscaling inputs by `scale_value`.
287
288        Args:
289            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
290
291        Returns:
292            Union(Tensor, tuple(Tensor)), the unscaled value.
293        """
294        inputs = mutable(inputs)
295        return _grad_unscale_map(self.scale_value, inputs)
296
297    def adjust(self, grads_finite):
298        """
299        Adjust `scale_value` in `LossScaler`. `scale_value` is fixed in `StaticLossScaler`, so this method
300        return False directly.
301
302        Args:
303            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
304        """
305        return False
306
307
308class DynamicLossScaler(LossScaler):
309    r"""
310    Dynamic Loss scale class.
311
312    Dynamic loss scaling tries to determine the largest loss scale value that
313    will keep gradients finite. It does this by increasing the loss scale every
314    `scale_window` steps by `factor` if the grads remain finite, otherwise it reduces
315    the loss scale by `1 / factor` and resets the counter.
316
317    .. warning::
318        This is an experimental API that is subject to change or deletion.
319
320    Args:
321        scale_value (Union(float, int)): The initial loss scale value.
322        scale_factor (int): The scale factor.
323        scale_window (int): Maximum continuous training steps that do not have
324            overflow to increase the loss scale.
325
326    Supported Platforms:
327        ``Ascend`` ``GPU`` ``CPU``
328
329    Examples:
330        >>> import mindspore
331        >>> from mindspore import amp, Tensor
332        >>> import numpy as np
333        >>> loss_scaler = amp.DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=1)
334        >>> grads = (Tensor(np.array([np.log(-1), 1.0]), mindspore.float16),
335        ...             Tensor(np.array([0.2]), mindspore.float16))
336        >>> unscaled_grads = loss_scaler.unscale(grads)
337        >>> grads_finite = amp.all_finite(unscaled_grads)
338        >>> loss_scaler.adjust(grads_finite)
339        True
340        >>> print(loss_scaler.scale_value.asnumpy())
341        512.0
342    """
343    def __init__(self, scale_value, scale_factor, scale_window):
344        scale_value = validator.check_value_type("scale_value", scale_value, [float, int])
345        if scale_value < 1.0:
346            raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value))
347        self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
348        self.scale_window = validator.check_positive_int(scale_window, "scale_window")
349        self.scale_factor = validator.check_positive_int(scale_factor, "scale_factor")
350        self.counter = Parameter(Tensor(0, dtype=mstype.int32), name="counter")
351
352    def scale(self, inputs):
353        """
354        Scaling inputs by `scale_value`.
355
356        Args:
357            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
358
359        Returns:
360            Union(Tensor, tuple(Tensor)), the scaled value.
361
362        Tutorial Examples:
363            - `Automatic Mix Precision - Loss Scaling
364              <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_
365        """
366        inputs = mutable(inputs)
367        return _grad_scale_map(self.scale_value, inputs)
368
369    def unscale(self, inputs):
370        """
371        Unscaling inputs by `scale_value`.
372
373        Args:
374            inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients.
375
376        Returns:
377            Union(Tensor, tuple(Tensor)), the unscaled value.
378
379        Tutorial Examples:
380            - `Automatic Mix Precision - Loss Scaling
381              <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_
382        """
383        inputs = mutable(inputs)
384        return _grad_unscale_map(self.scale_value, inputs)
385
386    def adjust(self, grads_finite):
387        """
388        Adjust the `scale_value` dependent on whether grads are finite.
389
390        Args:
391            grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite.
392
393        Tutorial Examples:
394            - `Automatic Mix Precision - Loss Scaling
395              <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_
396        """
397        one = ops.ones((), self.scale_value.dtype)
398        scale_mul_factor = self.scale_value * self.scale_factor
399        scale_value = ops.select(
400            grads_finite,
401            ops.select(
402                self.counter == (self.scale_window - 1),
403                ops.select(ops.isfinite(scale_mul_factor),
404                           scale_mul_factor,
405                           self.scale_value),
406                self.scale_value),
407            ops.maximum(one, self.scale_value / self.scale_factor))
408        ops.assign(self.scale_value, scale_value)
409
410        counter = ((self.counter + 1) % self.scale_window) * grads_finite
411        ops.assign(self.counter, counter)
412        return True
413
414__all__ = [
415    "DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
416    "build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
417    "auto_mixed_precision", "all_finite", "custom_mixed_precision",
418    "get_white_list", "get_black_list"
419]
420