1.. _torchinductor-gpu-profiling: 2 3TorchInductor GPU Profiling 4=========================== 5 6This section lists useful commands and workflows that can help 7you dive into a model’s performance in TorchInductor. When a model is not 8running as fast as expected, you may want to check individual kernels of the 9model. Usually, those kernels taking the majority of the 10GPU time are the most interesting ones. After that, you 11may also want to run individual kernels directly and inspect its perf. 12PyTorch provides tools to cover everything mentioned above. 13 14Relevant Environment Variables 15~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 16 17You can use the following environment variables in your analysis: 18 19- ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` 20 21 - By default, TorchInductor names a Triton kernel as ``‘triton\_’``. When 22 this environmental variable is enabled, inductor generates a more 23 meaningful kernel name in the trace, for example, 24 ``triton_poi_fused_cat_155`` which contains the kernel category 25 (``poi`` for pointwise) and original ATen 26 operator. This config is disabled by default to improve the chance of 27 compilation cache hit. 28 29- ``TORCHINDUCTOR_BENCHMARK_KERNEL`` 30 31 - Enabling this will make inductor codegen harness to benchmark 32 individual triton kernels. 33 34- ``TORCHINDUCTOR_MAX_AUTOTUNE`` 35 36 - Inductor autotuner will benchmark more ``triton.Configs`` and pick the 37 one with the best performance results. This will increase compilation 38 time with the hope to improve performance. 39 40Breakdown Model GPU Time 41~~~~~~~~~~~~~~~~~~~~~~~~ 42 43Below are the steps to breakdown execution time of a model into 44individual kernels. We take ``mixnet_l`` as an example. 45 461. Run the benchmark script for the model: 47 48 .. code-block:: bash 49 50 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 51 python -u benchmarks/dynamo/timm_models.py –backend inductor –amp 52 –performance –dashboard –only mixnet_l –disable-cudagraphs –training 53 54 .. note:: The tool relies on kernel name to decide its category. Enabling 55 ``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` is crucial for that. 56 572. In the output log, look for lines: 58 59 .. code-block:: bash 60 61 **Compiled module path: 62 /tmp/torchinductor_shunting/qz/cqz7hvhood7y3psp7fy6msjxsxyli7qiwiybizdwtjw6ffyq5wwd.py** 63 64We have one line for each compiled module. If there are no extra graph 65breaks, we would see 2 such lines in the log, one for the forward graph 66and one for the backward graph. 67 68For our example command, we get the following compiled module for the 69forward and backward graphs respectively: 70 71- https://gist.github.com/shunting314/c2a4d8a28b00fcb5586d0e9d9bf77f9f 72- https://gist.github.com/shunting314/48efc83b12ec3ead950052e4a0220b10 73 743. Now we can dive into the perf for each individual compiled module. 75 Let’s pick the one for the forward graph for illustration purposes. 76 I’ll name it ``fwd.py`` for convenience. Run it directly with the 77 ``-p`` argument: 78 79 .. code-block:: bash 80 81 **> python fwd.py -p** 82 83See the full output log in this 84`example gist <https://gist.github.com/shunting314/8243734a38b5733ea78479209c0ae893>`__. 85 86In the output, you can notice the following: 87 88* We write a chrome trace file for the profile so we can load the trace and interact with it. In the log, look for lines as follows to find the path of the trace file. 89 90 **Chrome trace for the profile is written to 91 /tmp/compiled_module_profile.json** 92 93 Loading the trace into Chrome (visit chrome://tracing in the chrome 94 browser and load the file as the UI suggested) will show UI as follows: 95 96 .. image:: _static/img/inductor_profiling/trace.png 97 98 You can zoom in and out to check the profile. 99 100* We report the percent of GPU time regarding to the wall time by log line like: 101 102 **Percent of time when GPU is busy: 102.88%** 103 104 Sometimes you may see a value larger than 100%. The reason is because PyTorch 105 uses the kernel execution time with profiling enabled while using wall time 106 with profiling disabled. Profiling may distort the kernel execution time a 107 bit. But overall it should not be a big deal. 108 109 If we run the model like ``densenet121`` with a small batch size, we would see 110 low percent of time when GPU is busy: 111 112 :: 113 114 (Forward graph) Percent of time when GPU is busy: 32.69% 115 116 This means the model has a lot of CPU overhead. This is consistent with 117 the fact that enabling cudagraphs improve densenet121’s perf a lot. 118 119* We can break down the GPU time to different categories of kernels. 120 In the ``mixnet_l`` example, we see 121 122 - pointwise kernel takes 28.58% 123 - reduction kernel takes 13.85% 124 - persistent reduction kernel takes 3.89% 125 - the rest are cutlass/cudnn kernels for mm/conv which takes 56.57% 126 127 This information can be found in the summary line (last line) 128 of the report for each kernel category. 129 130* We also call zoom into a certain category of kernels. For example, 131 let’s check reduction kernels: 132 133 .. image:: _static/img/inductor_profiling/kernel_breakdown.png 134 135 We can see an ordered table of execution time for each individual 136 reduction kernel. We also see how many times a kernel is executed. This 137 is helpful for a few reasons: 138 139 - If a kernel only takes a tiny amount of time, for example, 0.1%, 140 improving it will at most bring 0.1% overall gain. It is not 141 worth spending a lot of effort on it. 142 - Ff a kernel takes 2% of time, improving it by 2x will bring in 1% 143 overall gain which justifies the effort. 144 145Benchmark Individual Triton Kernel 146~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 147 148Let’s say we want to take a closer look at 149``triton_red_fused\__native_batch_norm_legit_functional_16`` which is the 150most expensive reduction kernel and takes 2.19% of overall wall time for 151the forward graph. 152 153We can lookup the kernel name in the ``fwd.py``, and find comment like: 154 155**# kernel path: 156/tmp/torchinductor_shunting/jk/cjk2vm3446xrk7rth7hr6pun7xxo3dnzubwcn6ydrpifal4eykrz.py** 157 158.. image:: _static/img/inductor_profiling/inductor_code.png 159 160I’ll rename it k.py for convenience. Here is a paste for this 161`file <https://gist.github.com/shunting314/96a0afef9dce53d6357bf1633094f358>`__. 162 163``k.py`` is a standalone Python module containing the kernel code and its 164benchmark. 165 166Run ``k.py`` directly will report its execution time and bandwidth: 167 168.. image:: _static/img/inductor_profiling/terminal_printout.png 169 170We can check if max-autotune helps this kernel, by running: 171 172.. code-block:: bash 173 174 **TORCHINDUCTOR_MAX_AUTOTUNE=1 python /tmp/k.py** 175 176We may also temporarily add more reduction heuristics and run the script 177again to check how that helps with the kernel. 178