• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <functional>
12 #include <memory>
13 #include <vector>
14 
15 #include <executorch/runtime/core/exec_aten/exec_aten.h>
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17 #include <executorch/runtime/platform/compiler.h>
18 
19 namespace executorch {
20 namespace extension {
21 
22 #ifndef USE_ATEN_LIB
23 /**
24  * A smart pointer for managing the lifecycle of a TensorImpl.
25  *
26  * TensorImplPtr uses a shared pointer since multiple Tensor objects may
27  * share the same underlying data and metadata. This shared ownership ensures
28  * that the TensorImpl is destroyed only when all references to it are gone,
29  * providing a safe and efficient way to manage shared tensor implementations.
30  * It serves as a safer, more convenient alternative to the original TensorImpl,
31  * which does not manage its metadata by design.
32  */
33 using TensorImplPtr = std::shared_ptr<executorch::aten::TensorImpl>;
34 #else
35 /**
36  * A smart pointer type for managing the lifecycle of a TensorImpl.
37  *
38  * TensorImplPtr uses an intrusive pointer when working with ATen, ensuring
39  * efficient reference counting and shared ownership of the underlying data and
40  * metadata.
41  */
42 using TensorImplPtr =
43     c10::intrusive_ptr<executorch::aten::TensorImpl, at::UndefinedTensorImpl>;
44 #endif // USE_ATEN_LIB
45 
46 /**
47  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
48  * specified properties.
49  *
50  * @param sizes A vector specifying the size of each dimension.
51  * @param data A pointer to the data buffer.
52  * @param dim_order A vector specifying the order of dimensions.
53  * @param strides A vector specifying the strides of each dimension.
54  * @param type The scalar type of the tensor elements.
55  * @param dynamism Specifies the mutability of the tensor's shape.
56  * @param deleter A custom deleter function for managing the lifetime of the
57  * data buffer. If provided, this deleter is called when the managed TensorImpl
58  * is destroyed.
59  * @return A TensorImplPtr managing the newly created TensorImpl.
60  */
61 TensorImplPtr make_tensor_impl_ptr(
62     std::vector<executorch::aten::SizesType> sizes,
63     void* data,
64     std::vector<executorch::aten::DimOrderType> dim_order,
65     std::vector<executorch::aten::StridesType> strides,
66     executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
67     executorch::aten::TensorShapeDynamism dynamism =
68         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
69     std::function<void(void*)> deleter = nullptr);
70 
71 /**
72  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
73  * specified properties.
74  *
75  * @param sizes A vector specifying the size of each dimension.
76  * @param data A pointer to the data buffer.
77  * @param type The scalar type of the tensor elements.
78  * @param dynamism Specifies the mutability of the tensor's shape.
79  * @param deleter A custom deleter function for managing the lifetime of the
80  * data buffer. If provided, this deleter is called when the managed TensorImpl
81  * is destroyed.
82  * @return A TensorImplPtr managing the newly created TensorImpl.
83  */
84 inline TensorImplPtr make_tensor_impl_ptr(
85     std::vector<executorch::aten::SizesType> sizes,
86     void* data,
87     executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
88     executorch::aten::TensorShapeDynamism dynamism =
89         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
90     std::function<void(void*)> deleter = nullptr) {
91   return make_tensor_impl_ptr(
92       std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
93 }
94 
95 /**
96  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
97  * specified properties.
98  *
99  * This template overload is specialized for cases where tensor data is provided
100  * as a vector. If the specified `type` differs from the deduced type of the
101  * vector's elements, and casting is allowed, the data will be cast to the
102  * specified `type`. This allows for flexible creation of tensors with data
103  * vectors of one type and a different scalar type.
104  *
105  * @tparam T The C++ type of the tensor elements, deduced from the vector.
106  * @param sizes A vector specifying the size of each dimension.
107  * @param data A vector containing the tensor's data.
108  * @param dim_order A vector specifying the order of dimensions.
109  * @param strides A vector specifying the strides of each dimension.
110  * @param type The scalar type of the tensor elements. If it differs from the
111  * deduced type, the data will be cast to this type if allowed.
112  * @param dynamism Specifies the mutability of the tensor's shape.
113  * @return A TensorImplPtr that manages the newly created TensorImpl.
114  */
115 template <
116     typename T = float,
117     executorch::aten::ScalarType deduced_type =
118         runtime::CppTypeToScalarType<T>::value>
119 TensorImplPtr make_tensor_impl_ptr(
120     std::vector<executorch::aten::SizesType> sizes,
121     std::vector<T> data,
122     std::vector<executorch::aten::DimOrderType> dim_order = {},
123     std::vector<executorch::aten::StridesType> strides = {},
124     executorch::aten::ScalarType type = deduced_type,
125     executorch::aten::TensorShapeDynamism dynamism =
126         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
127   if (type != deduced_type) {
128     ET_CHECK_MSG(
129         runtime::canCast(deduced_type, type),
130         "Cannot cast deduced type to specified type.");
131     std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
132     ET_SWITCH_REALHBBF16_TYPES(
133         type, nullptr, "make_tensor_impl_ptr", CTYPE, [&] {
134           std::transform(
135               data.begin(),
136               data.end(),
137               reinterpret_cast<CTYPE*>(casted_data.data()),
138               [](const T& val) { return static_cast<CTYPE>(val); });
139         });
140     const auto raw_data_ptr = casted_data.data();
141     auto data_ptr =
142         std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
143     return make_tensor_impl_ptr(
144         std::move(sizes),
145         raw_data_ptr,
146         std::move(dim_order),
147         std::move(strides),
148         type,
149         dynamism,
150         [data_ptr = std::move(data_ptr)](void*) {});
151   }
152   const auto raw_data_ptr = data.data();
153   auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
154   return make_tensor_impl_ptr(
155       std::move(sizes),
156       raw_data_ptr,
157       std::move(dim_order),
158       std::move(strides),
159       type,
160       dynamism,
161       [data_ptr = std::move(data_ptr)](void*) {});
162 }
163 
164 /**
165  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
166  * specified properties.
167  *
168  * This template overload is specialized for cases where tensor data is provided
169  * as a vector. If the specified `type` differs from the deduced type of the
170  * vector's elements, and casting is allowed, the data will be cast to the
171  * specified `type`. This allows for flexible creation of tensors with data
172  * vectors of one type and a different scalar type.
173  *
174  * @tparam T The C++ type of the tensor elements, deduced from the vector.
175  * @param data A vector containing the tensor's data.
176  * @param type The scalar type of the tensor elements. If it differs from the
177  * deduced type, the data will be cast to this type if allowed.
178  * @param dynamism Specifies the mutability of the tensor's shape.
179  * @return A TensorImplPtr that manages the newly created TensorImpl.
180  */
181 template <
182     typename T = float,
183     executorch::aten::ScalarType deduced_type =
184         runtime::CppTypeToScalarType<T>::value>
185 inline TensorImplPtr make_tensor_impl_ptr(
186     std::vector<T> data,
187     executorch::aten::ScalarType type = deduced_type,
188     executorch::aten::TensorShapeDynamism dynamism =
189         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
190   std::vector<executorch::aten::SizesType> sizes{
191       executorch::aten::SizesType(data.size())};
192   return make_tensor_impl_ptr(
193       std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
194 }
195 
196 /**
197  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
198  * specified properties.
199  *
200  * This template overload is specialized for cases where tensor data is provided
201  * as an initializer list. If the specified `type` differs from the deduced type
202  * of the initializer list's elements, and casting is allowed, the data will be
203  * cast to the specified `type`. This allows for flexible creation of tensors
204  * with data initializer list of one type and a different scalar type.
205  *
206  * @tparam T The C++ type of the tensor elements, deduced from the initializer
207  * list.
208  * @param sizes A vector specifying the size of each dimension.
209  * @param list An initializer list containing the tensor's data.
210  * @param dim_order A vector specifying the order of dimensions.
211  * @param strides A vector specifying the strides of each dimension.
212  * @param type The scalar type of the tensor elements. If it differs from the
213  * deduced type, the data will be cast to this type if allowed.
214  * @param dynamism Specifies the mutability of the tensor's shape.
215  * @return A TensorImplPtr that manages the newly created TensorImpl.
216  */
217 template <
218     typename T = float,
219     executorch::aten::ScalarType deduced_type =
220         runtime::CppTypeToScalarType<T>::value>
221 inline TensorImplPtr make_tensor_impl_ptr(
222     std::vector<executorch::aten::SizesType> sizes,
223     std::initializer_list<T> list,
224     std::vector<executorch::aten::DimOrderType> dim_order = {},
225     std::vector<executorch::aten::StridesType> strides = {},
226     executorch::aten::ScalarType type = deduced_type,
227     executorch::aten::TensorShapeDynamism dynamism =
228         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
229   return make_tensor_impl_ptr(
230       std::move(sizes),
231       std::vector<T>(std::move(list)),
232       std::move(dim_order),
233       std::move(strides),
234       type,
235       dynamism);
236 }
237 
238 /**
239  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
240  * specified properties.
241  *
242  * This template overload is specialized for cases where tensor data is provided
243  * as an initializer list. If the specified `type` differs from the deduced type
244  * of the initializer list's elements, and casting is allowed, the data will be
245  * cast to the specified `type`. This allows for flexible creation of tensors
246  * with data initializer list of one type and a different scalar type.
247  *
248  * @tparam T The C++ type of the tensor elements, deduced from the initializer
249  * list.
250  * @param list An initializer list containing the tensor's data.
251  * @param type The scalar type of the tensor elements. If it differs from the
252  * deduced type, the data will be cast to this type if allowed.
253  * @param dynamism Specifies the mutability of the tensor's shape.
254  * @return A TensorImplPtr that manages the newly created TensorImpl.
255  */
256 template <
257     typename T = float,
258     executorch::aten::ScalarType deduced_type =
259         runtime::CppTypeToScalarType<T>::value>
260 inline TensorImplPtr make_tensor_impl_ptr(
261     std::initializer_list<T> list,
262     executorch::aten::ScalarType type = deduced_type,
263     executorch::aten::TensorShapeDynamism dynamism =
264         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
265   std::vector<executorch::aten::SizesType> sizes{
266       executorch::aten::SizesType(list.size())};
267   return make_tensor_impl_ptr(
268       std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
269 }
270 
271 /**
272  * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
273  *
274  * @tparam T The C++ type of the scalar value.
275  * @param value The scalar value used for the Tensor.
276  * @return A TensorImplPtr managing the newly created TensorImpl.
277  */
278 template <typename T>
make_tensor_impl_ptr(T value)279 inline TensorImplPtr make_tensor_impl_ptr(T value) {
280   return make_tensor_impl_ptr({}, std::vector<T>{value});
281 }
282 
283 /**
284  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
285  * specified properties.
286  *
287  * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
288  * and a scalar type to interpret the data. The vector is managed, and its
289  * lifetime is tied to the TensorImpl.
290  *
291  * @param sizes A vector specifying the size of each dimension.
292  * @param data A vector containing the raw memory buffer for the tensor's data.
293  * @param dim_order A vector specifying the order of dimensions.
294  * @param strides A vector specifying the strides of each dimension.
295  * @param type The scalar type of the tensor elements.
296  * @param dynamism Specifies the mutability of the tensor's shape.
297  * @return A TensorImplPtr managing the newly created TensorImpl.
298  */
299 TensorImplPtr make_tensor_impl_ptr(
300     std::vector<executorch::aten::SizesType> sizes,
301     std::vector<uint8_t> data,
302     std::vector<executorch::aten::DimOrderType> dim_order,
303     std::vector<executorch::aten::StridesType> strides,
304     executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
305     executorch::aten::TensorShapeDynamism dynamism =
306         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
307 
308 /**
309  * Creates a TensorImplPtr that manages a newly created TensorImpl with the
310  * specified properties.
311  *
312  * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
313  * and a scalar type to interpret the data. The vector is managed, and the
314  * memory's lifetime is tied to the TensorImpl.
315  *
316  * @param sizes A vector specifying the size of each dimension.
317  * @param data A vector containing the raw memory for the tensor's data.
318  * @param type The scalar type of the tensor elements.
319  * @param dynamism Specifies the mutability of the tensor's shape.
320  * @return A TensorImplPtr managing the newly created TensorImpl.
321  */
322 inline TensorImplPtr make_tensor_impl_ptr(
323     std::vector<executorch::aten::SizesType> sizes,
324     std::vector<uint8_t> data,
325     executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
326     executorch::aten::TensorShapeDynamism dynamism =
327         executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
328   return make_tensor_impl_ptr(
329       std::move(sizes), std::move(data), {}, {}, type, dynamism);
330 }
331 
332 } // namespace extension
333 } // namespace executorch
334