• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2
3AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models
4=================================================================
5
6.. warning::
7
8    AOTInductor and its related features are in prototype status and are
9    subject to backwards compatibility breaking changes.
10
11AOTInductor is a specialized version of
12`TorchInductor <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
13, designed to process exported PyTorch models, optimize them, and produce shared libraries as well
14as other relevant artifacts.
15These compiled artifacts are specifically crafted for deployment in non-Python environments,
16which are frequently employed for inference deployments on the server side.
17
18In this tutorial, you will gain insight into the process of taking a PyTorch model, exporting it,
19compiling it into a shared library, and conducting model predictions using C++.
20
21
22Model Compilation
23---------------------------
24
25Using AOTInductor, you can still author the model in Python. The following
26example demonstrates how to invoke ``aot_compile`` to transform the model into a
27shared library.
28
29This API uses ``torch.export`` to capture the model into a computational graph,
30and then uses TorchInductor to generate a .so which can be run in a non-Python
31environment.  For comprehensive details on the ``torch._export.aot_compile``
32API, you can refer to the code
33`here <https://github.com/pytorch/pytorch/blob/92cc52ab0e48a27d77becd37f1683fd442992120/torch/_export/__init__.py#L891-L900C9>`__.
34For more details on ``torch.export``, you can refer to the :ref:`torch.export docs <torch.export>`.
35
36.. note::
37
38   If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support,
39   the following code will compile the model into a shared library for CUDA execution.
40   Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference,
41   it is suggested to enable freezing by setting `export TORCHINDUCTOR_FREEZING=1`
42   before running the Python script below.
43
44.. code-block:: python
45
46    import os
47    import torch
48
49    class Model(torch.nn.Module):
50        def __init__(self):
51            super().__init__()
52            self.fc1 = torch.nn.Linear(10, 16)
53            self.relu = torch.nn.ReLU()
54            self.fc2 = torch.nn.Linear(16, 1)
55            self.sigmoid = torch.nn.Sigmoid()
56
57        def forward(self, x):
58            x = self.fc1(x)
59            x = self.relu(x)
60            x = self.fc2(x)
61            x = self.sigmoid(x)
62            return x
63
64    with torch.no_grad():
65        device = "cuda" if torch.cuda.is_available() else "cpu"
66        model = Model().to(device=device)
67        example_inputs=(torch.randn(8, 10, device=device),)
68        batch_dim = torch.export.Dim("batch", min=1, max=1024)
69        so_path = torch._export.aot_compile(
70            model,
71            example_inputs,
72            # Specify the first dimension of the input x as dynamic
73            dynamic_shapes={"x": {0: batch_dim}},
74            # Specify the generated shared library path
75            options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")},
76        )
77
78In this illustrative example, the ``Dim`` parameter is employed to designate the first dimension of
79the input variable "x" as dynamic. Notably, the path and name of the compiled library remain unspecified,
80resulting in the shared library being stored in a temporary directory.
81To access this path from the C++ side, we save it to a file for later retrieval within the C++ code.
82
83
84Inference in C++
85---------------------------
86
87Next, we use the following C++ file ``inference.cpp`` to load the shared library generated in the
88previous step, enabling us to conduct model predictions directly within a C++ environment.
89
90.. note::
91
92    The following code snippet assumes your system has a CUDA-enabled device and your model was
93    compiled to run on CUDA as shown previously.
94    In the absence of a GPU, it's necessary to make these adjustments in order to run it on a CPU:
95    1. Change ``model_container_runner_cuda.h`` to ``model_container_runner_cpu.h``
96    2. Change ``AOTIModelContainerRunnerCuda`` to ``AOTIModelContainerRunnerCpu``
97    3. Change ``at::kCUDA`` to ``at::kCPU``
98
99.. code-block:: cpp
100
101    #include <iostream>
102    #include <vector>
103
104    #include <torch/torch.h>
105    #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
106
107    int main() {
108        c10::InferenceMode mode;
109
110        torch::inductor::AOTIModelContainerRunnerCuda runner("model.so");
111        std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)};
112        std::vector<torch::Tensor> outputs = runner.run(inputs);
113        std::cout << "Result from the first inference:"<< std::endl;
114        std::cout << outputs[0] << std::endl;
115
116        // The second inference uses a different batch size and it works because we
117        // specified that dimension as dynamic when compiling model.so.
118        std::cout << "Result from the second inference:"<< std::endl;
119        std::vector<torch::Tensor> inputs2 = {torch::randn({2, 10}, at::kCUDA)};
120        std::cout << runner.run(inputs2)[0] << std::endl;
121
122        return 0;
123    }
124
125For building the C++ file, you can make use of the provided ``CMakeLists.txt`` file, which
126automates the process of invoking ``python model.py`` for AOT compilation of the model and compiling
127``inference.cpp`` into an executable binary named ``aoti_example``.
128
129.. code-block:: cmake
130
131    cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
132    project(aoti_example)
133
134    find_package(Torch REQUIRED)
135
136    add_executable(aoti_example inference.cpp model.so)
137
138    add_custom_command(
139        OUTPUT model.so
140        COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py
141        DEPENDS model.py
142    )
143
144    target_link_libraries(aoti_example "${TORCH_LIBRARIES}")
145    set_property(TARGET aoti_example PROPERTY CXX_STANDARD 17)
146
147
148Provided the directory structure resembles the following, you can execute the subsequent commands
149to construct the binary. It is essential to note that the ``CMAKE_PREFIX_PATH`` variable
150is crucial for CMake to locate the LibTorch library, and it should be set to an absolute path.
151Please be mindful that your path may vary from the one illustrated in this example.
152
153.. code-block:: shell
154
155    aoti_example/
156        CMakeLists.txt
157        inference.cpp
158        model.py
159
160
161.. code-block:: shell
162
163    $ mkdir build
164    $ cd build
165    $ CMAKE_PREFIX_PATH=/path/to/python/install/site-packages/torch/share/cmake cmake ..
166    $ cmake --build . --config Release
167
168After the ``aoti_example`` binary has been generated in the ``build`` directory, executing it will
169display results akin to the following:
170
171.. code-block:: shell
172
173    $ ./aoti_example
174    Result from the first inference:
175    0.4866
176    0.5184
177    0.4462
178    0.4611
179    0.4744
180    0.4811
181    0.4938
182    0.4193
183    [ CUDAFloatType{8,1} ]
184    Result from the second inference:
185    0.4883
186    0.4703
187    [ CUDAFloatType{2,1} ]
188