• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
16 #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
17 
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/core/api/error_reporter.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/micro/micro_allocator.h"
26 #include "tensorflow/lite/micro/micro_op_resolver.h"
27 #include "tensorflow/lite/micro/micro_profiler.h"
28 #include "tensorflow/lite/portable_type_to_tflitetype.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 
31 // Copied from tensorflow/lite/version.h to avoid a dependency chain into
32 // tensorflow/core.
33 #define TFLITE_SCHEMA_VERSION (3)
34 
35 namespace tflite {
36 
37 namespace internal {
38 
39 // A helper class to encapsulate the implementation of APIs in Context.
40 // context->impl_ points to an instance of this class.
41 // Check tensorflow/lite/c/common.h for detailed descriptions.
42 // TODO(b/16157777): Consider rolling this class into MicroInterpreter.
43 class ContextHelper {
44  public:
45   explicit ContextHelper(ErrorReporter* error_reporter,
46                          MicroAllocator* allocator, const Model* model);
47 
48   // Functions that will be assigned to function pointers on TfLiteContext:
49   static void* AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes);
50   static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
51                                                   size_t bytes,
52                                                   int* buffer_idx);
53   static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
54   static void ReportOpError(struct TfLiteContext* context, const char* format,
55                             ...);
56   static TfLiteTensor* GetTensor(const struct TfLiteContext* context,
57                                  int tensor_idx);
58   static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
59                                          int tensor_idx);
60 
61   // Sets the pointer to a list of TfLiteEvalTensor instances.
62   void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors);
63 
64   // Sets the pointer to a list of ScratchBufferHandle instances.
65   void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles);
66 
67  private:
68   MicroAllocator* allocator_ = nullptr;
69   ErrorReporter* error_reporter_ = nullptr;
70   const Model* model_ = nullptr;
71   TfLiteEvalTensor* eval_tensors_ = nullptr;
72   ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
73 };
74 
75 }  // namespace internal
76 
77 class MicroInterpreter {
78  public:
79   // The lifetime of the model, op resolver, tensor arena, error reporter and
80   // profiler must be at least as long as that of the interpreter object, since
81   // the interpreter may need to access them at any time. This means that you
82   // should usually create them with the same scope as each other, for example
83   // having them all allocated on the stack as local variables through a
84   // top-level function. The interpreter doesn't do any deallocation of any of
85   // the pointed-to objects, ownership remains with the caller.
86   MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
87                    uint8_t* tensor_arena, size_t tensor_arena_size,
88                    ErrorReporter* error_reporter,
89                    MicroProfiler* profiler = nullptr);
90 
91   // Create an interpreter instance using an existing MicroAllocator instance.
92   // This constructor should be used when creating an allocator that needs to
93   // have allocation handled in more than one interpreter or for recording
94   // allocations inside the interpreter. The lifetime of the allocator must be
95   // as long as that of the interpreter object.
96   MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
97                    MicroAllocator* allocator, ErrorReporter* error_reporter,
98                    MicroProfiler* profiler = nullptr);
99 
100   ~MicroInterpreter();
101 
102   // Runs through the model and allocates all necessary input, output and
103   // intermediate tensors.
104   TfLiteStatus AllocateTensors();
105 
106   // In order to support partial graph runs for strided models, this can return
107   // values other than kTfLiteOk and kTfLiteError.
108   // TODO(b/149795762): Add this to the TfLiteStatus enum.
109   TfLiteStatus Invoke();
110 
tensors_size()111   size_t tensors_size() const { return context_.tensors_size; }
112   TfLiteTensor* tensor(size_t tensor_index);
113   template <class T>
typed_tensor(int tensor_index)114   T* typed_tensor(int tensor_index) {
115     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
116       if (tensor_ptr->type == typeToTfLiteType<T>()) {
117         return GetTensorData<T>(tensor_ptr);
118       }
119     }
120     return nullptr;
121   }
122 
123   TfLiteTensor* input(size_t index);
inputs_size()124   size_t inputs_size() const { return subgraph_->inputs()->Length(); }
inputs()125   const flatbuffers::Vector<int32_t>& inputs() const {
126     return *subgraph_->inputs();
127   }
input_tensor(size_t index)128   TfLiteTensor* input_tensor(size_t index) { return input(index); }
129   template <class T>
typed_input_tensor(int tensor_index)130   T* typed_input_tensor(int tensor_index) {
131     if (TfLiteTensor* tensor_ptr = input_tensor(tensor_index)) {
132       if (tensor_ptr->type == typeToTfLiteType<T>()) {
133         return GetTensorData<T>(tensor_ptr);
134       }
135     }
136     return nullptr;
137   }
138 
139   TfLiteTensor* output(size_t index);
outputs_size()140   size_t outputs_size() const { return subgraph_->outputs()->Length(); }
outputs()141   const flatbuffers::Vector<int32_t>& outputs() const {
142     return *subgraph_->outputs();
143   }
output_tensor(size_t index)144   TfLiteTensor* output_tensor(size_t index) { return output(index); }
145   template <class T>
typed_output_tensor(int tensor_index)146   T* typed_output_tensor(int tensor_index) {
147     if (TfLiteTensor* tensor_ptr = output_tensor(tensor_index)) {
148       if (tensor_ptr->type == typeToTfLiteType<T>()) {
149         return GetTensorData<T>(tensor_ptr);
150       }
151     }
152     return nullptr;
153   }
154 
155   // Reset all variable tensors to the default value.
156   TfLiteStatus ResetVariableTensors();
157 
initialization_status()158   TfLiteStatus initialization_status() const { return initialization_status_; }
159 
operators_size()160   size_t operators_size() const { return subgraph_->operators()->size(); }
161 
162   // For debugging only.
node_and_registration(int node_index)163   const NodeAndRegistration node_and_registration(int node_index) const {
164     return node_and_registrations_[node_index];
165   }
166 
167   // For debugging only.
168   // Returns the actual used arena in bytes. This method gives the optimal arena
169   // size. It's only available after `AllocateTensors` has been called.
170   // Note that normally `tensor_arena` requires 16 bytes alignment to fully
171   // utilize the space. If it's not the case, the optimial arena size would be
172   // arena_used_bytes() + 16.
arena_used_bytes()173   size_t arena_used_bytes() const { return allocator_.used_bytes(); }
174 
175  protected:
allocator()176   const MicroAllocator& allocator() const { return allocator_; }
context()177   const TfLiteContext& context() const { return context_; }
178 
179  private:
180   // TODO(b/158263161): Consider switching to Create() function to enable better
181   // error reporting during initialization.
182   void Init(MicroProfiler* profiler);
183 
184   NodeAndRegistration* node_and_registrations_ = nullptr;
185 
186   const Model* model_;
187   const MicroOpResolver& op_resolver_;
188   ErrorReporter* error_reporter_;
189   TfLiteContext context_ = {};
190   MicroAllocator& allocator_;
191   bool tensors_allocated_;
192 
193   TfLiteStatus initialization_status_;
194 
195   const SubGraph* subgraph_ = nullptr;
196   TfLiteEvalTensor* eval_tensors_ = nullptr;
197   ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
198 
199   // TODO(b/16157777): Drop this reference:
200   internal::ContextHelper context_helper_;
201 
202   // TODO(b/162311891): Clean these pointers up when this class supports buffers
203   // from TfLiteEvalTensor.
204   TfLiteTensor** input_tensors_;
205   TfLiteTensor** output_tensors_;
206 };
207 
208 }  // namespace tflite
209 
210 #endif  // TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_
211