• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/executor/tensor_parser.h>
10 
11 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13 #include <executorch/runtime/executor/memory_manager.h>
14 #include <executorch/runtime/executor/program.h>
15 #include <executorch/runtime/platform/profiler.h>
16 #include <executorch/schema/program_generated.h>
17 
18 #include <ATen/ATen.h> // @donotremove @manual=//caffe2/aten:ATen-core
19 
20 namespace executorch {
21 namespace runtime {
22 namespace deserialization {
23 
24 namespace {
25 
26 void deleteNothing(void*);
deleteNothing(void *)27 void deleteNothing(void*) {}
28 
29 } // namespace
30 
parseTensor(const Program * program,MemoryManager * memory_manager,const executorch_flatbuffer::Tensor * s_tensor)31 Result<at::Tensor> parseTensor(
32     const Program* program,
33     MemoryManager* memory_manager,
34     const executorch_flatbuffer::Tensor* s_tensor) {
35   EXECUTORCH_SCOPE_PROF("TensorParser::parseTensor");
36 
37   ET_CHECK_OR_RETURN_ERROR(
38       s_tensor->storage_offset() == 0,
39       NotSupported,
40       "Non-zero storage offset %" PRId32 " not supported",
41       s_tensor->storage_offset());
42 
43   // get metadata
44   at::ScalarType type = static_cast<at::ScalarType>(s_tensor->scalar_type());
45   ET_CHECK_OR_RETURN_ERROR(
46       isValid(type),
47       InvalidProgram,
48       "Invalid ScalarType %" PRId8,
49       static_cast<int8_t>(type));
50   auto options = at::CPU(type).options();
51 
52   ET_CHECK_OR_RETURN_ERROR(
53       s_tensor->sizes() != nullptr, InvalidProgram, "Missing sizes field");
54   size_t ndim = s_tensor->sizes()->size();
55 
56   ET_CHECK_OR_RETURN_ERROR(
57       s_tensor->dim_order() != nullptr,
58       InvalidProgram,
59       "Missing dim_order field");
60   ET_CHECK_OR_RETURN_ERROR(
61       s_tensor->dim_order()->size() == ndim,
62       InvalidProgram,
63       "dim_order size %" PRIu32 " != ndim %zu",
64       s_tensor->dim_order()->size(),
65       ndim);
66 
67   // convert int32 in serialization to int64 for aten
68   std::vector<int64_t> sizes(
69       s_tensor->sizes()->begin(), s_tensor->sizes()->end());
70   std::vector<int64_t> strides(ndim);
71   auto status = dim_order_to_stride(
72       s_tensor->sizes()->data(),
73       s_tensor->dim_order()->data(),
74       ndim,
75       strides.data());
76   ET_CHECK_OR_RETURN_ERROR(
77       status == Error::Ok,
78       Internal,
79       "dim_order_to_stride returned invalid status");
80 
81   // Create a tensor without data first so we can find its expected size before
82   // getting its memory.
83   at::Tensor tensor = at::from_blob(
84       /*data=*/nullptr,
85       sizes,
86       strides,
87       /*storage_offset=*/0,
88       deleteNothing,
89       options);
90 
91   if (s_tensor->shape_dynamism() ==
92       executorch_flatbuffer::TensorShapeDynamism::DYNAMIC_UNBOUND) {
93     // Provide fully dynamic tensors with an allocator so they can be resized
94     // within aten kernels.
95     auto impl = tensor.unsafeGetTensorImpl();
96     at::StorageImpl* storage = impl->unsafe_storage().unsafeGetStorageImpl();
97     storage->set_allocator(at::getCPUAllocator());
98     storage->set_resizable(true);
99     storage->set_nbytes(0);
100     impl->set_sizes_contiguous(0);
101     // Leave the data as nullptr since it will be reallocated.
102   } else {
103     // Now that we know how big the tensor is, find and assign its memory.
104     Result<void*> data_ptr = getTensorDataPtr(
105         s_tensor, program, tensor.nbytes(), memory_manager->planned_memory());
106     if (!data_ptr.ok()) {
107       ET_LOG(
108           Error,
109           "getTensorDataPtr() failed: 0x%" PRIx32,
110           static_cast<uint32_t>(data_ptr.error()));
111       return data_ptr.error();
112     }
113     tensor.unsafeGetTensorImpl()->unsafe_storage().set_data_ptr(
114         at::DataPtr(data_ptr.get(), c10::DeviceType::CPU));
115   }
116 
117   return tensor;
118 }
119 
120 } // namespace deserialization
121 } // namespace runtime
122 } // namespace executorch
123