1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_ 16 #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_ 17 18 #include <cstddef> 19 #include <cstdint> 20 21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 22 #include "tensorflow/lite/c/common.h" 23 #include "tensorflow/lite/core/api/error_reporter.h" 24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 25 #include "tensorflow/lite/micro/micro_allocator.h" 26 #include "tensorflow/lite/micro/micro_op_resolver.h" 27 #include "tensorflow/lite/micro/micro_profiler.h" 28 #include "tensorflow/lite/portable_type_to_tflitetype.h" 29 #include "tensorflow/lite/schema/schema_generated.h" 30 31 // Copied from tensorflow/lite/version.h to avoid a dependency chain into 32 // tensorflow/core. 33 #define TFLITE_SCHEMA_VERSION (3) 34 35 namespace tflite { 36 37 namespace internal { 38 39 // A helper class to encapsulate the implementation of APIs in Context. 40 // context->impl_ points to an instance of this class. 41 // Check tensorflow/lite/c/common.h for detailed descriptions. 42 // TODO(b/16157777): Consider rolling this class into MicroInterpreter. 43 class ContextHelper { 44 public: 45 explicit ContextHelper(ErrorReporter* error_reporter, 46 MicroAllocator* allocator, const Model* model); 47 48 // Functions that will be assigned to function pointers on TfLiteContext: 49 static void* AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes); 50 static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, 51 size_t bytes, 52 int* buffer_idx); 53 static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); 54 static void ReportOpError(struct TfLiteContext* context, const char* format, 55 ...); 56 static TfLiteTensor* GetTensor(const struct TfLiteContext* context, 57 int tensor_idx); 58 static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context, 59 int tensor_idx); 60 61 // Sets the pointer to a list of TfLiteEvalTensor instances. 62 void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors); 63 64 // Sets the pointer to a list of ScratchBufferHandle instances. 65 void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles); 66 67 private: 68 MicroAllocator* allocator_ = nullptr; 69 ErrorReporter* error_reporter_ = nullptr; 70 const Model* model_ = nullptr; 71 TfLiteEvalTensor* eval_tensors_ = nullptr; 72 ScratchBufferHandle* scratch_buffer_handles_ = nullptr; 73 }; 74 75 } // namespace internal 76 77 class MicroInterpreter { 78 public: 79 // The lifetime of the model, op resolver, tensor arena, error reporter and 80 // profiler must be at least as long as that of the interpreter object, since 81 // the interpreter may need to access them at any time. This means that you 82 // should usually create them with the same scope as each other, for example 83 // having them all allocated on the stack as local variables through a 84 // top-level function. The interpreter doesn't do any deallocation of any of 85 // the pointed-to objects, ownership remains with the caller. 86 MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver, 87 uint8_t* tensor_arena, size_t tensor_arena_size, 88 ErrorReporter* error_reporter, 89 MicroProfiler* profiler = nullptr); 90 91 // Create an interpreter instance using an existing MicroAllocator instance. 92 // This constructor should be used when creating an allocator that needs to 93 // have allocation handled in more than one interpreter or for recording 94 // allocations inside the interpreter. The lifetime of the allocator must be 95 // as long as that of the interpreter object. 96 MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver, 97 MicroAllocator* allocator, ErrorReporter* error_reporter, 98 MicroProfiler* profiler = nullptr); 99 100 ~MicroInterpreter(); 101 102 // Runs through the model and allocates all necessary input, output and 103 // intermediate tensors. 104 TfLiteStatus AllocateTensors(); 105 106 // In order to support partial graph runs for strided models, this can return 107 // values other than kTfLiteOk and kTfLiteError. 108 // TODO(b/149795762): Add this to the TfLiteStatus enum. 109 TfLiteStatus Invoke(); 110 tensors_size()111 size_t tensors_size() const { return context_.tensors_size; } 112 TfLiteTensor* tensor(size_t tensor_index); 113 template <class T> typed_tensor(int tensor_index)114 T* typed_tensor(int tensor_index) { 115 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 116 if (tensor_ptr->type == typeToTfLiteType<T>()) { 117 return GetTensorData<T>(tensor_ptr); 118 } 119 } 120 return nullptr; 121 } 122 123 TfLiteTensor* input(size_t index); inputs_size()124 size_t inputs_size() const { return subgraph_->inputs()->Length(); } inputs()125 const flatbuffers::Vector<int32_t>& inputs() const { 126 return *subgraph_->inputs(); 127 } input_tensor(size_t index)128 TfLiteTensor* input_tensor(size_t index) { return input(index); } 129 template <class T> typed_input_tensor(int tensor_index)130 T* typed_input_tensor(int tensor_index) { 131 if (TfLiteTensor* tensor_ptr = input_tensor(tensor_index)) { 132 if (tensor_ptr->type == typeToTfLiteType<T>()) { 133 return GetTensorData<T>(tensor_ptr); 134 } 135 } 136 return nullptr; 137 } 138 139 TfLiteTensor* output(size_t index); outputs_size()140 size_t outputs_size() const { return subgraph_->outputs()->Length(); } outputs()141 const flatbuffers::Vector<int32_t>& outputs() const { 142 return *subgraph_->outputs(); 143 } output_tensor(size_t index)144 TfLiteTensor* output_tensor(size_t index) { return output(index); } 145 template <class T> typed_output_tensor(int tensor_index)146 T* typed_output_tensor(int tensor_index) { 147 if (TfLiteTensor* tensor_ptr = output_tensor(tensor_index)) { 148 if (tensor_ptr->type == typeToTfLiteType<T>()) { 149 return GetTensorData<T>(tensor_ptr); 150 } 151 } 152 return nullptr; 153 } 154 155 // Reset all variable tensors to the default value. 156 TfLiteStatus ResetVariableTensors(); 157 initialization_status()158 TfLiteStatus initialization_status() const { return initialization_status_; } 159 operators_size()160 size_t operators_size() const { return subgraph_->operators()->size(); } 161 162 // For debugging only. node_and_registration(int node_index)163 const NodeAndRegistration node_and_registration(int node_index) const { 164 return node_and_registrations_[node_index]; 165 } 166 167 // For debugging only. 168 // Returns the actual used arena in bytes. This method gives the optimal arena 169 // size. It's only available after `AllocateTensors` has been called. 170 // Note that normally `tensor_arena` requires 16 bytes alignment to fully 171 // utilize the space. If it's not the case, the optimial arena size would be 172 // arena_used_bytes() + 16. arena_used_bytes()173 size_t arena_used_bytes() const { return allocator_.used_bytes(); } 174 175 protected: allocator()176 const MicroAllocator& allocator() const { return allocator_; } context()177 const TfLiteContext& context() const { return context_; } 178 179 private: 180 // TODO(b/158263161): Consider switching to Create() function to enable better 181 // error reporting during initialization. 182 void Init(MicroProfiler* profiler); 183 184 NodeAndRegistration* node_and_registrations_ = nullptr; 185 186 const Model* model_; 187 const MicroOpResolver& op_resolver_; 188 ErrorReporter* error_reporter_; 189 TfLiteContext context_ = {}; 190 MicroAllocator& allocator_; 191 bool tensors_allocated_; 192 193 TfLiteStatus initialization_status_; 194 195 const SubGraph* subgraph_ = nullptr; 196 TfLiteEvalTensor* eval_tensors_ = nullptr; 197 ScratchBufferHandle* scratch_buffer_handles_ = nullptr; 198 199 // TODO(b/16157777): Drop this reference: 200 internal::ContextHelper context_helper_; 201 202 // TODO(b/162311891): Clean these pointers up when this class supports buffers 203 // from TfLiteEvalTensor. 204 TfLiteTensor** input_tensors_; 205 TfLiteTensor** output_tensors_; 206 }; 207 208 } // namespace tflite 209 210 #endif // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_ 211