1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <algorithm>
12 #include <functional>
13 #include <memory>
14 #include <vector>
15
16 #include <executorch/runtime/core/error.h>
17 #include <executorch/runtime/core/exec_aten/exec_aten.h>
18 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
19
20 namespace executorch {
21 namespace extension {
22
23 /**
24 * A smart pointer type for managing the lifecycle of a Tensor.
25 */
26 using TensorPtr = std::shared_ptr<executorch::aten::Tensor>;
27
28 /**
29 * Creates a TensorPtr that manages a Tensor with the specified properties.
30 *
31 * @param sizes A vector specifying the size of each dimension.
32 * @param data A pointer to the data buffer.
33 * @param dim_order A vector specifying the order of dimensions.
34 * @param strides A vector specifying the strides of the tensor.
35 * @param type The scalar type of the tensor elements.
36 * @param dynamism Specifies the mutability of the tensor's shape.
37 * @param deleter A custom deleter function for managing the lifetime of the
38 * data buffer. If provided, this deleter will be called when the managed Tensor
39 * object is destroyed.
40 * @return A TensorPtr that manages the newly created Tensor.
41 */
42 TensorPtr make_tensor_ptr(
43 std::vector<executorch::aten::SizesType> sizes,
44 void* data,
45 std::vector<executorch::aten::DimOrderType> dim_order,
46 std::vector<executorch::aten::StridesType> strides,
47 const executorch::aten::ScalarType type =
48 executorch::aten::ScalarType::Float,
49 const executorch::aten::TensorShapeDynamism dynamism =
50 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
51 std::function<void(void*)> deleter = nullptr);
52
53 /**
54 * Creates a TensorPtr that manages a Tensor with the specified properties.
55 *
56 * @param sizes A vector specifying the size of each dimension.
57 * @param data A pointer to the data buffer.
58 * @param type The scalar type of the tensor elements.
59 * @param dynamism Specifies the mutability of the tensor's shape.
60 * @param deleter A custom deleter function for managing the lifetime of the
61 * data buffer. If provided, this deleter will be called when the managed Tensor
62 * object is destroyed.
63 * @return A TensorPtr that manages the newly created Tensor.
64 */
65 inline TensorPtr make_tensor_ptr(
66 std::vector<executorch::aten::SizesType> sizes,
67 void* data,
68 const executorch::aten::ScalarType type =
69 executorch::aten::ScalarType::Float,
70 const executorch::aten::TensorShapeDynamism dynamism =
71 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
72 std::function<void(void*)> deleter = nullptr) {
73 return make_tensor_ptr(
74 std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
75 }
76
77 /**
78 * Creates a TensorPtr that manages a Tensor with the specified properties.
79 *
80 * This template overload is specialized for cases where the tensor data is
81 * provided as a vector. The scalar type is automatically deduced from the
82 * vector's data type. If the specified `type` differs from the deduced type of
83 * the vector's elements, and casting is allowed, the data will be cast to the
84 * specified `type`. This allows for flexible creation of tensors with data
85 * vectors of one type and a different scalar type.
86 *
87 * @tparam T The C++ type of the tensor elements, deduced from the vector.
88 * @param sizes A vector specifying the size of each dimension.
89 * @param data A vector containing the tensor's data.
90 * @param dim_order A vector specifying the order of dimensions.
91 * @param strides A vector specifying the strides of each dimension.
92 * @param type The scalar type of the tensor elements. If it differs from the
93 * deduced type, the data will be cast to this type if allowed.
94 * @param dynamism Specifies the mutability of the tensor's shape.
95 * @return A TensorPtr that manages the newly created TensorImpl.
96 */
97 template <
98 typename T = float,
99 executorch::aten::ScalarType deduced_type =
100 runtime::CppTypeToScalarType<T>::value>
101 inline TensorPtr make_tensor_ptr(
102 std::vector<executorch::aten::SizesType> sizes,
103 std::vector<T> data,
104 std::vector<executorch::aten::DimOrderType> dim_order = {},
105 std::vector<executorch::aten::StridesType> strides = {},
106 executorch::aten::ScalarType type = deduced_type,
107 executorch::aten::TensorShapeDynamism dynamism =
108 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
109 if (type != deduced_type) {
110 ET_CHECK_MSG(
111 runtime::canCast(deduced_type, type),
112 "Cannot cast deduced type to specified type.");
113 std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
114 ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "make_tensor_ptr", CTYPE, [&] {
115 std::transform(
116 data.begin(),
117 data.end(),
118 reinterpret_cast<CTYPE*>(casted_data.data()),
119 [](const T& val) { return static_cast<CTYPE>(val); });
120 });
121 const auto raw_data_ptr = casted_data.data();
122 auto data_ptr =
123 std::make_shared<std::vector<uint8_t>>(std::move(casted_data));
124 return make_tensor_ptr(
125 std::move(sizes),
126 raw_data_ptr,
127 std::move(dim_order),
128 std::move(strides),
129 type,
130 dynamism,
131 [data_ptr = std::move(data_ptr)](void*) {});
132 }
133 const auto raw_data_ptr = data.data();
134 auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
135 return make_tensor_ptr(
136 std::move(sizes),
137 raw_data_ptr,
138 std::move(dim_order),
139 std::move(strides),
140 type,
141 dynamism,
142 [data_ptr = std::move(data_ptr)](void*) {});
143 }
144
145 /**
146 * Creates a TensorPtr that manages a Tensor with the specified properties.
147 *
148 * This template overload is specialized for cases where the tensor data is
149 * provided as a vector. The scalar type is automatically deduced from the
150 * vector's data type. If the specified `type` differs from the deduced type of
151 * the vector's elements, and casting is allowed, the data will be cast to the
152 * specified `type`. This allows for flexible creation of tensors with data
153 * vectors of one type and a different scalar type.
154 *
155 * @tparam T The C++ type of the tensor elements, deduced from the vector.
156 * @param data A vector containing the tensor's data.
157 * @param type The scalar type of the tensor elements. If it differs from the
158 * deduced type, the data will be cast to this type if allowed.
159 * @param dynamism Specifies the mutability of the tensor's shape.
160 * @return A TensorPtr that manages the newly created TensorImpl.
161 */
162 template <
163 typename T = float,
164 executorch::aten::ScalarType deduced_type =
165 runtime::CppTypeToScalarType<T>::value>
166 inline TensorPtr make_tensor_ptr(
167 std::vector<T> data,
168 executorch::aten::ScalarType type = deduced_type,
169 executorch::aten::TensorShapeDynamism dynamism =
170 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
171 std::vector<executorch::aten::SizesType> sizes{
172 executorch::aten::SizesType(data.size())};
173 return make_tensor_ptr(
174 std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
175 }
176
177 /**
178 * Creates a TensorPtr that manages a Tensor with the specified properties.
179 *
180 * This template overload is specialized for cases where the tensor data is
181 * provided as an initializer list. The scalar type is automatically deduced
182 * from the initializer list's data type. If the specified `type` differs from
183 * the deduced type of the initializer list's elements, and casting is allowed,
184 * the data will be cast to the specified `type`. This allows for flexible
185 * creation of tensors with data vectors of one type and a different scalar
186 * type.
187 *
188 * @tparam T The C++ type of the tensor elements, deduced from the initializer
189 * list.
190 * @param sizes A vector specifying the size of each dimension.
191 * @param list An initializer list containing the tensor's data.
192 * @param dim_order A vector specifying the order of dimensions.
193 * @param strides A vector specifying the strides of each dimension.
194 * @param type The scalar type of the tensor elements. If it differs from the
195 * deduced type, the data will be cast to this type if allowed.
196 * @param dynamism Specifies the mutability of the tensor's shape.
197 * @return A TensorPtr that manages the newly created TensorImpl.
198 */
199 template <
200 typename T = float,
201 executorch::aten::ScalarType deduced_type =
202 runtime::CppTypeToScalarType<T>::value>
203 inline TensorPtr make_tensor_ptr(
204 std::vector<executorch::aten::SizesType> sizes,
205 std::initializer_list<T> list,
206 std::vector<executorch::aten::DimOrderType> dim_order = {},
207 std::vector<executorch::aten::StridesType> strides = {},
208 executorch::aten::ScalarType type = deduced_type,
209 executorch::aten::TensorShapeDynamism dynamism =
210 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
211 return make_tensor_ptr(
212 std::move(sizes),
213 std::vector<T>(std::move(list)),
214 std::move(dim_order),
215 std::move(strides),
216 type,
217 dynamism);
218 }
219
220 /**
221 * Creates a TensorPtr that manages a Tensor with the specified properties.
222 *
223 * This template overload allows creating a Tensor from an initializer list
224 * of data. The scalar type is automatically deduced from the type of the
225 * initializer list's elements. If the specified `type` differs from
226 * the deduced type of the initializer list's elements, and casting is allowed,
227 * the data will be cast to the specified `type`. This allows for flexible
228 * creation of tensors with data vectors of one type and a different scalar
229 * type.
230 *
231 * @tparam T The C++ type of the tensor elements, deduced from the initializer
232 * list.
233 * @param list An initializer list containing the tensor's data.
234 * @param type The scalar type of the tensor elements. If it differs from the
235 * deduced type, the data will be cast to this type if allowed.
236 * @param dynamism Specifies the mutability of the tensor's shape.
237 * @return A TensorPtr that manages the newly created TensorImpl.
238 */
239 template <
240 typename T = float,
241 executorch::aten::ScalarType deduced_type =
242 runtime::CppTypeToScalarType<T>::value>
243 inline TensorPtr make_tensor_ptr(
244 std::initializer_list<T> list,
245 executorch::aten::ScalarType type = deduced_type,
246 executorch::aten::TensorShapeDynamism dynamism =
247 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
248 std::vector<executorch::aten::SizesType> sizes{
249 executorch::aten::SizesType(list.size())};
250 return make_tensor_ptr(
251 std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
252 }
253
254 /**
255 * Creates a TensorPtr that manages a Tensor with a single scalar value.
256 *
257 * @tparam T The C++ type of the scalar value.
258 * @param value The scalar value to be used for the Tensor.
259 * @return A TensorPtr that manages the newly created TensorImpl.
260 */
261 template <typename T>
make_tensor_ptr(T value)262 inline TensorPtr make_tensor_ptr(T value) {
263 return make_tensor_ptr({}, std::vector<T>{value});
264 }
265
266 /**
267 * Creates a TensorPtr that manages a Tensor with the specified properties.
268 *
269 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
270 * and a scalar type to interpret the data. The vector is managed, and the
271 * memory's lifetime is tied to the TensorImpl.
272 *
273 * @param sizes A vector specifying the size of each dimension.
274 * @param data A vector containing the raw memory for the tensor's data.
275 * @param dim_order A vector specifying the order of dimensions.
276 * @param strides A vector specifying the strides of each dimension.
277 * @param type The scalar type of the tensor elements.
278 * @param dynamism Specifies the mutability of the tensor's shape.
279 * @return A TensorPtr managing the newly created Tensor.
280 */
281 TensorPtr make_tensor_ptr(
282 std::vector<executorch::aten::SizesType> sizes,
283 std::vector<uint8_t> data,
284 std::vector<executorch::aten::DimOrderType> dim_order,
285 std::vector<executorch::aten::StridesType> strides,
286 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
287 executorch::aten::TensorShapeDynamism dynamism =
288 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
289
290 /**
291 * Creates a TensorPtr that manages a Tensor with the specified properties.
292 *
293 * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
294 * and a scalar type to interpret the data. The vector is managed, and the
295 * memory's lifetime is tied to the TensorImpl.
296 *
297 * @param sizes A vector specifying the size of each dimension.
298 * @param data A vector containing the raw memory for the tensor's data.
299 * @param type The scalar type of the tensor elements.
300 * @param dynamism Specifies the mutability of the tensor's shape.
301 * @return A TensorPtr managing the newly created Tensor.
302 */
303 inline TensorPtr make_tensor_ptr(
304 std::vector<executorch::aten::SizesType> sizes,
305 std::vector<uint8_t> data,
306 executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
307 executorch::aten::TensorShapeDynamism dynamism =
308 executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
309 return make_tensor_ptr(
310 std::move(sizes), std::move(data), {}, {}, type, dynamism);
311 }
312
313 /**
314 * Creates a TensorPtr to manage a new Tensor with the same properties
315 * as the given Tensor, sharing the same data without owning it.
316 *
317 * @param tensor The Tensor whose properties are used to create a new TensorPtr.
318 * @return A new TensorPtr managing a Tensor with the same properties as the
319 * original.
320 */
make_tensor_ptr(const executorch::aten::Tensor & tensor)321 inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
322 return make_tensor_ptr(
323 std::vector<executorch::aten::SizesType>(
324 tensor.sizes().begin(), tensor.sizes().end()),
325 tensor.mutable_data_ptr(),
326 #ifndef USE_ATEN_LIB
327 std::vector<executorch::aten::DimOrderType>(
328 tensor.dim_order().begin(), tensor.dim_order().end()),
329 std::vector<executorch::aten::StridesType>(
330 tensor.strides().begin(), tensor.strides().end()),
331 tensor.scalar_type(),
332 tensor.shape_dynamism()
333 #else // USE_ATEN_LIB
334 {},
335 std::vector<executorch::aten::StridesType>(
336 tensor.strides().begin(), tensor.strides().end()),
337 tensor.scalar_type()
338 #endif // USE_ATEN_LIB
339 );
340 }
341
342 /**
343 * Creates a TensorPtr that manages a new Tensor with the same properties
344 * as the given Tensor, but with a copy of the data owned by the returned
345 * TensorPtr, or nullptr if the original data is null.
346 *
347 * @param tensor The Tensor to clone.
348 * @return A new TensorPtr that manages a Tensor with the same properties as the
349 * original but with copied data.
350 */
351 TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor);
352
353 /**
354 * Creates a new TensorPtr by cloning the given TensorPtr, copying the
355 * underlying data.
356 *
357 * @param tensor The TensorPtr to clone.
358 * @return A new TensorPtr that manages a Tensor with the same properties as the
359 * original but with copied data.
360 */
clone_tensor_ptr(const TensorPtr & tensor)361 inline TensorPtr clone_tensor_ptr(const TensorPtr& tensor) {
362 return clone_tensor_ptr(*tensor);
363 }
364
365 /**
366 * Resizes the Tensor managed by the provided TensorPtr to the new sizes.
367 *
368 * @param tensor A TensorPtr managing the Tensor to resize.
369 * @param sizes A vector representing the new sizes for each dimension.
370 * @return Error::Ok on success, or an appropriate error code on failure.
371 */
372 ET_NODISCARD
373 runtime::Error resize_tensor_ptr(
374 TensorPtr& tensor,
375 const std::vector<executorch::aten::SizesType>& sizes);
376
377 } // namespace extension
378 } // namespace executorch