• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <iostream>
4 #include <memory>
5 #include <sstream>
6 #include <stdexcept>
7 #include <string>
8 #include <vector>
9 
10 // WARNING: Be careful when adding new includes here. This header will be used
11 // in model.so, and should not refer to any aten/c10 headers except the stable
12 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
13 // applies to other files under torch/csrc/inductor/aoti_runtime/.
14 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
15 
16 #if defined(__GNUC__) || defined(__clang__)
17 #define AOTI_NOINLINE __attribute__((noinline))
18 #elif _MSC_VER
19 #define AOTI_NOINLINE __declspec(noinline)
20 #else
21 #define AOTI_NOINLINE
22 #endif
23 
throw_exception(const char * call,const char * file,int64_t line)24 AOTI_NOINLINE static void throw_exception(
25     const char* call,
26     const char* file,
27     int64_t line) {
28   std::stringstream ss;
29   ss << call << " API call failed at " << file << ", line " << line;
30   throw std::runtime_error(ss.str());
31 }
32 
33 #define AOTI_TORCH_ERROR_CODE_CHECK(call)       \
34   if ((call) != AOTI_TORCH_SUCCESS) {           \
35     throw_exception(#call, __FILE__, __LINE__); \
36   }
37 
38 using AOTIRuntimeError = int32_t;
39 #define AOTI_RUNTIME_SUCCESS 0
40 #define AOTI_RUNTIME_FAILURE 1
41 
42 #define AOTI_RUNTIME_ERROR_CODE_CHECK(call)     \
43   if ((call) != AOTI_RUNTIME_SUCCESS) {         \
44     throw_exception(#call, __FILE__, __LINE__); \
45   }
46 
47 namespace torch::aot_inductor {
48 
49 using DeleterFnPtr = void (*)(void*);
50 
noop_deleter(void *)51 inline void noop_deleter(void*) {}
52 
delete_tensor_object(void * ptr)53 inline void delete_tensor_object(void* ptr) {
54   AOTI_TORCH_ERROR_CODE_CHECK(
55       aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
56 }
57 
58 // RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
59 class RAIIAtenTensorHandle {
60  public:
RAIIAtenTensorHandle()61   RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {}
62   RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete;
63   RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete;
64 
65   // Steal the ownership from another RAIIAtenTensorHandle using std::move
66   RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default;
67   RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default;
68 
69   // Steal the ownership from raw AtenTensorHandle
RAIIAtenTensorHandle(AtenTensorHandle handle)70   RAIIAtenTensorHandle(AtenTensorHandle handle)
71       : handle_(handle, delete_tensor_object) {}
72 
~RAIIAtenTensorHandle()73   ~RAIIAtenTensorHandle() {
74     handle_.reset();
75   }
76 
77   // Return a raw AtenTensorHandle to be used by aoti_torch functions
78   // Note: this function does NOT transfer the ownership of the handle
AtenTensorHandle()79   operator AtenTensorHandle() const {
80     return handle_.get();
81   }
82 
release()83   AtenTensorHandle release() {
84     return handle_.release();
85   }
86 
get()87   AtenTensorHandle get() const {
88     return handle_.get();
89   }
90 
reset()91   void reset() {
92     handle_.reset();
93   }
94 
size(int64_t d)95   int64_t size(int64_t d) {
96     int64_t size = 0;
97     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size));
98     return size;
99   }
100 
stride(int64_t d)101   int64_t stride(int64_t d) {
102     int64_t stride = 0;
103     AOTI_TORCH_ERROR_CODE_CHECK(
104         aoti_torch_get_stride(handle_.get(), d, &stride));
105     return stride;
106   }
107 
storage_offset()108   int64_t storage_offset() {
109     int64_t storage_offset = 0;
110     AOTI_TORCH_ERROR_CODE_CHECK(
111         aoti_torch_get_storage_offset(handle_.get(), &storage_offset));
112     return storage_offset;
113   }
114 
115  private:
116   std::unique_ptr<AtenTensorOpaque, DeleterFnPtr> handle_;
117 };
118 
119 // Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle
steal_from_raw_handles_to_raii_handles(AtenTensorHandle * handles,size_t size)120 inline std::vector<RAIIAtenTensorHandle> steal_from_raw_handles_to_raii_handles(
121     AtenTensorHandle* handles,
122     size_t size) {
123   std::vector<RAIIAtenTensorHandle> result;
124   result.reserve(size);
125   for (size_t i = 0; i < size; i++) {
126     result.emplace_back(handles[i]);
127     handles[i] = nullptr;
128   }
129   return result;
130 }
131 
132 class ConstantHandle {
133  public:
134   ConstantHandle() = default;
135 
ConstantHandle(AtenTensorHandle handle)136   explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) {
137     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_));
138   }
139 
AtenTensorHandle()140   operator AtenTensorHandle() const {
141     return handle_;
142   }
143 
tensor()144   AtenTensorHandle tensor() const {
145     return handle_;
146   }
147 
get()148   AtenTensorHandle get() const {
149     return handle_;
150   }
151 
data_ptr()152   void* data_ptr() const {
153     return data_;
154   }
155 
156  private:
157   AtenTensorHandle handle_{};
158   void* data_ = nullptr;
159 };
160 
get_data_ptr_wrapper(const ConstantHandle & constant)161 inline void* get_data_ptr_wrapper(const ConstantHandle& constant) {
162   return constant.data_ptr();
163 }
164 
unwrap_raii_handle_if_needed(const ConstantHandle & handle)165 inline const ConstantHandle& unwrap_raii_handle_if_needed(
166     const ConstantHandle& handle) {
167   return handle;
168 }
169 
170 // Shouldn't be called.
171 inline AtenTensorHandle wrap_with_raii_handle_if_needed(
172     const ConstantHandle& handle) = delete;
173 
174 #define CACHE_TORCH_DTYPE(typename) \
175   static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename()
176 
177 #define CACHE_TORCH_DEVICE(device)                \
178   static auto cached_torch_device_type_##device = \
179       aoti_torch_device_type_##device()
180 
181 #define CACHE_TORCH_LAYOUT(layout) \
182   static auto cached_torch_layout_##layout = aoti_torch_layout_##layout()
183 
184 } // namespace torch::aot_inductor
185