1.. role:: hidden 2 :class: hidden-section 3 4Automatic Mixed Precision package - torch.amp 5============================================= 6 7.. Both modules below are missing doc entry. Adding them here for now. 8.. This does not add anything to the rendered page 9.. py:module:: torch.cpu.amp 10.. py:module:: torch.cuda.amp 11 12.. automodule:: torch.amp 13.. currentmodule:: torch.amp 14 15:class:`torch.amp` provides convenience methods for mixed precision, 16where some operations use the ``torch.float32`` (``float``) datatype and other operations 17use lower precision floating point datatype (``lower_precision_fp``): ``torch.float16`` (``half``) or ``torch.bfloat16``. Some ops, like linear layers and convolutions, 18are much faster in ``lower_precision_fp``. Other ops, like reductions, often require the dynamic 19range of ``float32``. Mixed precision tries to match each op to its appropriate datatype. 20 21Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and 22:class:`torch.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>` 23and `Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_. 24However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and may be used separately if desired. 25As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with 26datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`. 27 28.. warning:: 29 ``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead. 30 ``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead. 31 32:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`. 33 34.. contents:: :local: 35 36.. _autocasting: 37 38Autocasting 39^^^^^^^^^^^ 40.. currentmodule:: torch.amp.autocast_mode 41 42.. autofunction:: is_autocast_available 43 44.. currentmodule:: torch 45 46.. autoclass:: autocast 47 :members: 48 49.. currentmodule:: torch.amp 50 51.. autofunction:: custom_fwd 52 53.. autofunction:: custom_bwd 54 55.. currentmodule:: torch.cuda.amp 56 57.. autoclass:: autocast 58 :members: 59 60.. autofunction:: custom_fwd 61 62.. autofunction:: custom_bwd 63 64.. currentmodule:: torch.cpu.amp 65 66.. autoclass:: autocast 67 :members: 68 69.. _gradient-scaling: 70 71Gradient Scaling 72^^^^^^^^^^^^^^^^ 73 74If the forward pass for a particular op has ``float16`` inputs, the backward pass for 75that op will produce ``float16`` gradients. 76Gradient values with small magnitudes may not be representable in ``float16``. 77These values will flush to zero ("underflow"), so the update for the corresponding parameters will be lost. 78 79To prevent underflow, "gradient scaling" multiplies the network's loss(es) by a scale factor and 80invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are 81then scaled by the same factor. In other words, gradient values have a larger magnitude, 82so they don't flush to zero. 83 84Each parameter's gradient (``.grad`` attribute) should be unscaled before the optimizer 85updates the parameters, so the scale factor does not interfere with the learning rate. 86 87.. note:: 88 89 AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in 90 the fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. In 91 this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number 92 representable in the fp16 dynamic range. While one may expect the scale to always be above 1, our 93 GradScaler does NOT make this guarantee to maintain performance. If you encounter NaNs in your loss 94 or gradients when running with AMP/fp16, verify your model is compatible. 95 96.. currentmodule:: torch.cuda.amp 97 98.. autoclass:: GradScaler 99 :members: 100 101.. currentmodule:: torch.cpu.amp 102 103.. autoclass:: GradScaler 104 :members: 105 106.. _autocast-op-reference: 107 108Autocast Op Reference 109^^^^^^^^^^^^^^^^^^^^^ 110 111.. _autocast-eligibility: 112 113Op Eligibility 114-------------- 115Ops that run in ``float64`` or non-floating-point dtypes are not eligible, and will 116run in these types whether or not autocast is enabled. 117 118Only out-of-place ops and Tensor methods are eligible. 119In-place variants and calls that explicitly supply an ``out=...`` Tensor 120are allowed in autocast-enabled regions, but won't go through autocasting. 121For example, in an autocast-enabled region ``a.addmm(b, c)`` can autocast, 122but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot. 123For best performance and stability, prefer out-of-place ops in autocast-enabled 124regions. 125 126Ops called with an explicit ``dtype=...`` argument are not eligible, 127and will produce output that respects the ``dtype`` argument. 128 129.. _autocast-cuda-op-reference: 130 131CUDA Op-Specific Behavior 132------------------------- 133The following lists describe the behavior of eligible ops in autocast-enabled regions. 134These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, 135as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, 136they go through autocasting regardless of the namespace. 137 138Ops not listed below do not go through autocasting. They run in the type 139defined by their inputs. However, autocasting may still change the type 140in which unlisted ops run if they're downstream from autocasted ops. 141 142If an op is unlisted, we assume it's numerically stable in ``float16``. 143If you believe an unlisted op is numerically unstable in ``float16``, 144please file an issue. 145 146CUDA Ops that can autocast to ``float16`` 147""""""""""""""""""""""""""""""""""""""""" 148 149``__matmul__``, 150``addbmm``, 151``addmm``, 152``addmv``, 153``addr``, 154``baddbmm``, 155``bmm``, 156``chain_matmul``, 157``multi_dot``, 158``conv1d``, 159``conv2d``, 160``conv3d``, 161``conv_transpose1d``, 162``conv_transpose2d``, 163``conv_transpose3d``, 164``GRUCell``, 165``linear``, 166``LSTMCell``, 167``matmul``, 168``mm``, 169``mv``, 170``prelu``, 171``RNNCell`` 172 173CUDA Ops that can autocast to ``float32`` 174""""""""""""""""""""""""""""""""""""""""" 175 176``__pow__``, 177``__rdiv__``, 178``__rpow__``, 179``__rtruediv__``, 180``acos``, 181``asin``, 182``binary_cross_entropy_with_logits``, 183``cosh``, 184``cosine_embedding_loss``, 185``cdist``, 186``cosine_similarity``, 187``cross_entropy``, 188``cumprod``, 189``cumsum``, 190``dist``, 191``erfinv``, 192``exp``, 193``expm1``, 194``group_norm``, 195``hinge_embedding_loss``, 196``kl_div``, 197``l1_loss``, 198``layer_norm``, 199``log``, 200``log_softmax``, 201``log10``, 202``log1p``, 203``log2``, 204``margin_ranking_loss``, 205``mse_loss``, 206``multilabel_margin_loss``, 207``multi_margin_loss``, 208``nll_loss``, 209``norm``, 210``normalize``, 211``pdist``, 212``poisson_nll_loss``, 213``pow``, 214``prod``, 215``reciprocal``, 216``rsqrt``, 217``sinh``, 218``smooth_l1_loss``, 219``soft_margin_loss``, 220``softmax``, 221``softmin``, 222``softplus``, 223``sum``, 224``renorm``, 225``tan``, 226``triplet_margin_loss`` 227 228CUDA Ops that promote to the widest input type 229"""""""""""""""""""""""""""""""""""""""""""""" 230These ops don't require a particular dtype for stability, but take multiple inputs 231and require that the inputs' dtypes match. If all of the inputs are 232``float16``, the op runs in ``float16``. If any of the inputs is ``float32``, 233autocast casts all inputs to ``float32`` and runs the op in ``float32``. 234 235``addcdiv``, 236``addcmul``, 237``atan2``, 238``bilinear``, 239``cross``, 240``dot``, 241``grid_sample``, 242``index_put``, 243``scatter_add``, 244``tensordot`` 245 246Some ops not listed here (e.g., binary ops like ``add``) natively promote 247inputs without autocasting's intervention. If inputs are a mixture of ``float16`` 248and ``float32``, these ops run in ``float32`` and produce ``float32`` output, 249regardless of whether autocast is enabled. 250 251Prefer ``binary_cross_entropy_with_logits`` over ``binary_cross_entropy`` 252""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" 253The backward passes of :func:`torch.nn.functional.binary_cross_entropy` (and :mod:`torch.nn.BCELoss`, which wraps it) 254can produce gradients that aren't representable in ``float16``. In autocast-enabled regions, the forward input 255may be ``float16``, which means the backward gradient must be representable in ``float16`` (autocasting ``float16`` 256forward inputs to ``float32`` doesn't help, because that cast must be reversed in backward). 257Therefore, ``binary_cross_entropy`` and ``BCELoss`` raise an error in autocast-enabled regions. 258 259Many models use a sigmoid layer right before the binary cross entropy layer. 260In this case, combine the two layers using :func:`torch.nn.functional.binary_cross_entropy_with_logits` 261or :mod:`torch.nn.BCEWithLogitsLoss`. ``binary_cross_entropy_with_logits`` and ``BCEWithLogits`` 262are safe to autocast. 263 264.. _autocast-xpu-op-reference: 265 266XPU Op-Specific Behavior (Experimental) 267--------------------------------------- 268The following lists describe the behavior of eligible ops in autocast-enabled regions. 269These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, 270as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, 271they go through autocasting regardless of the namespace. 272 273Ops not listed below do not go through autocasting. They run in the type 274defined by their inputs. However, autocasting may still change the type 275in which unlisted ops run if they're downstream from autocasted ops. 276 277If an op is unlisted, we assume it's numerically stable in ``float16``. 278If you believe an unlisted op is numerically unstable in ``float16``, 279please file an issue. 280 281XPU Ops that can autocast to ``float16`` 282"""""""""""""""""""""""""""""""""""""""" 283 284``addbmm``, 285``addmm``, 286``addmv``, 287``addr``, 288``baddbmm``, 289``bmm``, 290``chain_matmul``, 291``multi_dot``, 292``conv1d``, 293``conv2d``, 294``conv3d``, 295``conv_transpose1d``, 296``conv_transpose2d``, 297``conv_transpose3d``, 298``GRUCell``, 299``linear``, 300``LSTMCell``, 301``matmul``, 302``mm``, 303``mv``, 304``RNNCell`` 305 306XPU Ops that can autocast to ``float32`` 307"""""""""""""""""""""""""""""""""""""""" 308 309``__pow__``, 310``__rdiv__``, 311``__rpow__``, 312``__rtruediv__``, 313``binary_cross_entropy_with_logits``, 314``cosine_embedding_loss``, 315``cosine_similarity``, 316``cumsum``, 317``dist``, 318``exp``, 319``group_norm``, 320``hinge_embedding_loss``, 321``kl_div``, 322``l1_loss``, 323``layer_norm``, 324``log``, 325``log_softmax``, 326``margin_ranking_loss``, 327``nll_loss``, 328``normalize``, 329``poisson_nll_loss``, 330``pow``, 331``reciprocal``, 332``rsqrt``, 333``soft_margin_loss``, 334``softmax``, 335``softmin``, 336``sum``, 337``triplet_margin_loss`` 338 339XPU Ops that promote to the widest input type 340""""""""""""""""""""""""""""""""""""""""""""" 341These ops don't require a particular dtype for stability, but take multiple inputs 342and require that the inputs' dtypes match. If all of the inputs are 343``float16``, the op runs in ``float16``. If any of the inputs is ``float32``, 344autocast casts all inputs to ``float32`` and runs the op in ``float32``. 345 346``bilinear``, 347``cross``, 348``grid_sample``, 349``index_put``, 350``scatter_add``, 351``tensordot`` 352 353Some ops not listed here (e.g., binary ops like ``add``) natively promote 354inputs without autocasting's intervention. If inputs are a mixture of ``float16`` 355and ``float32``, these ops run in ``float32`` and produce ``float32`` output, 356regardless of whether autocast is enabled. 357 358.. _autocast-cpu-op-reference: 359 360CPU Op-Specific Behavior 361------------------------ 362The following lists describe the behavior of eligible ops in autocast-enabled regions. 363These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`, 364as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces, 365they go through autocasting regardless of the namespace. 366 367Ops not listed below do not go through autocasting. They run in the type 368defined by their inputs. However, autocasting may still change the type 369in which unlisted ops run if they're downstream from autocasted ops. 370 371If an op is unlisted, we assume it's numerically stable in ``bfloat16``. 372If you believe an unlisted op is numerically unstable in ``bfloat16``, 373please file an issue. ``float16`` shares the lists of ``bfloat16``. 374 375CPU Ops that can autocast to ``bfloat16`` 376""""""""""""""""""""""""""""""""""""""""" 377 378``conv1d``, 379``conv2d``, 380``conv3d``, 381``bmm``, 382``mm``, 383``linalg_vecdot``, 384``baddbmm``, 385``addmm``, 386``addbmm``, 387``linear``, 388``matmul``, 389``_convolution``, 390``conv_tbc``, 391``mkldnn_rnn_layer``, 392``conv_transpose1d``, 393``conv_transpose2d``, 394``conv_transpose3d``, 395``prelu``, 396``scaled_dot_product_attention``, 397``_native_multi_head_attention`` 398 399CPU Ops that can autocast to ``float32`` 400"""""""""""""""""""""""""""""""""""""""" 401 402``avg_pool3d``, 403``binary_cross_entropy``, 404``grid_sampler``, 405``grid_sampler_2d``, 406``_grid_sampler_2d_cpu_fallback``, 407``grid_sampler_3d``, 408``polar``, 409``prod``, 410``quantile``, 411``nanquantile``, 412``stft``, 413``cdist``, 414``trace``, 415``view_as_complex``, 416``cholesky``, 417``cholesky_inverse``, 418``cholesky_solve``, 419``inverse``, 420``lu_solve``, 421``orgqr``, 422``inverse``, 423``ormqr``, 424``pinverse``, 425``max_pool3d``, 426``max_unpool2d``, 427``max_unpool3d``, 428``adaptive_avg_pool3d``, 429``reflection_pad1d``, 430``reflection_pad2d``, 431``replication_pad1d``, 432``replication_pad2d``, 433``replication_pad3d``, 434``mse_loss``, 435``cosine_embedding_loss``, 436``nll_loss``, 437``nll_loss2d``, 438``hinge_embedding_loss``, 439``poisson_nll_loss``, 440``cross_entropy_loss``, 441``l1_loss``, 442``huber_loss``, 443``margin_ranking_loss``, 444``soft_margin_loss``, 445``triplet_margin_loss``, 446``multi_margin_loss``, 447``ctc_loss``, 448``kl_div``, 449``multilabel_margin_loss``, 450``binary_cross_entropy_with_logits``, 451``fft_fft``, 452``fft_ifft``, 453``fft_fft2``, 454``fft_ifft2``, 455``fft_fftn``, 456``fft_ifftn``, 457``fft_rfft``, 458``fft_irfft``, 459``fft_rfft2``, 460``fft_irfft2``, 461``fft_rfftn``, 462``fft_irfftn``, 463``fft_hfft``, 464``fft_ihfft``, 465``linalg_cond``, 466``linalg_matrix_rank``, 467``linalg_solve``, 468``linalg_cholesky``, 469``linalg_svdvals``, 470``linalg_eigvals``, 471``linalg_eigvalsh``, 472``linalg_inv``, 473``linalg_householder_product``, 474``linalg_tensorinv``, 475``linalg_tensorsolve``, 476``fake_quantize_per_tensor_affine``, 477``geqrf``, 478``_lu_with_info``, 479``qr``, 480``svd``, 481``triangular_solve``, 482``fractional_max_pool2d``, 483``fractional_max_pool3d``, 484``adaptive_max_pool3d``, 485``multilabel_margin_loss_forward``, 486``linalg_qr``, 487``linalg_cholesky_ex``, 488``linalg_svd``, 489``linalg_eig``, 490``linalg_eigh``, 491``linalg_lstsq``, 492``linalg_inv_ex`` 493 494CPU Ops that promote to the widest input type 495""""""""""""""""""""""""""""""""""""""""""""" 496These ops don't require a particular dtype for stability, but take multiple inputs 497and require that the inputs' dtypes match. If all of the inputs are 498``bfloat16``, the op runs in ``bfloat16``. If any of the inputs is ``float32``, 499autocast casts all inputs to ``float32`` and runs the op in ``float32``. 500 501``cat``, 502``stack``, 503``index_copy`` 504 505Some ops not listed here (e.g., binary ops like ``add``) natively promote 506inputs without autocasting's intervention. If inputs are a mixture of ``bfloat16`` 507and ``float32``, these ops run in ``float32`` and produce ``float32`` output, 508regardless of whether autocast is enabled. 509 510 511.. This module needs to be documented. Adding here in the meantime 512.. for tracking purposes 513.. py:module:: torch.amp.autocast_mode 514.. py:module:: torch.cpu.amp.autocast_mode 515.. py:module:: torch.cuda.amp.autocast_mode 516.. py:module:: torch.cuda.amp.common 517.. py:module:: torch.amp.grad_scaler 518.. py:module:: torch.cpu.amp.grad_scaler 519.. py:module:: torch.cuda.amp.grad_scaler 520