1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ms function for mixed precision.""" 16from __future__ import absolute_import 17 18import os 19from abc import ABC, abstractmethod 20from mindspore.common import mutable 21from mindspore.ops._primitive_cache import _get_cache_prim 22from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2 23from mindspore.ops.operations.nn_ops import AllFinite 24from mindspore import _checkparam as validator 25from mindspore._c_expression import MSContext 26from .common import dtype as mstype 27from . import context 28from . import ops 29from .ops import constexpr 30from .common.api import jit_class, jit 31from .common.parameter import Parameter 32from .common.tensor import Tensor 33from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager 34from .train.amp import build_train_network, auto_mixed_precision, custom_mixed_precision,\ 35 get_white_list, get_black_list 36 37 38_hypermap = ops.HyperMap() 39_partial = ops.Partial() 40 41 42@constexpr 43def _ascend_target(): 44 return context.get_context("device_target") == "Ascend" 45 46 47@constexpr 48def _ascend_910a_target(): 49 return MSContext.get_instance().get_ascend_soc_version() == "ascend910" 50 51 52@constexpr 53def _ascend_910bc_target(): 54 return MSContext.get_instance().get_ascend_soc_version() in ["ascend910b", "ascend910c"] 55 56 57@constexpr 58def _gpu_target(): 59 return context.get_context("device_target") == "GPU" 60 61 62@constexpr 63def _enable_all_finite(): 64 """check whether enable all finite""" 65 runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF') 66 global_jit_config = context.get_jit_config() 67 if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf): 68 return True 69 70 if runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf): 71 return False 72 73 if global_jit_config: 74 return global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1" 75 return False 76 77 78def _grad_unscale(scale, grad): 79 return grad * ops.Reciprocal()(scale).astype(grad.dtype) 80 81 82def _grad_scale(scale, grad): 83 return grad * scale.astype(grad.dtype) 84 85 86@jit 87def _grad_scale_map(scale_value, inputs): 88 return _hypermap(_partial(_grad_scale, scale_value), inputs) 89 90 91@jit 92def _grad_unscale_map(scale_value, inputs): 93 return _hypermap(_partial(_grad_unscale, scale_value), inputs) 94 95 96def _overflow(inputs): 97 if _gpu_target(): 98 return ops.FloatStatus()(inputs) 99 status = ops.isfinite(inputs) 100 return 1 - status.all() 101 102 103@jit 104def _all_finite(inputs, check_overflow_mode, enable_allfinite): 105 """all finite check""" 106 if _ascend_target(): 107 if (_ascend_910a_target()) or \ 108 (_ascend_910bc_target() and check_overflow_mode == "SATURATION_MODE"): 109 status = Tensor([0] * 8, mstype.int32) 110 status = ops.depend(status, inputs) 111 get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status) 112 status = ops.depend(status, get_status) 113 clear_status = _get_cache_prim(NPUClearFloatStatusV2)()(status) 114 get_status = ops.depend(get_status, clear_status) 115 status_finite = get_status.equal(Tensor(0, mstype.int32)).all() 116 return status_finite 117 118 status_finite = False 119 if enable_allfinite: 120 status_finite = ~AllFinite()(inputs) # pylint: disable=invalid-unary-operand-type 121 else: 122 outputs = _hypermap(_partial(_overflow), inputs) 123 flag_sum = ops.addn(outputs).reshape(()) 124 status_finite = ops.less(flag_sum, 1) 125 return status_finite 126 127 128def all_finite(inputs): 129 r""" 130 Returns a scalar Tensor indicating whether the inputs are finite. 131 132 .. warning:: 133 This is an experimental API that is subject to change or deletion. 134 135 The interface must be used in whole network training scenario to detect 136 whether grads are finite, and the results may be different on different 137 device targets. 138 139 Args: 140 inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor. 141 142 Returns: 143 Tensor, a scalar Tensor and the dtype is bool. 144 145 Supported Platforms: 146 ``Ascend`` ``GPU`` ``CPU`` 147 148 Examples: 149 >>> from mindspore import amp, Tensor 150 >>> import numpy as np 151 >>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))) 152 >>> output = amp.all_finite(x) 153 154 Tutorial Examples: 155 - `Automatic Mix Precision - Loss Scaling 156 <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_ 157 """ 158 inputs = mutable(inputs) 159 _check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE') 160 return _all_finite(inputs, _check_overflow_mode, _enable_all_finite()) 161 162 163@jit_class 164class LossScaler(ABC): 165 r""" 166 Loss scaler abstract class when using mixed precision. 167 168 Derived class needs to implement all of its methods. During training, `scale` and `unscale` is used 169 to scale and unscale the loss value and gradients to avoid overflow, `adjust` is used to update the 170 loss scale value. 171 172 For more information, refer to the `tutorials <https://mindspore.cn/tutorials/en/master/advanced/ 173 mixed_precision.html#loss-scaling>`_. 174 175 .. warning:: 176 This is an experimental API that is subject to change or deletion. 177 178 Examples: 179 >>> from mindspore.amp import LossScaler, _grad_scale_map, _grad_unscale_map 180 >>> from mindspore import ops, Parameter, Tensor 181 >>> from mindspore.common import dtype as mstype 182 >>> 183 >>> class MyLossScaler(LossScaler): 184 ... def __init__(self, scale_value): 185 ... self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value") 186 ... 187 ... def scale(self, inputs): 188 ... inputs = mutable(inputs) 189 ... return _grad_scale_map(self.scale_value, inputs) 190 ... 191 ... def unscale(self, inputs): 192 ... inputs = mutable(inputs) 193 ... return _grad_unscale_map(self.scale_value, inputs) 194 ... 195 ... def adjust(self, grads_finite): 196 ... scale_mul_factor = self.scale_value * self.scale_factor 197 ... scale_value = ops.select(grads_finite, scale_mul_factor, self.scale_value) 198 ... ops.assign(self.scale_value, scale_value) 199 ... return True 200 >>> 201 >>> loss_scaler = MyLossScaler(1024) 202 """ 203 @abstractmethod 204 def scale(self, inputs): 205 """ 206 Scaling inputs by `scale_value`. 207 208 Args: 209 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 210 """ 211 raise NotImplementedError 212 213 @abstractmethod 214 def unscale(self, inputs): 215 """ 216 Unscaling inputs by `scale_value`. 217 218 Args: 219 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 220 """ 221 raise NotImplementedError 222 223 @abstractmethod 224 def adjust(self, grads_finite): 225 """ 226 Adjust the `scale_value` dependent on whether grads are finite. 227 228 Args: 229 grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite. 230 """ 231 raise NotImplementedError 232 233 234class StaticLossScaler(LossScaler): 235 r""" 236 Static Loss scale class. 237 238 Scales and unscales loss or gradients by a fixed constant. 239 240 .. warning:: 241 This is an experimental API that is subject to change or deletion. 242 243 Args: 244 scale_value (Union(float, int)): The initial loss scale value. 245 246 Supported Platforms: 247 ``Ascend`` ``GPU`` ``CPU`` 248 249 Examples: 250 >>> import mindspore 251 >>> from mindspore import amp, Tensor 252 >>> import numpy as np 253 >>> loss_scaler = amp.StaticLossScaler(scale_value=2**10) 254 >>> loss_value = Tensor([1.], mindspore.float32) 255 >>> scaled_loss_value = loss_scaler.scale(loss_value) 256 >>> print(scaled_loss_value) 257 [1024.] 258 >>> grads = (Tensor(np.array([1.5, 1.0]), mindspore.float16), 259 ... Tensor(np.array([1.2]), mindspore.float16)) 260 >>> unscaled_grads = loss_scaler.unscale(grads) 261 >>> print(unscaled_grads) 262 (Tensor(shape=[2], dtype=Float16, value= [ 1.4648e-03, 9.7656e-04]), 263 Tensor(shape=[1], dtype=Float16, value= [ 1.1721e-03])) 264 """ 265 def __init__(self, scale_value): 266 scale_value = validator.check_value_type("scale_value", scale_value, [float, int]) 267 if scale_value < 1.0: 268 raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value)) 269 self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value") 270 271 def scale(self, inputs): 272 """ 273 Scaling inputs by `scale_value`. 274 275 Args: 276 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 277 278 Returns: 279 Union(Tensor, tuple(Tensor)), the scaled value. 280 """ 281 inputs = mutable(inputs) 282 return _grad_scale_map(self.scale_value, inputs) 283 284 def unscale(self, inputs): 285 """ 286 Unscaling inputs by `scale_value`. 287 288 Args: 289 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 290 291 Returns: 292 Union(Tensor, tuple(Tensor)), the unscaled value. 293 """ 294 inputs = mutable(inputs) 295 return _grad_unscale_map(self.scale_value, inputs) 296 297 def adjust(self, grads_finite): 298 """ 299 Adjust `scale_value` in `LossScaler`. `scale_value` is fixed in `StaticLossScaler`, so this method 300 return False directly. 301 302 Args: 303 grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite. 304 """ 305 return False 306 307 308class DynamicLossScaler(LossScaler): 309 r""" 310 Dynamic Loss scale class. 311 312 Dynamic loss scaling tries to determine the largest loss scale value that 313 will keep gradients finite. It does this by increasing the loss scale every 314 `scale_window` steps by `factor` if the grads remain finite, otherwise it reduces 315 the loss scale by `1 / factor` and resets the counter. 316 317 .. warning:: 318 This is an experimental API that is subject to change or deletion. 319 320 Args: 321 scale_value (Union(float, int)): The initial loss scale value. 322 scale_factor (int): The scale factor. 323 scale_window (int): Maximum continuous training steps that do not have 324 overflow to increase the loss scale. 325 326 Supported Platforms: 327 ``Ascend`` ``GPU`` ``CPU`` 328 329 Examples: 330 >>> import mindspore 331 >>> from mindspore import amp, Tensor 332 >>> import numpy as np 333 >>> loss_scaler = amp.DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=1) 334 >>> grads = (Tensor(np.array([np.log(-1), 1.0]), mindspore.float16), 335 ... Tensor(np.array([0.2]), mindspore.float16)) 336 >>> unscaled_grads = loss_scaler.unscale(grads) 337 >>> grads_finite = amp.all_finite(unscaled_grads) 338 >>> loss_scaler.adjust(grads_finite) 339 True 340 >>> print(loss_scaler.scale_value.asnumpy()) 341 512.0 342 """ 343 def __init__(self, scale_value, scale_factor, scale_window): 344 scale_value = validator.check_value_type("scale_value", scale_value, [float, int]) 345 if scale_value < 1.0: 346 raise ValueError("The argument 'scale_value' must be > 1, but got {}".format(scale_value)) 347 self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value") 348 self.scale_window = validator.check_positive_int(scale_window, "scale_window") 349 self.scale_factor = validator.check_positive_int(scale_factor, "scale_factor") 350 self.counter = Parameter(Tensor(0, dtype=mstype.int32), name="counter") 351 352 def scale(self, inputs): 353 """ 354 Scaling inputs by `scale_value`. 355 356 Args: 357 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 358 359 Returns: 360 Union(Tensor, tuple(Tensor)), the scaled value. 361 362 Tutorial Examples: 363 - `Automatic Mix Precision - Loss Scaling 364 <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_ 365 """ 366 inputs = mutable(inputs) 367 return _grad_scale_map(self.scale_value, inputs) 368 369 def unscale(self, inputs): 370 """ 371 Unscaling inputs by `scale_value`. 372 373 Args: 374 inputs (Union(Tensor, tuple(Tensor))): the input loss value or gradients. 375 376 Returns: 377 Union(Tensor, tuple(Tensor)), the unscaled value. 378 379 Tutorial Examples: 380 - `Automatic Mix Precision - Loss Scaling 381 <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_ 382 """ 383 inputs = mutable(inputs) 384 return _grad_unscale_map(self.scale_value, inputs) 385 386 def adjust(self, grads_finite): 387 """ 388 Adjust the `scale_value` dependent on whether grads are finite. 389 390 Args: 391 grads_finite (Tensor): a scalar bool Tensor indicating whether the grads are finite. 392 393 Tutorial Examples: 394 - `Automatic Mix Precision - Loss Scaling 395 <https://mindspore.cn/tutorials/en/master/advanced/mixed_precision.html#loss-scaling>`_ 396 """ 397 one = ops.ones((), self.scale_value.dtype) 398 scale_mul_factor = self.scale_value * self.scale_factor 399 scale_value = ops.select( 400 grads_finite, 401 ops.select( 402 self.counter == (self.scale_window - 1), 403 ops.select(ops.isfinite(scale_mul_factor), 404 scale_mul_factor, 405 self.scale_value), 406 self.scale_value), 407 ops.maximum(one, self.scale_value / self.scale_factor)) 408 ops.assign(self.scale_value, scale_value) 409 410 counter = ((self.counter + 1) % self.scale_window) * grads_finite 411 ops.assign(self.counter, counter) 412 return True 413 414__all__ = [ 415 "DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager", 416 "build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler", 417 "auto_mixed_precision", "all_finite", "custom_mixed_precision", 418 "get_white_list", "get_black_list" 419] 420