• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// \file
16 /// Main abstraction controlling the tflite interpreter.
17 /// See context.h for the API for defining operations (TfLiteRegistration).
18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
19 #define TENSORFLOW_LITE_INTERPRETER_H_
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <complex>
25 #include <cstdio>
26 #include <cstdlib>
27 #include <functional>
28 #include <map>
29 #include <memory>
30 #include <string>
31 #include <utility>
32 #include <vector>
33 
34 #include "tensorflow/lite/allocation.h"
35 #include "tensorflow/lite/c/common.h"  // IWYU pragma: export
36 #include "tensorflow/lite/core/api/error_reporter.h"
37 #include "tensorflow/lite/core/api/profiler.h"
38 #include "tensorflow/lite/core/subgraph.h"
39 #include "tensorflow/lite/experimental/resource/resource_base.h"
40 #include "tensorflow/lite/external_cpu_backend_context.h"
41 #include "tensorflow/lite/memory_planner.h"
42 #include "tensorflow/lite/portable_type_to_tflitetype.h"
43 #include "tensorflow/lite/stderr_reporter.h"
44 #include "tensorflow/lite/string_type.h"
45 #include "tensorflow/lite/type_to_tflitetype.h"
46 
47 namespace tflite {
48 
49 class InterpreterTest;  // Class for friend declarations.
50 
51 namespace delegates {
52 class InterpreterUtils;  // Class for friend declarations.
53 
54 namespace test_utils {
55 class TestDelegate;  // Class for friend declarations.
56 }  // namespace test_utils
57 }  // namespace delegates
58 
59 /// An interpreter for a graph of nodes that input and output from tensors.
60 /// Each node of the graph processes a set of input tensors and produces a
61 /// set of output Tensors. All inputs/output tensors are referenced by index.
62 ///
63 /// Usage:
64 ///
65 /// <pre><code>
66 /// // Create model from file. Note that the model instance must outlive the
67 /// // interpreter instance.
68 /// auto model = tflite::FlatBufferModel::BuildFromFile(...);
69 /// if (model == nullptr) {
70 ///   // Return error.
71 /// }
72 /// // Create an Interpreter with an InterpreterBuilder.
73 /// std::unique_ptr<tflite::Interpreter> interpreter;
74 /// tflite::ops::builtin::BuiltinOpResolver resolver;
75 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
76 ///   // Return failure.
77 /// }
78 /// interpreter->AllocateTensors();
79 ///
80 /// auto input = interpreter->typed_tensor<float>(0);
81 /// for (int i = 0; i < input_size; i++) {
82 ///   input[i] = ...;
83 //  }
84 /// interpreter.Invoke();
85 /// </code></pre>
86 ///
87 /// Note: for nearly all practical use cases, one should not directly construct
88 /// an Interpreter object, but rather use the InterpreterBuilder.
89 
90 class Interpreter {
91  public:
92   // Instantiate an interpreter. All errors associated with reading and
93   // processing this model will be forwarded to the error_reporter object.
94   //
95   // Note, if error_reporter is nullptr, then a default StderrReporter is
96   // used. Ownership of 'error_reporter' remains with the caller.
97   // WARNING: Use of this constructor outside of an InterpreterBuilder is not
98   // recommended.
99   explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
100 
101   ~Interpreter();
102 
103   // Interpreters are not copyable as they have non-trivial memory semantics.
104   Interpreter(const Interpreter&) = delete;
105   Interpreter& operator=(const Interpreter&) = delete;
106 
107   // Functions to build interpreter
108 #ifndef DOXYGEN_SKIP
109   /// Provide a list of tensor indexes that are inputs to the model.
110   /// Each index is bound check and this modifies the consistent_ flag of the
111   /// interpreter.
112   TfLiteStatus SetInputs(std::vector<int> inputs);
113 
114   /// Provide a list of tensor indexes that are outputs to the model
115   /// Each index is bound check and this modifies the consistent_ flag of the
116   /// interpreter.
117   TfLiteStatus SetOutputs(std::vector<int> outputs);
118 
119   /// Provide a list of tensor indexes that are variable tensors.
120   /// Each index is bound check and this modifies the consistent_ flag of the
121   /// interpreter.
122   TfLiteStatus SetVariables(std::vector<int> variables);
123 
124   /// Ensure the internal node storage memory allocates at least `count`
125   /// spots for node. NOTE, this doesn't actually add operators. This is an
126   /// efficiency optimization that is subject to change.
127   void ReserveNodes(int count);
128 
129   /// Adds a node with the given parameters and returns the index of the new
130   /// node in `node_index` (optionally). Interpreter will take ownership of
131   /// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
132   /// remains with the caller.
133   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
134                                      const std::vector<int>& outputs,
135                                      const char* init_data,
136                                      size_t init_data_size, void* builtin_data,
137                                      const TfLiteRegistration* registration,
138                                      int* node_index = nullptr);
139 
140   /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
141   /// The value pointed to by `first_new_tensor_index` will be set to the
142   /// index of the first new tensor if `first_new_tensor_index` is non-null.
143   TfLiteStatus AddTensors(int tensors_to_add,
144                           int* first_new_tensor_index = nullptr);
145 
146   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
147   /// This variant assumes an external buffer has been allocated of size
148   /// bytes. The lifetime of buffer must be ensured to be greater or equal
149   /// to Interpreter.
150   TfLiteStatus SetTensorParametersReadOnly(
151       int tensor_index, TfLiteType type, const char* name,
152       const std::vector<int>& dims, TfLiteQuantization quantization,
153       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
154 
155   /// Legacy. Deprecated in favor of above.
156   inline TfLiteStatus SetTensorParametersReadOnly(
157       int tensor_index, TfLiteType type, const char* name,
158       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
159       const char* buffer, size_t bytes,
160       const Allocation* allocation = nullptr) {
161     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
162                                        dims.data(), quantization, buffer, bytes,
163                                        allocation);
164   }
165 
166   TfLiteStatus SetTensorParametersReadOnly(
167       int tensor_index, TfLiteType type, const char* name, const size_t rank,
168       const int* dims, TfLiteQuantizationParams quantization,
169       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
170 
171   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
172   /// This variant assumes an external buffer has been allocated of size
173   /// bytes. The lifetime of buffer must be ensured to be greater or equal
174   /// to Interpreter.
175   TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
176                                             const char* name,
177                                             const std::vector<int>& dims,
178                                             TfLiteQuantization quantization,
179                                             bool is_variable = false);
180 
181   /// Legacy. Deprecated in favor of above.
182   inline TfLiteStatus SetTensorParametersReadWrite(
183       int tensor_index, TfLiteType type, const char* name,
184       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
185       bool is_variable = false,
186       const std::vector<int>* dims_signature = nullptr) {
187     size_t rank_dims_signature = 0;
188     const int* dims_signature_pointer = nullptr;
189     if (dims_signature) {
190       rank_dims_signature = dims_signature->size();
191       dims_signature_pointer = dims_signature->data();
192     }
193     return SetTensorParametersReadWrite(
194         tensor_index, type, name, dims.size(), dims.data(), quantization,
195         is_variable, rank_dims_signature, dims_signature_pointer);
196   }
197   TfLiteStatus SetTensorParametersReadWrite(
198       int tensor_index, TfLiteType type, const char* name, const size_t rank,
199       const int* dims, TfLiteQuantizationParams quantization,
200       bool is_variable = false, const size_t rank_dims_signature = 0,
201       const int* dims_signature = nullptr);
202 #endif  // DOXYGEN_SKIP
203   // Functions to access tensor data
204 
205   /// Read only access to list of inputs.
inputs()206   const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
207 
208   /// Return the name of a given input. The given index must be between 0 and
209   /// inputs().size().
GetInputName(int index)210   const char* GetInputName(int index) const {
211     return context_->tensors[inputs()[index]].name;
212   }
213 
214   /// Read only access to list of outputs.
outputs()215   const std::vector<int>& outputs() const {
216     return primary_subgraph().outputs();
217   }
218 
219   /// Read only access to list of variable tensors.
variables()220   const std::vector<int>& variables() const {
221     return primary_subgraph().variables();
222   }
223 
224   /// Return the name of a given output. The given index must be between 0 and
225   /// outputs().size().
GetOutputName(int index)226   const char* GetOutputName(int index) const {
227     return context_->tensors[outputs()[index]].name;
228   }
229 
230   /// Return the number of tensors in the model.
tensors_size()231   size_t tensors_size() const { return context_->tensors_size; }
232 
233   /// Return the number of ops in the model.
nodes_size()234   size_t nodes_size() const { return primary_subgraph().nodes_size(); }
235 
236   /// WARNING: Experimental interface, subject to change
execution_plan()237   const std::vector<int>& execution_plan() const {
238     return primary_subgraph().execution_plan();
239   }
240 
241 #ifndef DOXYGEN_
242   /// WARNING: Experimental interface, subject to change
243   /// Overrides execution plan. This bounds checks indices sent in.
244   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
245 #endif  // DOXYGEN_SKIP
246 
247   /// Get a mutable tensor data structure.
248   // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
249   // read/write access to structure
tensor(int tensor_index)250   TfLiteTensor* tensor(int tensor_index) {
251     return primary_subgraph().tensor(tensor_index);
252   }
253 
254   /// Get an immutable tensor data structure.
tensor(int tensor_index)255   const TfLiteTensor* tensor(int tensor_index) const {
256     return primary_subgraph().tensor(tensor_index);
257   }
258 
259   /// Get a pointer to an operation and registration data structure if in
260   /// bounds.
node_and_registration(int node_index)261   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
262       int node_index) const {
263     return primary_subgraph().node_and_registration(node_index);
264   }
265 
266   /// Perform a checked cast to the appropriate tensor type (mutable pointer
267   /// version).
268   template <class T>
typed_tensor(int tensor_index)269   T* typed_tensor(int tensor_index) {
270     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
271       if (tensor_ptr->type == typeToTfLiteType<T>()) {
272         return reinterpret_cast<T*>(tensor_ptr->data.raw);
273       }
274     }
275     return nullptr;
276   }
277 
278   /// Perform a checked cast to the appropriate tensor type (immutable pointer
279   /// version).
280   template <class T>
typed_tensor(int tensor_index)281   const T* typed_tensor(int tensor_index) const {
282     if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
283       if (tensor_ptr->type == typeToTfLiteType<T>()) {
284         return reinterpret_cast<const T*>(tensor_ptr->data.raw);
285       }
286     }
287     return nullptr;
288   }
289 
290   /// WARNING: Experimental interface, subject to change
291   /// Returns list of all names of different method signatures defined
292   /// in the model.
293   /// Note, pointers returned have lifetime same as the Interpreter object.
signature_def_names()294   std::vector<const std::string*> signature_def_names() const {
295     std::vector<const std::string*> method_names;
296     method_names.reserve(signature_defs_.size());
297     for (const auto& sig_def : signature_defs_) {
298       method_names.emplace_back(&sig_def.method_name);
299     }
300     return method_names;
301   }
302 
303   /// WARNING: Experimental interface, subject to change
304   /// Returns the mapping of inputs to tensor index in the signature
305   /// specified through 'method_name'.
306   /// If invalid name passed, an empty list will be returned.
signature_inputs(const char * method_name)307   const std::map<std::string, uint32_t>& signature_inputs(
308       const char* method_name) const {
309     for (const auto& sig_def : signature_defs_) {
310       if (sig_def.method_name == method_name) return sig_def.inputs;
311     }
312     static const std::map<std::string, uint32_t>* default_empty_list =
313         new std::map<std::string, uint32_t>();
314     return *default_empty_list;
315   }
316 
317   /// WARNING: Experimental interface, subject to change
318   /// Returns the mapping of outputs to tensor index in the signature
319   /// specified through 'method_name'.
320   /// If invalid name passed, an empty list will be returned.
signature_outputs(const char * method_name)321   const std::map<std::string, uint32_t>& signature_outputs(
322       const char* method_name) const {
323     for (const auto& sig_def : signature_defs_) {
324       if (sig_def.method_name == method_name) return sig_def.outputs;
325     }
326     static const std::map<std::string, uint32_t>* default_empty_list =
327         new std::map<std::string, uint32_t>();
328     return *default_empty_list;
329   }
330 
331   /// WARNING: Experimental interface, subject to change
332   /// Returns the input tensor identified by 'signature_input_name' in the
333   /// signature identified by 'signature_method_name'.
334   /// Returns nullptr if not found.
input_tensor_by_signature_name(const char * signature_input_name,const char * signature_method_name)335   TfLiteTensor* input_tensor_by_signature_name(
336       const char* signature_input_name, const char* signature_method_name) {
337     const int tensor_index = GetTensorIndexFromSignatureDefName(
338         signature_input_name, signature_method_name, /*is_input=*/true);
339     return tensor_index == -1 ? nullptr : tensor(tensor_index);
340   }
341 
342   /// WARNING: Experimental interface, subject to change
343   /// Returns the output tensor identified by 'signature_output_name' in the
344   /// signature identified by 'signature_method_name'.
345   /// Returns nullptr if not found.
output_tensor_by_signature_name(const char * signature_output_name,const char * signature_method_name)346   const TfLiteTensor* output_tensor_by_signature_name(
347       const char* signature_output_name,
348       const char* signature_method_name) const {
349     const int tensor_index = GetTensorIndexFromSignatureDefName(
350         signature_output_name, signature_method_name, /*is_input=*/false);
351     return tensor_index == -1 ? nullptr : tensor(tensor_index);
352   }
353 
354   /// Return a mutable pointer to the given input tensor. The given index must
355   /// be between 0 and inputs().size().
input_tensor(size_t index)356   TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
357 
358   /// Return an immutable pointerto the given input tensor. The given index must
359   /// be between 0 and inputs().size().
input_tensor(size_t index)360   const TfLiteTensor* input_tensor(size_t index) const {
361     return tensor(inputs()[index]);
362   }
363 
364   /// Return a mutable pointer into the data of a given input tensor. The given
365   /// index must be between 0 and inputs().size().
366   template <class T>
typed_input_tensor(int index)367   T* typed_input_tensor(int index) {
368     return typed_tensor<T>(inputs()[index]);
369   }
370 
371   /// Return an immutable pointer into the data of a given input tensor. The
372   /// given index must be between 0 and inputs().size().
373   template <class T>
typed_input_tensor(int index)374   const T* typed_input_tensor(int index) const {
375     return typed_tensor<T>(inputs()[index]);
376   }
377 
378   /// Return a mutable pointer to the given output tensor. The given index must
379   /// be between 0 and outputs().size().
output_tensor(size_t index)380   TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); }
381 
382   /// Return an immutable pointer to the given output tensor. The given index
383   /// must be between 0 and outputs().size().
output_tensor(size_t index)384   const TfLiteTensor* output_tensor(size_t index) const {
385     return tensor(outputs()[index]);
386   }
387 
388   /// Return a mutable pointer into the data of a given output tensor. The given
389   /// index must be between 0 and outputs().size().
390   template <class T>
typed_output_tensor(int index)391   T* typed_output_tensor(int index) {
392     return typed_tensor<T>(outputs()[index]);
393   }
394 
395   /// Return an immutable pointer into the data of a given output tensor. The
396   /// given index must be between 0 and outputs().size().
397   template <class T>
typed_output_tensor(int index)398   const T* typed_output_tensor(int index) const {
399     return typed_tensor<T>(outputs()[index]);
400   }
401 
402   /// Change the dimensionality of a given tensor. Note, this is only acceptable
403   /// for tensor indices that are inputs or variables.
404   /// Returns status of failure or success. Note that this doesn't actually
405   /// resize any existing buffers. A call to AllocateTensors() is required to
406   /// change the tensor input buffer.
407   TfLiteStatus ResizeInputTensor(int tensor_index,
408                                  const std::vector<int>& dims);
409 
410   // WARNING: Experimental interface, subject to change
411   // Change the dimensionality of a given tensor. This is only acceptable for
412   // tensor indices that are inputs or variables. Only unknown dimensions can be
413   // resized with this function. Unknown dimensions are indicated as `-1` in the
414   // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure
415   // or success.  Note that this doesn't actually resize any existing buffers.
416   /// A call to AllocateTensors() is required to change the tensor input buffer.
417   TfLiteStatus ResizeInputTensorStrict(int tensor_index,
418                                        const std::vector<int>& dims);
419 
420   // This releases memory held by non-persistent tensors. It does NOT re-perform
421   // memory planning.
422   // AllocateTensors needs to be called before next invocation.
423   /// WARNING: Experimental interface, subject to change
424   TfLiteStatus ReleaseNonPersistentMemory();
425 
426   // Update allocations for all tensors. This will redim dependent tensors
427   // using the input tensor dimensionality as given. This is relatively
428   // expensive. This *must be* called after the interpreter has been created
429   // and before running inference (and accessing tensor buffers), and *must be*
430   // called again if (and only if) an input tensor is resized. Returns status of
431   // success or failure.
432   TfLiteStatus AllocateTensors();
433 
434   /// Invoke the interpreter (run the whole graph in dependency order).
435   ///
436   /// NOTE: It is possible that the interpreter is not in a ready state
437   /// to evaluate (i.e. if a ResizeTensor() has been performed without an
438   /// AllocateTensors().
439   /// Returns status of success or failure.
440   TfLiteStatus Invoke();
441 
442   /// Set the number of threads available to the interpreter.
443   ///
444   /// NOTE: num_threads should be >= -1.
445   /// User may pass -1 to let the TFLite interpreter set the no of threads
446   /// available to itself.
447   TfLiteStatus SetNumThreads(int num_threads);
448 
449   /// Allow float16 precision for FP32 calculation when possible.
450   /// Default: not allow.
451   ///
452   /// WARNING: This API is deprecated: prefer controlling this via delegate
453   /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or
454   /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
455   /// This method will be removed in a future release.
456   void SetAllowFp16PrecisionForFp32(bool allow);
457 
458   /// Get the half precision flag.
459   /// WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()460   bool GetAllowFp16PrecisionForFp32() const {
461     return context_->allow_fp32_relax_to_fp16;
462   }
463 
464   /// Sets the cancellation function pointer in order to cancel a request in the
465   /// middle of a call to Invoke(). The interpreter queries this function during
466   /// inference, between op invocations; when it returns true, the interpreter
467   /// will abort execution and return `kTfLiteError`. The `data` parameter
468   /// contains any data used by the cancellation function, and if non-null,
469   /// remains owned by the caller.
470   /// WARNING: This is an experimental API and subject to change.
471   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
472 
473   /// Allow a delegate to look at the graph and modify the graph to handle
474   /// parts of the graph themselves. After this is called, the graph may
475   /// contain new nodes that replace 1 more nodes.
476   /// 'delegate' must outlive the interpreter.
477   /// Returns one of the following four status codes:
478   /// 1. kTfLiteOk: Success.
479   /// 2. kTfLiteDelegateError: Delegation failed due to an error in the
480   /// delegate, or the delegate parameter was null. The Interpreter has been
481   /// restored to its pre-delegation state.
482   /// NOTE: This undoes all delegates previously applied to the Interpreter.
483   /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
484   /// incompatibility with the TfLite runtime, e.g., the model graph is already
485   /// immutable when applying the delegate. However, the interpreter could still
486   /// be invoked.
487   /// 4. kTfLiteError: Unexpected/runtime failure.
488   /// WARNING: This is an experimental API and subject to change.
489   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
490 
491   // Owning handle to a TfLiteDelegate instance.
492   using TfLiteDelegatePtr =
493       std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
494 
495   /// Same as ModifyGraphWithDelegate except this interpreter takes
496   /// ownership of the provided delegate.
497   /// WARNING: This is an experimental API and subject to change.
498   template <typename Delegate, typename Deleter>
ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)499   inline TfLiteStatus ModifyGraphWithDelegate(
500       std::unique_ptr<Delegate, Deleter> delegate) {
501     Deleter deleter = std::move(delegate.get_deleter());
502 
503     // Note that we retain ownership of the delegate even if graph modification
504     // fails, as delegate use will be in an indeterminate state at that point.
505     owned_delegates_.emplace_back(
506         delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
507           deleter(
508               static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
509                   delegate_to_delete));
510         });
511     return ModifyGraphWithDelegate(owned_delegates_.back().get());
512   }
513 
514   /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no
515   /// virtual destructor. The default deleter of the unique_ptr does not know
516   /// how to delete C++ objects deriving from TfLiteDelegate.
517   TfLiteStatus ModifyGraphWithDelegate(
518       std::unique_ptr<TfLiteDelegate> delegate) = delete;
519 
520   /// Ensure the data in `tensor.data` is readable. In case delegate is used,
521   /// it might require to copy the data from delegate buffer to raw memory.
522   /// WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)523   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
524     return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
525   }
526 
527   /// Set the delegate buffer handle to a tensor. It can be called in the
528   /// following cases:
529   /// 1. Set the buffer handle to a tensor that's not being written by a
530   ///    delegate. For example, feeding an OpenGL texture as the input of the
531   ///    inference graph.
532   /// 2. Set the buffer handle to a tensor that uses the same delegate.
533   ///    For example, set an OpenGL texture as the output of inference, while
534   ///    the node which produces output is an OpenGL delegate node.
535   /// WARNING: This is an experimental API and subject to change.
536   TfLiteStatus SetBufferHandle(int tensor_index,
537                                TfLiteBufferHandle buffer_handle,
538                                TfLiteDelegate* delegate);
539 
540   /// Get the delegate buffer handle, and the delegate which can process the
541   /// buffer handle.
542   /// WARNING: This is an experimental API and subject to change.
543   TfLiteStatus GetBufferHandle(int tensor_index,
544                                TfLiteBufferHandle* buffer_handle,
545                                TfLiteDelegate** delegate);
546 
547   /// Sets the profiler to tracing execution. The caller retains ownership
548   /// of the profiler and must ensure its validity.
549   /// WARNING: This is an experimental API and subject to change.
550   void SetProfiler(Profiler* profiler);
551 
552   /// Same as SetProfiler except this interpreter takes ownership
553   /// of the provided profiler.
554   /// WARNING: This is an experimental API and subject to change.
555   void SetProfiler(std::unique_ptr<Profiler> profiler);
556 
557   /// Gets the profiler used for op tracing.
558   /// WARNING: This is an experimental API and subject to change.
559   Profiler* GetProfiler();
560 
561   // The default capacity of `tensors_` vector.
562   static constexpr int kTensorsReservedCapacity = 128;
563   /// The capacity headroom of `tensors_` vector before calling ops'
564   /// `prepare` and `invoke` function. In these functions, it's guaranteed
565   /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
566   /// pointers to existing tensors.
567   static constexpr int kTensorsCapacityHeadroom = 16;
568 
569   /// Set if buffer handle output is allowed.
570   //
571   /// When using hardware delegation, Interpreter will make the data of output
572   /// tensors available in `tensor->data` by default. If the application can
573   /// consume the buffer handle directly (e.g. reading output from OpenGL
574   /// texture), it can set this flag to false, so Interpreter won't copy the
575   /// data from buffer handle to CPU memory. WARNING: This is an experimental
576   /// API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)577   void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
578     allow_buffer_handle_output_ = allow_buffer_handle_output;
579   }
580 
581   /// Reset all variable tensors to the default value.
582   /// If a variable tensor doesn't have a buffer, reset it to zero.
583   /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
584   /// to the value of the buffer.
585   /// WARNING: This is an experimental API and subject to change.
586   TfLiteStatus ResetVariableTensors();
587 
588   /// Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)589   const char* OpProfilingString(const TfLiteRegistration& op_reg,
590                                 const TfLiteNode* node) const {
591     if (op_reg.profiling_string == nullptr) return nullptr;
592     return op_reg.profiling_string(context_, node);
593   }
594 
595   // Set the value of an external context. TFLite interpreter doesn't take the
596   // memory ownership of this external context 'ctx', and the context should
597   // outlive the TFLite interpreter.
598   void SetExternalContext(TfLiteExternalContextType type,
599                           TfLiteExternalContext* ctx);
600 
601   // Assigns (or reassigns) a custom memory allocation for the given tensor.
602   // If AllocateTensors() is called after this, the runtime does not consider
603   // the tensor during internal memory planning and will continue using the
604   // provided allocation for the tensor (assuming it satisfies the expected
605   // tensor byte length).
606   // The runtime does NOT take ownership of the underlying memory.
607   // Note that while this function can be called again to set a new allocation
608   // for the tensor, it can no longer be reset to the TFLite arena memory.
609   //
610   // Parameters should satisfy the following conditions:
611   // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
612   //    In general, this is true for I/O tensors & variable tensors.
613   // 2. allocation->data has the appropriate permissions for runtime access
614   //    (Read-only for inputs, Read-Write for others), and outlives Interpreter.
615   // 3. allocation->bytes >= tensor->bytes.
616   //    This condition is checked again if any tensors are resized.
617   // 4. allocation->data should be aligned to kDefaultTensorAlignment
618   //    defined in lite/util.h. (Currently 64 bytes)
619   //
620   // WARNING: This is an experimental interface that is subject to change.
621   TfLiteStatus SetCustomAllocationForTensor(
622       int tensor_index, const TfLiteCustomAllocation& allocation);
623 
624 #ifndef DOXYGEN_SKIP
625   /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
626   /// entries. The value pointed to by `first_new_subgraph_index` will be set to
627   /// the index of the first new subgraph if `first_new_subgraph_index` is
628   /// non-null.
629   /// WARNING: This is an experimental API and subject to change.
630   void AddSubgraphs(int subgraphs_to_add,
631                     int* first_new_subgraph_index = nullptr);
632 
633   /// Return the number of subgraphs in the model.
634   /// WARNING: This is an experimental API and subject to change.
subgraphs_size()635   size_t subgraphs_size() const { return subgraphs_.size(); }
636 
637   /// Get a pointer to a subgraph if in bounds.
638   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)639   Subgraph* subgraph(int subgraph_index) {
640     if (subgraph_index < 0 ||
641         static_cast<size_t>(subgraph_index) >= subgraphs_size())
642       return nullptr;
643     return &*subgraphs_[subgraph_index];
644   }
645 
646   /// WARNING: Experimental interface, subject to change
primary_subgraph()647   Subgraph& primary_subgraph() {
648     return *subgraphs_.front();  /// Safe as subgraphs_ always has 1 entry.
649   }
650 
651   /// WARNING: Experimental interface, subject to change
primary_subgraph()652   const Subgraph& primary_subgraph() const {
653     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
654   }
655 
656   /// WARNING: Experimental interface, subject to change
657   // Get the error reporter associated with this interpreter.
error_reporter()658   ErrorReporter* error_reporter() const { return error_reporter_; }
659 
660 #endif  // DOXYGEN_SKIP
661 
662  private:
663   // Structure representing SignatureDef inputs/outputs.
664   struct SignatureDef {
665     // Maps name in signature def as key to index of the tensor in the model.
666     std::map<std::string, uint32_t> inputs;
667     // Maps name in signature def as key to index of the tensor in the model.
668     std::map<std::string, uint32_t> outputs;
669     // The method name for this signature.
670     std::string method_name;
671     // The key of this SignatureDef in the SavedModel signature def map.
672     std::string signature_def_key;
673   };
674   friend class InterpreterBuilder;
675   friend class tflite::InterpreterTest;
676   friend class tflite::delegates::InterpreterUtils;
677   friend class tflite::delegates::test_utils::TestDelegate;
678 
679   /// Set the value of an external context.
680   static void SetExternalContext(struct TfLiteContext* context,
681                                  TfLiteExternalContextType type,
682                                  TfLiteExternalContext* ctx);
683 
684   // Helper method that return the tensot index that corresponds to
685   // a name in a SignatureDef. Defined by 'signature_method_name', and
686   // 'signature_tensor_name'.
687   // If 'is_input' is true then the tensor is checked in input tensors,
688   // otherwise it will be checked in output tensors.
689   // Returns -1 if the tensor is not found.
GetTensorIndexFromSignatureDefName(const char * signature_tensor_name,const char * signature_method_name,bool is_input)690   int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name,
691                                          const char* signature_method_name,
692                                          bool is_input) const {
693     // Iterate directly and don't use other methods to avoid extra allocation.
694     for (const auto& signature : signature_defs_) {
695       if (signature.method_name != signature_method_name) continue;
696       auto& signature_list = (is_input ? signature.inputs : signature.outputs);
697       auto tensor_iter = signature_list.find(signature_tensor_name);
698       if (tensor_iter == signature_list.end()) return -1;
699       return tensor_iter->second;
700     }
701     return -1;
702   }
703 
704   // Sets the profiler to all subgraphs.
705   void SetSubgraphProfiler();
706 
707   // Remove delegates (for fallback behaviour). The interpreter is invokable
708   // afterwards.
709   TfLiteStatus RemoveAllDelegates();
710 
711   // Returns true if delegates have been applied.
712   bool HasDelegates();
713 
714   // Returns true if cancellation function returns true.
715   bool IsCancelled();
716 
717   // Sets the list of signature defs in the model.
SetSignatureDef(std::vector<SignatureDef> signature_defs)718   void SetSignatureDef(std::vector<SignatureDef> signature_defs) {
719     signature_defs_ = std::move(signature_defs);
720   }
721 
722   // A pure C data structure used to communicate with the pure C plugin
723   // interface. To avoid copying tensor metadata, this is also the definitive
724   // structure to store tensors.
725   // This is the primary subgraph context.
726   TfLiteContext* context_ = nullptr;
727 
728   // The error reporter delegate that tflite will forward queries errors to.
729   ErrorReporter* error_reporter_ = nullptr;
730 
731   // List of delegates that have been installed and are owned by this
732   // interpreter instance. Useful if client delegate ownership is burdensome.
733   // WARNING: This is an experimental API and subject to change.
734   // TODO(b/116667551): Use TfLiteExternalContext for storing state.
735   std::vector<
736       std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>>
737       owned_delegates_;
738 
739   // Profiler that has been installed and is owned by this interpreter instance.
740   // Useful if client profiler ownership is burdensome.
741   std::unique_ptr<Profiler> owned_profiler_;
742 
743   // Points to the installed Profiler instance.
744   Profiler* installed_profiler_ = nullptr;
745 
746   bool allow_buffer_handle_output_ = false;
747 
748   // List of active external contexts.
749   TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
750 
751   // The default external cpu backend context. After an TFLite interpreter is
752   // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point
753   // to this object. However, if this element value is overwritten via calling
754   // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to
755   // nullptr if necessary.
756   std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_;
757 
758   // Subgraphs
759   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
760 
761   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
762   resource::ResourceMap resources_;
763 
764   // Indicating delegates that the TFLite interpreter will apply by default.
765   // An empty one means there's no delegate to be applied by default or
766   // delegates have been applied and doesn't need to be applied again.
767   std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
768 
769   // List of signature def mapping inputs/output to tensor ids.
770   // We just keep track of tensor index.
771   std::vector<SignatureDef> signature_defs_;
772 };
773 
774 }  // namespace tflite
775 #endif  // TENSORFLOW_LITE_INTERPRETER_H_
776