1 2 3AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models 4================================================================= 5 6.. warning:: 7 8 AOTInductor and its related features are in prototype status and are 9 subject to backwards compatibility breaking changes. 10 11AOTInductor is a specialized version of 12`TorchInductor <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__ 13, designed to process exported PyTorch models, optimize them, and produce shared libraries as well 14as other relevant artifacts. 15These compiled artifacts are specifically crafted for deployment in non-Python environments, 16which are frequently employed for inference deployments on the server side. 17 18In this tutorial, you will gain insight into the process of taking a PyTorch model, exporting it, 19compiling it into a shared library, and conducting model predictions using C++. 20 21 22Model Compilation 23--------------------------- 24 25Using AOTInductor, you can still author the model in Python. The following 26example demonstrates how to invoke ``aot_compile`` to transform the model into a 27shared library. 28 29This API uses ``torch.export`` to capture the model into a computational graph, 30and then uses TorchInductor to generate a .so which can be run in a non-Python 31environment. For comprehensive details on the ``torch._export.aot_compile`` 32API, you can refer to the code 33`here <https://github.com/pytorch/pytorch/blob/92cc52ab0e48a27d77becd37f1683fd442992120/torch/_export/__init__.py#L891-L900C9>`__. 34For more details on ``torch.export``, you can refer to the :ref:`torch.export docs <torch.export>`. 35 36.. note:: 37 38 If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, 39 the following code will compile the model into a shared library for CUDA execution. 40 Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference, 41 it is suggested to enable freezing by setting `export TORCHINDUCTOR_FREEZING=1` 42 before running the Python script below. 43 44.. code-block:: python 45 46 import os 47 import torch 48 49 class Model(torch.nn.Module): 50 def __init__(self): 51 super().__init__() 52 self.fc1 = torch.nn.Linear(10, 16) 53 self.relu = torch.nn.ReLU() 54 self.fc2 = torch.nn.Linear(16, 1) 55 self.sigmoid = torch.nn.Sigmoid() 56 57 def forward(self, x): 58 x = self.fc1(x) 59 x = self.relu(x) 60 x = self.fc2(x) 61 x = self.sigmoid(x) 62 return x 63 64 with torch.no_grad(): 65 device = "cuda" if torch.cuda.is_available() else "cpu" 66 model = Model().to(device=device) 67 example_inputs=(torch.randn(8, 10, device=device),) 68 batch_dim = torch.export.Dim("batch", min=1, max=1024) 69 so_path = torch._export.aot_compile( 70 model, 71 example_inputs, 72 # Specify the first dimension of the input x as dynamic 73 dynamic_shapes={"x": {0: batch_dim}}, 74 # Specify the generated shared library path 75 options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")}, 76 ) 77 78In this illustrative example, the ``Dim`` parameter is employed to designate the first dimension of 79the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified, 80resulting in the shared library being stored in a temporary directory. 81To access this path from the C++ side, we save it to a file for later retrieval within the C++ code. 82 83 84Inference in C++ 85--------------------------- 86 87Next, we use the following C++ file ``inference.cpp`` to load the shared library generated in the 88previous step, enabling us to conduct model predictions directly within a C++ environment. 89 90.. note:: 91 92 The following code snippet assumes your system has a CUDA-enabled device and your model was 93 compiled to run on CUDA as shown previously. 94 In the absence of a GPU, it's necessary to make these adjustments in order to run it on a CPU: 95 1. Change ``model_container_runner_cuda.h`` to ``model_container_runner_cpu.h`` 96 2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunnerCpu`` 97 3. Change ``at::kCUDA`` to ``at::kCPU`` 98 99.. code-block:: cpp 100 101 #include <iostream> 102 #include <vector> 103 104 #include <torch/torch.h> 105 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h> 106 107 int main() { 108 c10::InferenceMode mode; 109 110 torch::inductor::AOTIModelContainerRunnerCuda runner("model.so"); 111 std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)}; 112 std::vector<torch::Tensor> outputs = runner.run(inputs); 113 std::cout << "Result from the first inference:"<< std::endl; 114 std::cout << outputs[0] << std::endl; 115 116 // The second inference uses a different batch size and it works because we 117 // specified that dimension as dynamic when compiling model.so. 118 std::cout << "Result from the second inference:"<< std::endl; 119 std::vector<torch::Tensor> inputs2 = {torch::randn({2, 10}, at::kCUDA)}; 120 std::cout << runner.run(inputs2)[0] << std::endl; 121 122 return 0; 123 } 124 125For building the C++ file, you can make use of the provided ``CMakeLists.txt`` file, which 126automates the process of invoking ``python model.py`` for AOT compilation of the model and compiling 127``inference.cpp`` into an executable binary named ``aoti_example``. 128 129.. code-block:: cmake 130 131 cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 132 project(aoti_example) 133 134 find_package(Torch REQUIRED) 135 136 add_executable(aoti_example inference.cpp model.so) 137 138 add_custom_command( 139 OUTPUT model.so 140 COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py 141 DEPENDS model.py 142 ) 143 144 target_link_libraries(aoti_example "${TORCH_LIBRARIES}") 145 set_property(TARGET aoti_example PROPERTY CXX_STANDARD 17) 146 147 148Provided the directory structure resembles the following, you can execute the subsequent commands 149to construct the binary. It is essential to note that the ``CMAKE_PREFIX_PATH`` variable 150is crucial for CMake to locate the LibTorch library, and it should be set to an absolute path. 151Please be mindful that your path may vary from the one illustrated in this example. 152 153.. code-block:: shell 154 155 aoti_example/ 156 CMakeLists.txt 157 inference.cpp 158 model.py 159 160 161.. code-block:: shell 162 163 $ mkdir build 164 $ cd build 165 $ CMAKE_PREFIX_PATH=/path/to/python/install/site-packages/torch/share/cmake cmake .. 166 $ cmake --build . --config Release 167 168After the ``aoti_example`` binary has been generated in the ``build`` directory, executing it will 169display results akin to the following: 170 171.. code-block:: shell 172 173 $ ./aoti_example 174 Result from the first inference: 175 0.4866 176 0.5184 177 0.4462 178 0.4611 179 0.4744 180 0.4811 181 0.4938 182 0.4193 183 [ CUDAFloatType{8,1} ] 184 Result from the second inference: 185 0.4883 186 0.4703 187 [ CUDAFloatType{2,1} ] 188