1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <functional>
12 #include <memory>
13 #include <vector>
14
15 #include <executorch/runtime/core/exec_aten/exec_aten.h>
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17 #include <executorch/runtime/platform/compiler.h>
18
19 namespace executorch {
20 namespace extension {
21
22 #ifndef USE_ATEN_LIB
23 /**
24 * A smart pointer for managing the lifecycle of a TensorImpl.
25 *
26 * TensorImplPtr uses a shared pointer since multiple Tensor objects may
27 * share the same underlying data and metadata. This shared ownership ensures
28 * that the TensorImpl is destroyed only when all references to it are gone,
29 * providing a safe and efficient way to manage shared tensor implementations.
30 * It serves as a safer, more convenient alternative to the original TensorImpl,
31 * which does not manage its metadata by design.
32 */
33 using TensorImplPtr = std::shared_ptr<executorch::aten::TensorImpl>;
34 #else
35 /**
36 * A smart pointer type for managing the lifecycle of a TensorImpl.
37 *
38 * TensorImplPtr uses an intrusive pointer when working with ATen, ensuring
39 * efficient reference counting and shared ownership of the underlying data and
40 * metadata.
41 */
42 using TensorImplPtr =
43 c10::intrusive_ptr<executorch::aten::TensorImpl, at::UndefinedTensorImpl>;
44 #endif // USE_ATEN_LIB
45
46 /**
47 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
48 * specified properties.
49 *
50 * @param sizes A vector specifying the size of each dimension.
51 * @param data A pointer to the data buffer.
52 * @param dim_order A vector specifying the order of dimensions.
53 * @param strides A vector specifying the strides of each dimension.
54 * @param type The scalar type of the tensor elements.
55 * @param dynamism Specifies the mutability of the tensor's shape.
56 * @param deleter A custom deleter function for managing the lifetime of the
57 * data buffer. If provided, this deleter is called when the managed TensorImpl
58 * is destroyed.
59 * @return A TensorImplPtr managing the newly created TensorImpl.
60 */
61 TensorImplPtr make_tensor_impl_ptr(
62 std::vector<executorch::aten::SizesType> sizes,
63 void* data,
64 std::vector<executorch::aten::DimOrderType> dim_order,
65 std::vector<executorch::aten::StridesType> strides,
66 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
67 executorch::aten::TensorShapeDynamism dynamism =
68 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
69 std::function<void(void*)> deleter = nullptr);
70
71 /**
72 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
73 * specified properties.
74 *
75 * @param sizes A vector specifying the size of each dimension.
76 * @param data A pointer to the data buffer.
77 * @param type The scalar type of the tensor elements.
78 * @param dynamism Specifies the mutability of the tensor's shape.
79 * @param deleter A custom deleter function for managing the lifetime of the
80 * data buffer. If provided, this deleter is called when the managed TensorImpl
81 * is destroyed.
82 * @return A TensorImplPtr managing the newly created TensorImpl.
83 */
84 inline TensorImplPtr make_tensor_impl_ptr(
85 std::vector<executorch::aten::SizesType> sizes,
86 void* data,
87 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
88 executorch::aten::TensorShapeDynamism dynamism =
89 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
90 std::function<void(void*)> deleter = nullptr) {
91 return make_tensor_impl_ptr(
92 std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
93 }
94
95 /**
96 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
97 * specified properties.
98 *
99 * This template overload is specialized for cases where tensor data is provided
100 * as a vector. If the specified `type` differs from the deduced type of the
101 * vector's elements, and casting is allowed, the data will be cast to the
102 * specified `type`. This allows for flexible creation of tensors with data
103 * vectors of one type and a different scalar type.
104 *
105 * @tparam T The C++ type of the tensor elements, deduced from the vector.
106 * @param sizes A vector specifying the size of each dimension.
107 * @param data A vector containing the tensor's data.
108 * @param dim_order A vector specifying the order of dimensions.
109 * @param strides A vector specifying the strides of each dimension.
110 * @param type The scalar type of the tensor elements. If it differs from the
111 * deduced type, the data will be cast to this type if allowed.
112 * @param dynamism Specifies the mutability of the tensor's shape.
113 * @return A TensorImplPtr that manages the newly created TensorImpl.
114 */
115 template <
116 typename T = float,
117 executorch::aten::ScalarType deduced_type =
118 runtime::CppTypeToScalarType<T>::value>
119 TensorImplPtr make_tensor_impl_ptr(
120 std::vector<executorch::aten::SizesType> sizes,
121 std::vector<T> data,
122 std::vector<executorch::aten::DimOrderType> dim_order = {},
123 std::vector<executorch::aten::StridesType> strides = {},
124 executorch::aten::ScalarType type = deduced_type,
125 executorch::aten::TensorShapeDynamism dynamism =
126 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
127 if (type != deduced_type) {
128 ET_CHECK_MSG(
129 runtime::canCast(deduced_type, type),
130 "Cannot cast deduced type to specified type.");
131 std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
132 ET_SWITCH_REALHBBF16_TYPES(
133 type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
134 std::transform(
135 data.begin(),
136 data.end(),
137 reinterpret_cast<CTYPE*>(casted_data.data()),
138 [](const T& val) { return static_cast<CTYPE>(val); });
139 });
140 const auto raw_data_ptr = casted_data.data();
141 auto data_ptr =
142 std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
143 return make_tensor_impl_ptr(
144 std::move(sizes),
145 raw_data_ptr,
146 std::move(dim_order),
147 std::move(strides),
148 type,
149 dynamism,
150 [data_ptr = std::move(data_ptr)](void*) {});
151 }
152 const auto raw_data_ptr = data.data();
153 auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
154 return make_tensor_impl_ptr(
155 std::move(sizes),
156 raw_data_ptr,
157 std::move(dim_order),
158 std::move(strides),
159 type,
160 dynamism,
161 [data_ptr = std::move(data_ptr)](void*) {});
162 }
163
164 /**
165 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
166 * specified properties.
167 *
168 * This template overload is specialized for cases where tensor data is provided
169 * as a vector. If the specified `type` differs from the deduced type of the
170 * vector's elements, and casting is allowed, the data will be cast to the
171 * specified `type`. This allows for flexible creation of tensors with data
172 * vectors of one type and a different scalar type.
173 *
174 * @tparam T The C++ type of the tensor elements, deduced from the vector.
175 * @param data A vector containing the tensor's data.
176 * @param type The scalar type of the tensor elements. If it differs from the
177 * deduced type, the data will be cast to this type if allowed.
178 * @param dynamism Specifies the mutability of the tensor's shape.
179 * @return A TensorImplPtr that manages the newly created TensorImpl.
180 */
181 template <
182 typename T = float,
183 executorch::aten::ScalarType deduced_type =
184 runtime::CppTypeToScalarType<T>::value>
185 inline TensorImplPtr make_tensor_impl_ptr(
186 std::vector<T> data,
187 executorch::aten::ScalarType type = deduced_type,
188 executorch::aten::TensorShapeDynamism dynamism =
189 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
190 std::vector<executorch::aten::SizesType> sizes{
191 executorch::aten::SizesType(data.size())};
192 return make_tensor_impl_ptr(
193 std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
194 }
195
196 /**
197 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
198 * specified properties.
199 *
200 * This template overload is specialized for cases where tensor data is provided
201 * as an initializer list. If the specified `type` differs from the deduced type
202 * of the initializer list's elements, and casting is allowed, the data will be
203 * cast to the specified `type`. This allows for flexible creation of tensors
204 * with data initializer list of one type and a different scalar type.
205 *
206 * @tparam T The C++ type of the tensor elements, deduced from the initializer
207 * list.
208 * @param sizes A vector specifying the size of each dimension.
209 * @param list An initializer list containing the tensor's data.
210 * @param dim_order A vector specifying the order of dimensions.
211 * @param strides A vector specifying the strides of each dimension.
212 * @param type The scalar type of the tensor elements. If it differs from the
213 * deduced type, the data will be cast to this type if allowed.
214 * @param dynamism Specifies the mutability of the tensor's shape.
215 * @return A TensorImplPtr that manages the newly created TensorImpl.
216 */
217 template <
218 typename T = float,
219 executorch::aten::ScalarType deduced_type =
220 runtime::CppTypeToScalarType<T>::value>
221 inline TensorImplPtr make_tensor_impl_ptr(
222 std::vector<executorch::aten::SizesType> sizes,
223 std::initializer_list<T> list,
224 std::vector<executorch::aten::DimOrderType> dim_order = {},
225 std::vector<executorch::aten::StridesType> strides = {},
226 executorch::aten::ScalarType type = deduced_type,
227 executorch::aten::TensorShapeDynamism dynamism =
228 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
229 return make_tensor_impl_ptr(
230 std::move(sizes),
231 std::vector<T>(std::move(list)),
232 std::move(dim_order),
233 std::move(strides),
234 type,
235 dynamism);
236 }
237
238 /**
239 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
240 * specified properties.
241 *
242 * This template overload is specialized for cases where tensor data is provided
243 * as an initializer list. If the specified `type` differs from the deduced type
244 * of the initializer list's elements, and casting is allowed, the data will be
245 * cast to the specified `type`. This allows for flexible creation of tensors
246 * with data initializer list of one type and a different scalar type.
247 *
248 * @tparam T The C++ type of the tensor elements, deduced from the initializer
249 * list.
250 * @param list An initializer list containing the tensor's data.
251 * @param type The scalar type of the tensor elements. If it differs from the
252 * deduced type, the data will be cast to this type if allowed.
253 * @param dynamism Specifies the mutability of the tensor's shape.
254 * @return A TensorImplPtr that manages the newly created TensorImpl.
255 */
256 template <
257 typename T = float,
258 executorch::aten::ScalarType deduced_type =
259 runtime::CppTypeToScalarType<T>::value>
260 inline TensorImplPtr make_tensor_impl_ptr(
261 std::initializer_list<T> list,
262 executorch::aten::ScalarType type = deduced_type,
263 executorch::aten::TensorShapeDynamism dynamism =
264 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
265 std::vector<executorch::aten::SizesType> sizes{
266 executorch::aten::SizesType(list.size())};
267 return make_tensor_impl_ptr(
268 std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
269 }
270
271 /**
272 * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
273 *
274 * @tparam T The C++ type of the scalar value.
275 * @param value The scalar value used for the Tensor.
276 * @return A TensorImplPtr managing the newly created TensorImpl.
277 */
278 template <typename T>
make_tensor_impl_ptr(T value)279 inline TensorImplPtr make_tensor_impl_ptr(T value) {
280 return make_tensor_impl_ptr({}, std::vector<T>{value});
281 }
282
283 /**
284 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
285 * specified properties.
286 *
287 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
288 * and a scalar type to interpret the data. The vector is managed, and its
289 * lifetime is tied to the TensorImpl.
290 *
291 * @param sizes A vector specifying the size of each dimension.
292 * @param data A vector containing the raw memory buffer for the tensor's data.
293 * @param dim_order A vector specifying the order of dimensions.
294 * @param strides A vector specifying the strides of each dimension.
295 * @param type The scalar type of the tensor elements.
296 * @param dynamism Specifies the mutability of the tensor's shape.
297 * @return A TensorImplPtr managing the newly created TensorImpl.
298 */
299 TensorImplPtr make_tensor_impl_ptr(
300 std::vector<executorch::aten::SizesType> sizes,
301 std::vector<uint8_t> data,
302 std::vector<executorch::aten::DimOrderType> dim_order,
303 std::vector<executorch::aten::StridesType> strides,
304 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
305 executorch::aten::TensorShapeDynamism dynamism =
306 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
307
308 /**
309 * Creates a TensorImplPtr that manages a newly created TensorImpl with the
310 * specified properties.
311 *
312 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
313 * and a scalar type to interpret the data. The vector is managed, and the
314 * memory's lifetime is tied to the TensorImpl.
315 *
316 * @param sizes A vector specifying the size of each dimension.
317 * @param data A vector containing the raw memory for the tensor's data.
318 * @param type The scalar type of the tensor elements.
319 * @param dynamism Specifies the mutability of the tensor's shape.
320 * @return A TensorImplPtr managing the newly created TensorImpl.
321 */
322 inline TensorImplPtr make_tensor_impl_ptr(
323 std::vector<executorch::aten::SizesType> sizes,
324 std::vector<uint8_t> data,
325 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
326 executorch::aten::TensorShapeDynamism dynamism =
327 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
328 return make_tensor_impl_ptr(
329 std::move(sizes), std::move(data), {}, {}, type, dynamism);
330 }
331
332 } // namespace extension
333 } // namespace executorch
334