1.. _torch.compiler_overview: 2 3torch.compiler 4============== 5 6``torch.compiler`` is a namespace through which some of the internal compiler 7methods are surfaced for user consumption. The main function and the feature in 8this namespace is ``torch.compile``. 9 10``torch.compile`` is a PyTorch function introduced in PyTorch 2.x that aims to 11solve the problem of accurate graph capturing in PyTorch and ultimately enable 12software engineers to run their PyTorch programs faster. ``torch.compile`` is 13written in Python and it marks the transition of PyTorch from C++ to Python. 14 15``torch.compile`` leverages the following underlying technologies: 16 17* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython 18 feature called the Frame Evaluation API to safely capture PyTorch graphs. 19 Methods that are available externally for PyTorch users are surfaced 20 through the ``torch.compiler`` namespace. 21 22* **TorchInductor** is the default ``torch.compile`` deep learning compiler 23 that generates fast code for multiple accelerators and backends. You 24 need to use a backend compiler to make speedups through ``torch.compile`` 25 possible. For NVIDIA, AMD and Intel GPUs, it leverages OpenAI Triton as the key 26 building block. 27 28* **AOT Autograd** captures not only the user-level code, but also backpropagation, 29 which results in capturing the backwards pass "ahead-of-time". This enables 30 acceleration of both forwards and backwards pass using TorchInductor. 31 32.. note:: In some cases, the terms ``torch.compile``, TorchDynamo, ``torch.compiler`` 33 might be used interchangeably in this documentation. 34 35As mentioned above, to run your workflows faster, ``torch.compile`` through 36TorchDynamo requires a backend that converts the captured graphs into a fast 37machine code. Different backends can result in various optimization gains. 38The default backend is called TorchInductor, also known as *inductor*, 39TorchDynamo has a list of supported backends developed by our partners, 40which can be see by running ``torch.compiler.list_backends()`` each of which 41with its optional dependencies. 42 43Some of the most commonly used backends include: 44 45**Training & inference backends** 46 47.. list-table:: 48 :widths: 50 50 49 :header-rows: 1 50 51 * - Backend 52 - Description 53 * - ``torch.compile(m, backend="inductor")`` 54 - Uses the TorchInductor backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__ 55 * - ``torch.compile(m, backend="cudagraphs")`` 56 - CUDA graphs with AOT Autograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__ 57 * - ``torch.compile(m, backend="ipex")`` 58 - Uses IPEX on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__ 59 * - ``torch.compile(m, backend="onnxrt")`` 60 - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more <onnx_dynamo_onnxruntime_backend>` 61 62**Inference-only backends** 63 64.. list-table:: 65 :widths: 50 50 66 :header-rows: 1 67 68 * - Backend 69 - Description 70 * - ``torch.compile(m, backend="tensorrt")`` 71 - Uses Torch-TensorRT for inference optimizations. Requires ``import torch_tensorrt`` in the calling script to register backend. `Read more <https://github.com/pytorch/TensorRT>`__ 72 * - ``torch.compile(m, backend="ipex")`` 73 - Uses IPEX for inference on CPU. `Read more <https://github.com/intel/intel-extension-for-pytorch>`__ 74 * - ``torch.compile(m, backend="tvm")`` 75 - Uses Apache TVM for inference optimizations. `Read more <https://tvm.apache.org/>`__ 76 * - ``torch.compile(m, backend="openvino")`` 77 - Uses OpenVINO for inference optimizations. `Read more <https://docs.openvino.ai/torchcompile>`__ 78 79Read More 80~~~~~~~~~ 81 82.. toctree:: 83 :caption: Getting Started for PyTorch Users 84 :maxdepth: 1 85 86 torch.compiler_get_started 87 torch.compiler_api 88 torch.compiler_fine_grain_apis 89 torch.compiler_aot_inductor 90 torch.compiler_inductor_profiling 91 torch.compiler_profiling_torch_compile 92 torch.compiler_faq 93 torch.compiler_troubleshooting 94 torch.compiler_performance_dashboard 95 96.. 97 _If you want to contribute a developer-level topic 98 that provides in-depth overview of a torch._dynamo feature, 99 add in the below toc. 100 101.. toctree:: 102 :caption: Deep Dive for PyTorch Developers 103 :maxdepth: 1 104 105 torch.compiler_dynamo_overview 106 torch.compiler_dynamo_deepdive 107 torch.compiler_dynamic_shapes 108 torch.compiler_nn_module 109 torch.compiler_best_practices_for_backends 110 torch.compiler_cudagraph_trees 111 torch.compiler_fake_tensor 112 113.. toctree:: 114 :caption: HowTo for PyTorch Backend Vendors 115 :maxdepth: 1 116 117 torch.compiler_custom_backends 118 torch.compiler_transformations 119 torch.compiler_ir 120