• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Main abstraction controlling the tflite interpreter.
16 // See context.h for the API for defining operations (TfLiteRegistration).
17 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
18 #define TENSORFLOW_LITE_INTERPRETER_H_
19 
20 #include <complex>
21 #include <cstdio>
22 #include <cstdlib>
23 #include <vector>
24 
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/c/c_api_internal.h"
27 #include "tensorflow/lite/core/api/error_reporter.h"
28 #include "tensorflow/lite/core/subgraph.h"
29 #include "tensorflow/lite/memory_planner.h"
30 #include "tensorflow/lite/profiling/profiler.h"
31 #include "tensorflow/lite/stderr_reporter.h"
32 
33 namespace tflite {
34 
35 // Map statically from a c++ type to a TfLiteType (used below for safe casts).
36 template <class T>
typeToTfLiteType()37 constexpr TfLiteType typeToTfLiteType() {
38   return kTfLiteNoType;
39 }
40 template <>
41 constexpr TfLiteType typeToTfLiteType<int>() {
42   return kTfLiteInt32;
43 }
44 template <>
45 constexpr TfLiteType typeToTfLiteType<int16_t>() {
46   return kTfLiteInt16;
47 }
48 template <>
49 constexpr TfLiteType typeToTfLiteType<int64_t>() {
50   return kTfLiteInt64;
51 }
52 template <>
53 constexpr TfLiteType typeToTfLiteType<float>() {
54   return kTfLiteFloat32;
55 }
56 template <>
57 constexpr TfLiteType typeToTfLiteType<unsigned char>() {
58   return kTfLiteUInt8;
59 }
60 template <>
61 constexpr TfLiteType typeToTfLiteType<int8_t>() {
62   return kTfLiteInt8;
63 }
64 template <>
65 constexpr TfLiteType typeToTfLiteType<bool>() {
66   return kTfLiteBool;
67 }
68 template <>
69 constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
70   return kTfLiteComplex64;
71 }
72 template <>
73 constexpr TfLiteType typeToTfLiteType<string>() {
74   return kTfLiteString;
75 }
76 
77 // An interpreter for a graph of nodes that input and output from tensors.
78 // Each node of the graph processes a set of input tensors and produces a
79 // set of output Tensors. All inputs/output tensors are referenced by index.
80 //
81 // Usage:
82 //
83 // -- Create basic model
84 // Interpreter foo(2, 1);
85 // foo.SetTensorParametersReadWrite(0, ...);
86 // foo.SetTensorParametersReadOnly(1, ...);
87 // foo.SetNodeParameters(0, ...)
88 //
89 // -- Resize input array to 1 length.
90 // foo.ResizeInputTensor(0, 1);
91 // foo.AllocateTensors();
92 // -- Install array data
93 // foo.typed_tensor<float>(0)[0] = 3;
94 // foo.Invoke();
95 // foo.typed_tensor<float>(0)[0] = 4;
96 // foo.Invoke();
97 // -- Resize input array and set data.
98 // foo.ResizeInputTensor(0, 2);
99 // foo.AllocateTensors();
100 // foo.typed_tensor<float>(0)[0] = 4;
101 // foo.typed_tensor<float>(0)[1] = 8;
102 // foo.Invoke();
103 //
104 
105 class Interpreter {
106  public:
107   // Instantiate an interpreter. All errors associated with reading and
108   // processing this model will be forwarded to the error_reporter object.
109   //
110   // Note, if error_reporter is nullptr, then a default StderrReporter is
111   // used. Ownership of 'error_reporter' remains with the caller.
112   explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
113 
114   ~Interpreter();
115 
116   // Interpreters are not copyable as they have non-trivial memory semantics.
117   Interpreter(const Interpreter&) = delete;
118   Interpreter& operator=(const Interpreter&) = delete;
119 
120   // Functions to build interpreter
121 
122   // Provide a list of tensor indexes that are inputs to the model.
123   // Each index is bound check and this modifies the consistent_ flag of the
124   // interpreter.
125   TfLiteStatus SetInputs(std::vector<int> inputs);
126 
127   // Provide a list of tensor indexes that are outputs to the model
128   // Each index is bound check and this modifies the consistent_ flag of the
129   // interpreter.
130   TfLiteStatus SetOutputs(std::vector<int> outputs);
131 
132   // Provide a list of tensor indexes that are variable tensors.
133   // Each index is bound check and this modifies the consistent_ flag of the
134   // interpreter.
135   TfLiteStatus SetVariables(std::vector<int> variables);
136 
137   // Ensure the internal node storage memory allocates at least `count`
138   // spots for node. NOTE, this doesn't actually add operators. This is an
139   // efficiency optimization that is subject to change.
140   void ReserveNodes(int count);
141 
142   // Adds a node with the given parameters and returns the index of the new
143   // node in `node_index` (optionally). Interpreter will take ownership of
144   // `builtin_data` and destroy it with `free`. Ownership of 'init_data'
145   // remains with the caller.
146   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
147                                      const std::vector<int>& outputs,
148                                      const char* init_data,
149                                      size_t init_data_size, void* builtin_data,
150                                      const TfLiteRegistration* registration,
151                                      int* node_index = nullptr);
152 
153   // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
154   // The value pointed to by `first_new_tensor_index` will be set to the
155   // index of the first new tensor if `first_new_tensor_index` is non-null.
156   TfLiteStatus AddTensors(int tensors_to_add,
157                           int* first_new_tensor_index = nullptr);
158 
159   // Set description of inputs/outputs/data/fptrs for node `node_index`.
160   // This variant assumes an external buffer has been allocated of size
161   // bytes. The lifetime of buffer must be ensured to be greater or equal
162   // to Interpreter.
163   TfLiteStatus SetTensorParametersReadOnly(
164       int tensor_index, TfLiteType type, const char* name,
165       const std::vector<int>& dims, TfLiteQuantization quantization,
166       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
167 
168   // Legacy. Deprecated in favor of above.
169   inline TfLiteStatus SetTensorParametersReadOnly(
170       int tensor_index, TfLiteType type, const char* name,
171       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
172       const char* buffer, size_t bytes,
173       const Allocation* allocation = nullptr) {
174     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
175                                        dims.data(), quantization, buffer, bytes,
176                                        allocation);
177   }
178 
179   TfLiteStatus SetTensorParametersReadOnly(
180       int tensor_index, TfLiteType type, const char* name, const size_t rank,
181       const int* dims, TfLiteQuantizationParams quantization,
182       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
183 
184   // Set description of inputs/outputs/data/fptrs for node `node_index`.
185   // This variant assumes an external buffer has been allocated of size
186   // bytes. The lifetime of buffer must be ensured to be greater or equal
187   // to Interpreter.
188   TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
189                                             const char* name,
190                                             const std::vector<int>& dims,
191                                             TfLiteQuantization quantization,
192                                             bool is_variable = false);
193 
194   // Legacy. Deprecated in favor of above.
195   inline TfLiteStatus SetTensorParametersReadWrite(
196       int tensor_index, TfLiteType type, const char* name,
197       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
198       bool is_variable = false) {
199     return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
200                                         dims.data(), quantization, is_variable);
201   }
202   TfLiteStatus SetTensorParametersReadWrite(
203       int tensor_index, TfLiteType type, const char* name, const size_t rank,
204       const int* dims, TfLiteQuantizationParams quantization,
205       bool is_variable = false);
206 
207   // Functions to access tensor data
208 
209   // Read only access to list of inputs.
inputs()210   const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
211 
212   // Return the name of a given input. The given index must be between 0 and
213   // inputs().size().
GetInputName(int index)214   const char* GetInputName(int index) const {
215     return context_->tensors[inputs()[index]].name;
216   }
217 
218   // Read only access to list of outputs.
outputs()219   const std::vector<int>& outputs() const {
220     return primary_subgraph().outputs();
221   }
222 
223   // Read only access to list of variable tensors.
variables()224   const std::vector<int>& variables() const {
225     return primary_subgraph().variables();
226   }
227 
228   // Return the name of a given output. The given index must be between 0 and
229   // outputs().size().
GetOutputName(int index)230   const char* GetOutputName(int index) const {
231     return context_->tensors[outputs()[index]].name;
232   }
233 
234   // Return the number of tensors in the model.
tensors_size()235   size_t tensors_size() const { return context_->tensors_size; }
236 
237   // Return the number of ops in the model.
nodes_size()238   size_t nodes_size() const { return primary_subgraph().nodes_size(); }
239 
240   // WARNING: Experimental interface, subject to change
execution_plan()241   const std::vector<int>& execution_plan() const {
242     return primary_subgraph().execution_plan();
243   }
244 
245   // WARNING: Experimental interface, subject to change
246   // Overrides execution plan. This bounds checks indices sent in.
247   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
248 
249   // Get a mutable tensor data structure.
250   // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
251   // read/write access to structure
tensor(int tensor_index)252   TfLiteTensor* tensor(int tensor_index) {
253     return primary_subgraph().tensor(tensor_index);
254   }
255 
256   // Get an immutable tensor data structure.
tensor(int tensor_index)257   const TfLiteTensor* tensor(int tensor_index) const {
258     return primary_subgraph().tensor(tensor_index);
259   }
260 
261   // Get a pointer to an operation and registration data structure if in bounds.
node_and_registration(int node_index)262   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
263       int node_index) const {
264     return primary_subgraph().node_and_registration(node_index);
265   }
266 
267   // Perform a checked cast to the appropriate tensor type (mutable pointer
268   // version).
269   template <class T>
typed_tensor(int tensor_index)270   T* typed_tensor(int tensor_index) {
271     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
272       if (tensor_ptr->type == typeToTfLiteType<T>()) {
273         return reinterpret_cast<T*>(tensor_ptr->data.raw);
274       }
275     }
276     return nullptr;
277   }
278 
279   // Perform a checked cast to the appropriate tensor type (immutable pointer
280   // version).
281   template <class T>
typed_tensor(int tensor_index)282   const T* typed_tensor(int tensor_index) const {
283     if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
284       if (tensor_ptr->type == typeToTfLiteType<T>()) {
285         return reinterpret_cast<const T*>(tensor_ptr->data.raw);
286       }
287     }
288     return nullptr;
289   }
290 
291   // Return a mutable pointer into the data of a given input tensor. The given
292   // index must be between 0 and inputs().size().
293   template <class T>
typed_input_tensor(int index)294   T* typed_input_tensor(int index) {
295     return typed_tensor<T>(inputs()[index]);
296   }
297 
298   // Return an immutable pointer into the data of a given input tensor. The
299   // given index must be between 0 and inputs().size().
300   template <class T>
typed_input_tensor(int index)301   const T* typed_input_tensor(int index) const {
302     return typed_tensor<T>(inputs()[index]);
303   }
304 
305   // Return a mutable pointer into the data of a given output tensor. The given
306   // index must be between 0 and outputs().size().
307   template <class T>
typed_output_tensor(int index)308   T* typed_output_tensor(int index) {
309     return typed_tensor<T>(outputs()[index]);
310   }
311 
312   // Return an immutable pointer into the data of a given output tensor. The
313   // given index must be between 0 and outputs().size().
314   template <class T>
typed_output_tensor(int index)315   const T* typed_output_tensor(int index) const {
316     return typed_tensor<T>(outputs()[index]);
317   }
318 
319   // Change the dimensionality of a given tensor. Note, this is only acceptable
320   // for tensor indices that are inputs.
321   // Returns status of failure or success.
322   // TODO(aselle): Consider implementing ArraySlice equivalent to make this
323   //   more adept at accepting data without an extra copy. Use absl::ArraySlice
324   //   if our partners determine that dependency is acceptable.
325   TfLiteStatus ResizeInputTensor(int tensor_index,
326                                  const std::vector<int>& dims);
327 
328   // Update allocations for all tensors. This will redim dependent tensors using
329   // the input tensor dimensionality as given. This is relatively expensive.
330   // If you know that your sizes are not changing, you need not call this.
331   // Returns status of success or failure.
332   TfLiteStatus AllocateTensors();
333 
334   // Invoke the interpreter (run the whole graph in dependency order).
335   //
336   // NOTE: It is possible that the interpreter is not in a ready state
337   // to evaluate (i.e. if a ResizeTensor() has been performed without an
338   // AllocateTensors().
339   // Returns status of success or failure.
340   TfLiteStatus Invoke();
341 
342   // Enable or disable the NN API (true to enable)
343   void UseNNAPI(bool enable);
344 
345   // Set the number of threads available to the interpreter.
346   void SetNumThreads(int num_threads);
347 
348   // Allow float16 precision for FP32 calculation when possible.
349   // default: not allow.
350   // WARNING: This is an experimental API and subject to change.
351   void SetAllowFp16PrecisionForFp32(bool allow);
352 
353   // Get the half precision flag.
354   // WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()355   bool GetAllowFp16PrecisionForFp32() const {
356     return context_->allow_fp32_relax_to_fp16;
357   }
358 
359   // Sets the cancellation function pointer in order to cancel a request in the
360   // middle of a call to Invoke(). The interpreter queries this function during
361   // inference, between op invocations; when it returns true, the interpreter
362   // will abort execution and return `kTfLiteError`. The `data` parameter
363   // contains any data used by the cancellation function, and if non-null,
364   // remains owned by the caller.
365   // WARNING: This is an experimental API and subject to change.
366   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
367 
368   // Owning handle to a TfLiteDelegate instance.
369   using TfLiteDelegatePtr =
370       std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
371 
372   // Allow a delegate to look at the graph and modify the graph to handle
373   // parts of the graph themselves. After this is called, the graph may
374   // contain new nodes that replace 1 more nodes.
375   // WARNING: This is an experimental API and subject to change.
376   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
377 
378   // Ensure the data in `tensor.data` is readable. In case delegate is used,
379   // it might require to copy the data from delegate buffer to raw memory.
380   // WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)381   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
382     return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
383   }
384 
385   // Set the delegate buffer handle to a tensor. It can be called in the
386   // following cases:
387   // 1. Set the buffer handle to a tensor that's not being written by a
388   //    delegate. For example, feeding an OpenGL texture as the input of the
389   //    inference graph.
390   // 2. Set the buffer handle to a tensor that uses the same delegate.
391   //    For example, set an OpenGL texture as the output of inference, while
392   //    the node which produces output is an OpenGL delegate node.
393   // WARNING: This is an experimental API and subject to change.
394   TfLiteStatus SetBufferHandle(int tensor_index,
395                                TfLiteBufferHandle buffer_handle,
396                                TfLiteDelegate* delegate);
397 
398   // Get the delegate buffer handle, and the delegate which can process the
399   // buffer handle.
400   // WARNING: This is an experimental API and subject to change.
401   TfLiteStatus GetBufferHandle(int tensor_index,
402                                TfLiteBufferHandle* buffer_handle,
403                                TfLiteDelegate** delegate);
404 
405   void SetProfiler(profiling::Profiler* profiler);
406 
407   profiling::Profiler* GetProfiler();
408 
409   // The default capacity of `tensors_` vector.
410   static constexpr int kTensorsReservedCapacity = 128;
411   // The capacity headroom of `tensors_` vector before calling ops'
412   // `prepare` and `invoke` function. In these functions, it's guaranteed
413   // allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
414   // pointers to existing tensors.
415   static constexpr int kTensorsCapacityHeadroom = 16;
416 
417   // Set if buffer handle output is allowed.
418   //
419   // When using hardware delegation, Interpreter will make the data of output
420   // tensors available in `tensor->data` by default. If the application can
421   // consume the buffer handle directly (e.g. reading output from OpenGL
422   // texture), it can set this flag to false, so Interpreter won't copy the data
423   // from buffer handle to CPU memory.
424   // WARNING: This is an experimental API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)425   void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
426     allow_buffer_handle_output_ = allow_buffer_handle_output;
427   }
428 
429   // Reset all variable tensors to the default value.
430   // If a variable tensor doesn't have a buffer, reset it to zero.
431   // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
432   // to the value of the buffer.
433   // WARNING: This is an experimental API and subject to change.
434   TfLiteStatus ResetVariableTensors();
435 
436   // Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)437   const char* OpProfilingString(const TfLiteRegistration& op_reg,
438                                 const TfLiteNode* node) const {
439     if (op_reg.profiling_string == nullptr) return nullptr;
440     return op_reg.profiling_string(context_, node);
441   }
442 
443   // Set the value of an external context.
444   void SetExternalContext(TfLiteExternalContextType type,
445                           TfLiteExternalContext* ctx);
446 
447   // Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
448   // entries. The value pointed to by `first_new_subgraph_index` will be set to
449   // the index of the first new subgraph if `first_new_subgraph_index` is
450   // non-null.
451   // WARNING: This is an experimental API and subject to change.
452   void AddSubgraphs(int subgraphs_to_add,
453                     int* first_new_subgraph_index = nullptr);
454 
455   // Return the number of subgraphs in the model.
456   // WARNING: This is an experimental API and subject to change.
subgraphs_size()457   size_t subgraphs_size() const { return subgraphs_.size(); }
458 
459   // Get a pointer to a subgraph if in bounds.
460   // WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)461   Subgraph* subgraph(int subgraph_index) {
462     if (subgraph_index < 0 ||
463         static_cast<size_t>(subgraph_index) >= subgraphs_size())
464       return nullptr;
465     return &*subgraphs_[subgraph_index];
466   }
467 
468   // WARNING: Experimental interface, subject to change
primary_subgraph()469   Subgraph& primary_subgraph() {
470     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
471   }
472 
473   // WARNING: Experimental interface, subject to change
primary_subgraph()474   const Subgraph& primary_subgraph() const {
475     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
476   }
477 
478  private:
479   friend class InterpreterBuilder;
480   friend class InterpreterTest;
481 
482   // Set the value of an external context.
483   static void SetExternalContext(struct TfLiteContext* context,
484                                  TfLiteExternalContextType type,
485                                  TfLiteExternalContext* ctx);
486 
487   // Variant of the public ModifyGraphWithDelegate method that additionally
488   // Assumes ownership of the provided delegate.
489   // WARNING: This is an experimental API and subject to change.
ModifyGraphWithDelegate(TfLiteDelegatePtr delegate)490   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) {
491     // Note that we retain ownership of the delegate even if graph modification
492     // fails, as delegate use will be in an indeterminate state at that point.
493     owned_delegates_.push_back(std::move(delegate));
494     return ModifyGraphWithDelegate(owned_delegates_.back().get());
495   }
496 
497   // A pure C data structure used to communicate with the pure C plugin
498   // interface. To avoid copying tensor metadata, this is also the definitive
499   // structure to store tensors.
500   // This is the primary subgraph context.
501   TfLiteContext* context_;
502 
503   // The error reporter delegate that tflite will forward queries errors to.
504   ErrorReporter* error_reporter_;
505 
506   // List of delegates that have been installed and are owned by this
507   // interpreter instance. Useful if client delegate ownership is burdensome.
508   // WARNING: This is an experimental API and subject to change.
509   // TODO(b/116667551): Use TfLiteExternalContext for storing state.
510   std::vector<TfLiteDelegatePtr> owned_delegates_;
511 
512   bool allow_buffer_handle_output_ = false;
513 
514   // List of active external contexts.
515   TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
516 
517   // Subgraphs
518   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
519 };
520 
521 }  // namespace tflite
522 #endif  // TENSORFLOW_LITE_INTERPRETER_H_
523