• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
4 
5 namespace torch::aot_inductor {
6 
7 template <typename T>
8 struct ThreadLocalCachedOutputTensor;
9 
10 template <>
11 struct ThreadLocalCachedOutputTensor<RAIIAtenTensorHandle> {
12   explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {}
13   void copy_data_from(const RAIIAtenTensorHandle& handle) {
14     throw std::runtime_error("can't happen");
15   }
16 
17   AtenTensorHandle tensor() const {
18     throw std::runtime_error("can't happen");
19   }
20 };
21 
22 template <>
23 struct ThreadLocalCachedOutputTensor<AtenTensorHandle> {
24   explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {}
25   void copy_data_from(const AtenTensorHandle& handle) {
26     throw std::runtime_error("can't happen");
27   }
28 
29   AtenTensorHandle tensor() const {
30     throw std::runtime_error("can't happen");
31   }
32 };
33 
34 template <>
35 struct ThreadLocalCachedOutputTensor<ConstantHandle> {
36   explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {}
37   void copy_data_from(const ConstantHandle& handle) {
38     throw std::runtime_error("can't happen");
39   }
40 
41   AtenTensorHandle tensor() const {
42     throw std::runtime_error("can't happen");
43   }
44 };
45 
46 template <typename T>
47 struct ThreadLocalCachedOutputTensor<ArrayRefTensor<T>> {
48   explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor<T>& t) {
49     realloc(t);
50   }
51 
52   void copy_data_from(const ArrayRefTensor<T>& t) {
53     if (t.numel() > capacity_) {
54       realloc(t);
55     }
56     std::copy(t.data(), t.data() + t.numel(), storage_.get());
57   }
58 
59   AtenTensorHandle tensor() const {
60     return tensor_.get();
61   }
62 
63  private:
64   void realloc(const ArrayRefTensor<T>& t) {
65     capacity_ = t.numel();
66     storage_ = std::make_unique<T[]>(t.numel());
67     AtenTensorHandle handle = nullptr;
68     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
69         storage_.get(),
70         t.sizes().size(),
71         t.sizes().data(),
72         t.strides().data(),
73         0,
74         aoti_torch_dtype<std::remove_const_t<T>>(),
75         t.device_type(),
76         t.device_idx(),
77         &handle));
78     tensor_ = handle;
79   }
80 
81   std::unique_ptr<T[]> storage_;
82   int64_t capacity_ = 0;
83   RAIIAtenTensorHandle tensor_;
84 };
85 
86 template <typename T>
87 struct ThreadLocalCachedOutputArray;
88 
89 // Just needs to compile, doesn't need to do anything.
90 template <>
91 struct ThreadLocalCachedOutputArray<RAIIAtenTensorHandle> {
92   explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) {
93     throw std::runtime_error("can't happen");
94   }
95 
96   // Not supported yet! We would need to put contiguous() or
97   // expect_contiguous() into the ABI.
98   void copy_data_from(const RAIIAtenTensorHandle&) {
99     throw std::runtime_error("can't happen");
100   }
101 
102   template <typename U>
103   ArrayRefTensor<U> arrayref_tensor() const {
104     throw std::runtime_error("can't happen");
105   }
106 };
107 
108 // Just needs to compile, doesn't need to do anything.
109 template <>
110 struct ThreadLocalCachedOutputArray<ConstantHandle> {
111   explicit ThreadLocalCachedOutputArray(const ConstantHandle&) {
112     throw std::runtime_error("can't happen");
113   }
114 
115   // Not supported yet! We would need to put contiguous() or
116   // expect_contiguous() into the ABI.
117   void copy_data_from(const ConstantHandle&) {
118     throw std::runtime_error("can't happen");
119   }
120 
121   template <typename U>
122   ArrayRefTensor<U> arrayref_tensor() const {
123     throw std::runtime_error("can't happen");
124   }
125 };
126 
127 template <typename T>
128 struct ThreadLocalCachedOutputArray<ArrayRefTensor<T>> {
129   explicit ThreadLocalCachedOutputArray(const ArrayRefTensor<T>& t) {}
130 
131   template <
132       typename U,
133       std::enable_if_t<
134           std::is_same_v<std::remove_const_t<T>, std::remove_const_t<U>>,
135           bool> = true>
136   ArrayRefTensor<T> arrayref_tensor() const {
137     return tensor_;
138   }
139 
140   void copy_data_from(const ArrayRefTensor<T>& t) {
141     if (t.numel() > capacity_) {
142       capacity_ = t.numel();
143       storage_ = std::make_unique<T[]>(capacity_);
144     }
145     std::copy(t.data(), t.data() + t.numel(), storage_.get());
146     tensor_ = t;
147     tensor_.set_arrayref(MiniArrayRef<T>(storage_.get(), t.numel()));
148   }
149 
150  private:
151   std::unique_ptr<T[]> storage_;
152   uint32_t capacity_ = 0;
153   ArrayRefTensor<T> tensor_;
154 };
155 
156 } // namespace torch::aot_inductor
157