1 #pragma once 2 3 #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h> 4 5 namespace torch::aot_inductor { 6 7 template <typename T> 8 struct ThreadLocalCachedOutputTensor; 9 10 template <> 11 struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> { 12 explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {} 13 void copy_data_from(const RAIIAtenTensorHandle& handle) { 14 throw std::runtime_error("can't happen"); 15 } 16 17 AtenTensorHandle tensor() const { 18 throw std::runtime_error("can't happen"); 19 } 20 }; 21 22 template <> 23 struct ThreadLocalCachedOutputTensor<AtenTensorHandle> { 24 explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {} 25 void copy_data_from(const AtenTensorHandle& handle) { 26 throw std::runtime_error("can't happen"); 27 } 28 29 AtenTensorHandle tensor() const { 30 throw std::runtime_error("can't happen"); 31 } 32 }; 33 34 template <> 35 struct ThreadLocalCachedOutputTensor<ConstantHandle> { 36 explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {} 37 void copy_data_from(const ConstantHandle& handle) { 38 throw std::runtime_error("can't happen"); 39 } 40 41 AtenTensorHandle tensor() const { 42 throw std::runtime_error("can't happen"); 43 } 44 }; 45 46 template <typename T> 47 struct ThreadLocalCachedOutputTensor<ArrayRefTensor<T>> { 48 explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor<T>& t) { 49 realloc(t); 50 } 51 52 void copy_data_from(const ArrayRefTensor<T>& t) { 53 if (t.numel() > capacity_) { 54 realloc(t); 55 } 56 std::copy(t.data(), t.data() + t.numel(), storage_.get()); 57 } 58 59 AtenTensorHandle tensor() const { 60 return tensor_.get(); 61 } 62 63 private: 64 void realloc(const ArrayRefTensor<T>& t) { 65 capacity_ = t.numel(); 66 storage_ = std::make_unique<T[]>(t.numel()); 67 AtenTensorHandle handle = nullptr; 68 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( 69 storage_.get(), 70 t.sizes().size(), 71 t.sizes().data(), 72 t.strides().data(), 73 0, 74 aoti_torch_dtype<std::remove_const_t<T>>(), 75 t.device_type(), 76 t.device_idx(), 77 &handle)); 78 tensor_ = handle; 79 } 80 81 std::unique_ptr<T[]> storage_; 82 int64_t capacity_ = 0; 83 RAIIAtenTensorHandle tensor_; 84 }; 85 86 template <typename T> 87 struct ThreadLocalCachedOutputArray; 88 89 // Just needs to compile, doesn't need to do anything. 90 template <> 91 struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> { 92 explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) { 93 throw std::runtime_error("can't happen"); 94 } 95 96 // Not supported yet! We would need to put contiguous() or 97 // expect_contiguous() into the ABI. 98 void copy_data_from(const RAIIAtenTensorHandle&) { 99 throw std::runtime_error("can't happen"); 100 } 101 102 template <typename U> 103 ArrayRefTensor<U> arrayref_tensor() const { 104 throw std::runtime_error("can't happen"); 105 } 106 }; 107 108 // Just needs to compile, doesn't need to do anything. 109 template <> 110 struct ThreadLocalCachedOutputArray<ConstantHandle> { 111 explicit ThreadLocalCachedOutputArray(const ConstantHandle&) { 112 throw std::runtime_error("can't happen"); 113 } 114 115 // Not supported yet! We would need to put contiguous() or 116 // expect_contiguous() into the ABI. 117 void copy_data_from(const ConstantHandle&) { 118 throw std::runtime_error("can't happen"); 119 } 120 121 template <typename U> 122 ArrayRefTensor<U> arrayref_tensor() const { 123 throw std::runtime_error("can't happen"); 124 } 125 }; 126 127 template <typename T> 128 struct ThreadLocalCachedOutputArray<ArrayRefTensor<T>> { 129 explicit ThreadLocalCachedOutputArray(const ArrayRefTensor<T>& t) {} 130 131 template < 132 typename U, 133 std::enable_if_t< 134 std::is_same_v<std::remove_const_t<T>, std::remove_const_t<U>>, 135 bool> = true> 136 ArrayRefTensor<T> arrayref_tensor() const { 137 return tensor_; 138 } 139 140 void copy_data_from(const ArrayRefTensor<T>& t) { 141 if (t.numel() > capacity_) { 142 capacity_ = t.numel(); 143 storage_ = std::make_unique<T[]>(capacity_); 144 } 145 std::copy(t.data(), t.data() + t.numel(), storage_.get()); 146 tensor_ = t; 147 tensor_.set_arrayref(MiniArrayRef<T>(storage_.get(), t.numel())); 148 } 149 150 private: 151 std::unique_ptr<T[]> storage_; 152 uint32_t capacity_ = 0; 153 ArrayRefTensor<T> tensor_; 154 }; 155 156 } // namespace torch::aot_inductor 157