• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Main abstraction controlling the tflite interpreter.
16 // See context.h for the API for defining operations (TfLiteRegistration).
17 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
18 #define TENSORFLOW_LITE_INTERPRETER_H_
19 
20 #include <complex>
21 #include <cstdio>
22 #include <cstdlib>
23 #include <memory>
24 #include <vector>
25 
26 #include "tensorflow/lite/allocation.h"
27 #include "tensorflow/lite/c/common.h"  // IWYU pragma: export
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/profiler.h"
30 #include "tensorflow/lite/core/subgraph.h"
31 #include "tensorflow/lite/experimental/resource/resource_base.h"
32 #include "tensorflow/lite/external_cpu_backend_context.h"
33 #include "tensorflow/lite/memory_planner.h"
34 #include "tensorflow/lite/stderr_reporter.h"
35 #include "tensorflow/lite/type_to_tflitetype.h"
36 
37 namespace tflite {
38 
39 /// An interpreter for a graph of nodes that input and output from tensors.
40 /// Each node of the graph processes a set of input tensors and produces a
41 /// set of output Tensors. All inputs/output tensors are referenced by index.
42 ///
43 /// Usage:
44 ///
45 /// <pre><code>
46 /// // Create basic model
47 /// Interpreter foo(2, 1);
48 /// foo.SetTensorParametersReadWrite(0, ...);
49 /// foo.SetTensorParametersReadOnly(1, ...);
50 /// foo.SetNodeParameters(0, ...)
51 /// // Resize input array to 1 length.
52 /// foo.ResizeInputTensor(0, 1);
53 /// foo.AllocateTensors();
54 /// // Install array data
55 /// foo.typed_tensor<float>(0)[0] = 3;
56 /// foo.Invoke();
57 /// foo.typed_tensor<float>(0)[0] = 4;
58 /// foo.Invoke();
59 /// // Resize input array and set data.
60 /// foo.ResizeInputTensor(0, 2);
61 /// foo.AllocateTensors();
62 /// foo.typed_tensor<float>(0)[0] = 4;
63 /// foo.typed_tensor<float>(0)[1] = 8;
64 /// foo.Invoke();
65 /// </code></pre>
66 ///
67 
68 class Interpreter {
69  public:
70   /// Instantiate an interpreter. All errors associated with reading and
71   /// processing this model will be forwarded to the error_reporter object.
72   //
73   /// Note, if error_reporter is nullptr, then a default StderrReporter is
74   /// used. Ownership of 'error_reporter' remains with the caller.
75   explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
76 
77   ~Interpreter();
78 
79   // Interpreters are not copyable as they have non-trivial memory semantics.
80   Interpreter(const Interpreter&) = delete;
81   Interpreter& operator=(const Interpreter&) = delete;
82 
83   // Functions to build interpreter
84 #ifndef DOXYGEN_SKIP
85   /// Provide a list of tensor indexes that are inputs to the model.
86   /// Each index is bound check and this modifies the consistent_ flag of the
87   /// interpreter.
88   TfLiteStatus SetInputs(std::vector<int> inputs);
89 
90   /// Provide a list of tensor indexes that are outputs to the model
91   /// Each index is bound check and this modifies the consistent_ flag of the
92   /// interpreter.
93   TfLiteStatus SetOutputs(std::vector<int> outputs);
94 
95   /// Provide a list of tensor indexes that are variable tensors.
96   /// Each index is bound check and this modifies the consistent_ flag of the
97   /// interpreter.
98   TfLiteStatus SetVariables(std::vector<int> variables);
99 
100   /// Ensure the internal node storage memory allocates at least `count`
101   /// spots for node. NOTE, this doesn't actually add operators. This is an
102   /// efficiency optimization that is subject to change.
103   void ReserveNodes(int count);
104 
105   /// Adds a node with the given parameters and returns the index of the new
106   /// node in `node_index` (optionally). Interpreter will take ownership of
107   /// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
108   /// remains with the caller.
109   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
110                                      const std::vector<int>& outputs,
111                                      const char* init_data,
112                                      size_t init_data_size, void* builtin_data,
113                                      const TfLiteRegistration* registration,
114                                      int* node_index = nullptr);
115 
116   /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
117   /// The value pointed to by `first_new_tensor_index` will be set to the
118   /// index of the first new tensor if `first_new_tensor_index` is non-null.
119   TfLiteStatus AddTensors(int tensors_to_add,
120                           int* first_new_tensor_index = nullptr);
121 
122   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
123   /// This variant assumes an external buffer has been allocated of size
124   /// bytes. The lifetime of buffer must be ensured to be greater or equal
125   /// to Interpreter.
126   TfLiteStatus SetTensorParametersReadOnly(
127       int tensor_index, TfLiteType type, const char* name,
128       const std::vector<int>& dims, TfLiteQuantization quantization,
129       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
130 
131   /// Legacy. Deprecated in favor of above.
132   inline TfLiteStatus SetTensorParametersReadOnly(
133       int tensor_index, TfLiteType type, const char* name,
134       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
135       const char* buffer, size_t bytes,
136       const Allocation* allocation = nullptr) {
137     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
138                                        dims.data(), quantization, buffer, bytes,
139                                        allocation);
140   }
141 
142   TfLiteStatus SetTensorParametersReadOnly(
143       int tensor_index, TfLiteType type, const char* name, const size_t rank,
144       const int* dims, TfLiteQuantizationParams quantization,
145       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
146 
147   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
148   /// This variant assumes an external buffer has been allocated of size
149   /// bytes. The lifetime of buffer must be ensured to be greater or equal
150   /// to Interpreter.
151   TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
152                                             const char* name,
153                                             const std::vector<int>& dims,
154                                             TfLiteQuantization quantization,
155                                             bool is_variable = false);
156 
157   /// Legacy. Deprecated in favor of above.
158   inline TfLiteStatus SetTensorParametersReadWrite(
159       int tensor_index, TfLiteType type, const char* name,
160       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
161       bool is_variable = false) {
162     return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
163                                         dims.data(), quantization, is_variable);
164   }
165   TfLiteStatus SetTensorParametersReadWrite(
166       int tensor_index, TfLiteType type, const char* name, const size_t rank,
167       const int* dims, TfLiteQuantizationParams quantization,
168       bool is_variable = false);
169 #endif  // DOXYGEN_SKIP
170   // Functions to access tensor data
171 
172   /// Read only access to list of inputs.
inputs()173   const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
174 
175   /// Return the name of a given input. The given index must be between 0 and
176   /// inputs().size().
GetInputName(int index)177   const char* GetInputName(int index) const {
178     return context_->tensors[inputs()[index]].name;
179   }
180 
181   /// Read only access to list of outputs.
outputs()182   const std::vector<int>& outputs() const {
183     return primary_subgraph().outputs();
184   }
185 
186   /// Read only access to list of variable tensors.
variables()187   const std::vector<int>& variables() const {
188     return primary_subgraph().variables();
189   }
190 
191   /// Return the name of a given output. The given index must be between 0 and
192   /// outputs().size().
GetOutputName(int index)193   const char* GetOutputName(int index) const {
194     return context_->tensors[outputs()[index]].name;
195   }
196 
197   /// Return the number of tensors in the model.
tensors_size()198   size_t tensors_size() const { return context_->tensors_size; }
199 
200   /// Return the number of ops in the model.
nodes_size()201   size_t nodes_size() const { return primary_subgraph().nodes_size(); }
202 
203   /// WARNING: Experimental interface, subject to change
execution_plan()204   const std::vector<int>& execution_plan() const {
205     return primary_subgraph().execution_plan();
206   }
207 
208 #ifndef DOXYGEN_
209   /// WARNING: Experimental interface, subject to change
210   /// Overrides execution plan. This bounds checks indices sent in.
211   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
212 #endif  // DOXYGEN_SKIP
213 
214   /// Get a mutable tensor data structure.
215   // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
216   // read/write access to structure
tensor(int tensor_index)217   TfLiteTensor* tensor(int tensor_index) {
218     return primary_subgraph().tensor(tensor_index);
219   }
220 
221   /// Get an immutable tensor data structure.
tensor(int tensor_index)222   const TfLiteTensor* tensor(int tensor_index) const {
223     return primary_subgraph().tensor(tensor_index);
224   }
225 
226   /// Get a pointer to an operation and registration data structure if in
227   /// bounds.
node_and_registration(int node_index)228   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
229       int node_index) const {
230     return primary_subgraph().node_and_registration(node_index);
231   }
232 
233   /// Perform a checked cast to the appropriate tensor type (mutable pointer
234   /// version).
235   template <class T>
typed_tensor(int tensor_index)236   T* typed_tensor(int tensor_index) {
237     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
238       if (tensor_ptr->type == typeToTfLiteType<T>()) {
239         return reinterpret_cast<T*>(tensor_ptr->data.raw);
240       }
241     }
242     return nullptr;
243   }
244 
245   /// Perform a checked cast to the appropriate tensor type (immutable pointer
246   /// version).
247   template <class T>
typed_tensor(int tensor_index)248   const T* typed_tensor(int tensor_index) const {
249     if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
250       if (tensor_ptr->type == typeToTfLiteType<T>()) {
251         return reinterpret_cast<const T*>(tensor_ptr->data.raw);
252       }
253     }
254     return nullptr;
255   }
256 
257   /// Return a mutable pointer to the given input tensor. The given index must
258   /// be between 0 and inputs().size().
input_tensor(size_t index)259   TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
260 
261   /// Return an immutable pointerto the given input tensor. The given index must
262   /// be between 0 and inputs().size().
input_tensor(size_t index)263   const TfLiteTensor* input_tensor(size_t index) const {
264     return tensor(inputs()[index]);
265   }
266 
267   /// Return a mutable pointer into the data of a given input tensor. The given
268   /// index must be between 0 and inputs().size().
269   template <class T>
typed_input_tensor(int index)270   T* typed_input_tensor(int index) {
271     return typed_tensor<T>(inputs()[index]);
272   }
273 
274   /// Return an immutable pointer into the data of a given input tensor. The
275   /// given index must be between 0 and inputs().size().
276   template <class T>
typed_input_tensor(int index)277   const T* typed_input_tensor(int index) const {
278     return typed_tensor<T>(inputs()[index]);
279   }
280 
281   /// Return a mutable pointer to the given output tensor. The given index must
282   /// be between 0 and outputs().size().
output_tensor(size_t index)283   TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); }
284 
285   /// Return an immutable pointer to the given output tensor. The given index
286   /// must be between 0 and outputs().size().
output_tensor(size_t index)287   const TfLiteTensor* output_tensor(size_t index) const {
288     return tensor(outputs()[index]);
289   }
290 
291   /// Return a mutable pointer into the data of a given output tensor. The given
292   /// index must be between 0 and outputs().size().
293   template <class T>
typed_output_tensor(int index)294   T* typed_output_tensor(int index) {
295     return typed_tensor<T>(outputs()[index]);
296   }
297 
298   /// Return an immutable pointer into the data of a given output tensor. The
299   /// given index must be between 0 and outputs().size().
300   template <class T>
typed_output_tensor(int index)301   const T* typed_output_tensor(int index) const {
302     return typed_tensor<T>(outputs()[index]);
303   }
304 
305   /// Change the dimensionality of a given tensor. Note, this is only acceptable
306   /// for tensor indices that are inputs or variables.
307   /// Returns status of failure or success.
308   /// TODO(aselle): Consider implementing ArraySlice equivalent to make this
309   ///   more adept at accepting data without an extra copy. Use absl::ArraySlice
310   ///   if our partners determine that dependency is acceptable.
311   TfLiteStatus ResizeInputTensor(int tensor_index,
312                                  const std::vector<int>& dims);
313 
314   // This releases memory held by non-persistent tensors. It does NOT re-perform
315   // memory planning.
316   // AllocateTensors needs to be called before next invocation.
317   /// WARNING: Experimental interface, subject to change
318   TfLiteStatus ReleaseNonPersistentMemory();
319 
320   /// Update allocations for all tensors. This will redim dependent tensors
321   /// using the input tensor dimensionality as given. This is relatively
322   /// expensive. If you know that your sizes are not changing, you need not call
323   /// this. Returns status of success or failure.
324   TfLiteStatus AllocateTensors();
325 
326   /// Invoke the interpreter (run the whole graph in dependency order).
327   ///
328   /// NOTE: It is possible that the interpreter is not in a ready state
329   /// to evaluate (i.e. if a ResizeTensor() has been performed without an
330   /// AllocateTensors().
331   /// Returns status of success or failure.
332   TfLiteStatus Invoke();
333 
334   /// Enable or disable the NN API (true to enable)
335   void UseNNAPI(bool enable);
336 
337   /// Set the number of threads available to the interpreter.
338   void SetNumThreads(int num_threads);
339 
340   /// Allow float16 precision for FP32 calculation when possible.
341   /// default: not allow.
342   /// WARNING: This is an experimental API and subject to change.
343   void SetAllowFp16PrecisionForFp32(bool allow);
344 
345   /// Get the half precision flag.
346   /// WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()347   bool GetAllowFp16PrecisionForFp32() const {
348     return context_->allow_fp32_relax_to_fp16;
349   }
350 
351   /// Sets the cancellation function pointer in order to cancel a request in the
352   /// middle of a call to Invoke(). The interpreter queries this function during
353   /// inference, between op invocations; when it returns true, the interpreter
354   /// will abort execution and return `kTfLiteError`. The `data` parameter
355   /// contains any data used by the cancellation function, and if non-null,
356   /// remains owned by the caller.
357   /// WARNING: This is an experimental API and subject to change.
358   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
359 
360   /// Allow a delegate to look at the graph and modify the graph to handle
361   /// parts of the graph themselves. After this is called, the graph may
362   /// contain new nodes that replace 1 more nodes.
363   /// 'delegate' must outlive the interpreter.
364   /// WARNING: This is an experimental API and subject to change.
365   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
366 
367   // Owning handle to a TfLiteDelegate instance.
368   using TfLiteDelegatePtr =
369       std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
370 
371   /// Same as ModifyGraphWithDelegate except this interpreter takes
372   /// ownership of the provided delegate. Be sure to construct the unique_ptr
373   /// with a suitable destruction function.
374   /// WARNING: This is an experimental API and subject to change.
375   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate);
376 
377   /// Ensure the data in `tensor.data` is readable. In case delegate is used,
378   /// it might require to copy the data from delegate buffer to raw memory.
379   /// WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)380   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
381     return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
382   }
383 
384   /// Set the delegate buffer handle to a tensor. It can be called in the
385   /// following cases:
386   /// 1. Set the buffer handle to a tensor that's not being written by a
387   ///    delegate. For example, feeding an OpenGL texture as the input of the
388   ///    inference graph.
389   /// 2. Set the buffer handle to a tensor that uses the same delegate.
390   ///    For example, set an OpenGL texture as the output of inference, while
391   ///    the node which produces output is an OpenGL delegate node.
392   /// WARNING: This is an experimental API and subject to change.
393   TfLiteStatus SetBufferHandle(int tensor_index,
394                                TfLiteBufferHandle buffer_handle,
395                                TfLiteDelegate* delegate);
396 
397   /// Get the delegate buffer handle, and the delegate which can process the
398   /// buffer handle.
399   /// WARNING: This is an experimental API and subject to change.
400   TfLiteStatus GetBufferHandle(int tensor_index,
401                                TfLiteBufferHandle* buffer_handle,
402                                TfLiteDelegate** delegate);
403 
404   /// Sets the profiler to tracing execution. The caller retains ownership
405   /// of the profiler and must ensure its validity.
406   /// WARNING: This is an experimental API and subject to change.
407   void SetProfiler(Profiler* profiler);
408 
409   /// Gets the profiler used for op tracing.
410   /// WARNING: This is an experimental API and subject to change.
411   Profiler* GetProfiler();
412 
413   // The default capacity of `tensors_` vector.
414   static constexpr int kTensorsReservedCapacity = 128;
415   /// The capacity headroom of `tensors_` vector before calling ops'
416   /// `prepare` and `invoke` function. In these functions, it's guaranteed
417   /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
418   /// pointers to existing tensors.
419   static constexpr int kTensorsCapacityHeadroom = 16;
420 
421   /// Set if buffer handle output is allowed.
422   //
423   /// When using hardware delegation, Interpreter will make the data of output
424   /// tensors available in `tensor->data` by default. If the application can
425   /// consume the buffer handle directly (e.g. reading output from OpenGL
426   /// texture), it can set this flag to false, so Interpreter won't copy the
427   /// data from buffer handle to CPU memory. WARNING: This is an experimental
428   /// API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)429   void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
430     allow_buffer_handle_output_ = allow_buffer_handle_output;
431   }
432 
433   /// Reset all variable tensors to the default value.
434   /// If a variable tensor doesn't have a buffer, reset it to zero.
435   /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
436   /// to the value of the buffer.
437   /// WARNING: This is an experimental API and subject to change.
438   TfLiteStatus ResetVariableTensors();
439 
440   /// Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)441   const char* OpProfilingString(const TfLiteRegistration& op_reg,
442                                 const TfLiteNode* node) const {
443     if (op_reg.profiling_string == nullptr) return nullptr;
444     return op_reg.profiling_string(context_, node);
445   }
446 
447   // Set the value of an external context. TFLite interpreter doesn't take the
448   // memory ownership of this external context 'ctx', and the context should
449   // outlive the TFLite interpreter.
450   void SetExternalContext(TfLiteExternalContextType type,
451                           TfLiteExternalContext* ctx);
452 
453 #ifndef DOXYGEN_SKIP
454   /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
455   /// entries. The value pointed to by `first_new_subgraph_index` will be set to
456   /// the index of the first new subgraph if `first_new_subgraph_index` is
457   /// non-null.
458   /// WARNING: This is an experimental API and subject to change.
459   void AddSubgraphs(int subgraphs_to_add,
460                     int* first_new_subgraph_index = nullptr);
461 
462   /// Return the number of subgraphs in the model.
463   /// WARNING: This is an experimental API and subject to change.
subgraphs_size()464   size_t subgraphs_size() const { return subgraphs_.size(); }
465 
466   /// Get a pointer to a subgraph if in bounds.
467   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)468   Subgraph* subgraph(int subgraph_index) {
469     if (subgraph_index < 0 ||
470         static_cast<size_t>(subgraph_index) >= subgraphs_size())
471       return nullptr;
472     return &*subgraphs_[subgraph_index];
473   }
474 
475   /// WARNING: Experimental interface, subject to change
primary_subgraph()476   Subgraph& primary_subgraph() {
477     return *subgraphs_.front();  /// Safe as subgraphs_ always has 1 entry.
478   }
479 
480   /// WARNING: Experimental interface, subject to change
primary_subgraph()481   const Subgraph& primary_subgraph() const {
482     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
483   }
484 #endif  // DOXYGEN_SKIP
485 
486  private:
487   friend class InterpreterBuilder;
488   friend class InterpreterTest;
489 
490   /// Set the value of an external context.
491   static void SetExternalContext(struct TfLiteContext* context,
492                                  TfLiteExternalContextType type,
493                                  TfLiteExternalContext* ctx);
494 
495   // A pure C data structure used to communicate with the pure C plugin
496   // interface. To avoid copying tensor metadata, this is also the definitive
497   // structure to store tensors.
498   // This is the primary subgraph context.
499   TfLiteContext* context_;
500 
501   // The error reporter delegate that tflite will forward queries errors to.
502   ErrorReporter* error_reporter_;
503 
504   // List of delegates that have been installed and are owned by this
505   // interpreter instance. Useful if client delegate ownership is burdensome.
506   // WARNING: This is an experimental API and subject to change.
507   // TODO(b/116667551): Use TfLiteExternalContext for storing state.
508   std::vector<TfLiteDelegatePtr> owned_delegates_;
509 
510   bool allow_buffer_handle_output_ = false;
511 
512   // List of active external contexts.
513   TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
514 
515   // The default external cpu backend context. After an TFLite interpreter is
516   // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point
517   // to this object. However, if this element value is overwritten via calling
518   // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to
519   // nullptr if necessary.
520   std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_;
521 
522   // Subgraphs
523   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
524 
525   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
526   resource::ResourceMap resources_;
527 };
528 
529 }  // namespace tflite
530 #endif  // TENSORFLOW_LITE_INTERPRETER_H_
531