• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// \file
16 /// Main abstraction controlling the tflite interpreter.
17 /// See context.h for the API for defining operations (TfLiteRegistration).
18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
19 #define TENSORFLOW_LITE_INTERPRETER_H_
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <complex>
25 #include <cstdio>
26 #include <cstdlib>
27 #include <functional>
28 #include <map>
29 #include <memory>
30 #include <string>
31 #include <utility>
32 #include <vector>
33 
34 #include "tensorflow/lite/allocation.h"
35 #include "tensorflow/lite/c/common.h"  // IWYU pragma: export
36 #include "tensorflow/lite/core/api/error_reporter.h"
37 #include "tensorflow/lite/core/api/profiler.h"
38 #include "tensorflow/lite/core/subgraph.h"
39 #include "tensorflow/lite/experimental/resource/initialization_status.h"
40 #include "tensorflow/lite/experimental/resource/resource_base.h"
41 #include "tensorflow/lite/external_cpu_backend_context.h"
42 #include "tensorflow/lite/internal/signature_def.h"
43 #include "tensorflow/lite/memory_planner.h"
44 #include "tensorflow/lite/portable_type_to_tflitetype.h"
45 #include "tensorflow/lite/signature_runner.h"
46 #include "tensorflow/lite/stderr_reporter.h"
47 #include "tensorflow/lite/string_type.h"
48 #include "tensorflow/lite/type_to_tflitetype.h"
49 
50 namespace tflite {
51 
52 class InterpreterTest;  // Class for friend declarations.
53 
54 namespace delegates {
55 class InterpreterUtils;  // Class for friend declarations.
56 
57 namespace test_utils {
58 class TestDelegation;  // Class for friend declarations.
59 }  // namespace test_utils
60 }  // namespace delegates
61 
62 namespace interpreter_wrapper {
63 class InterpreterWrapper;  // Class for friend declarations.
64 }  // namespace interpreter_wrapper
65 
66 /// An interpreter for a graph of nodes that input and output from tensors.
67 /// Each node of the graph processes a set of input tensors and produces a
68 /// set of output Tensors. All inputs/output tensors are referenced by index.
69 ///
70 /// Usage:
71 ///
72 /// <pre><code>
73 /// // Create model from file. Note that the model instance must outlive the
74 /// // interpreter instance.
75 /// auto model = tflite::FlatBufferModel::BuildFromFile(...);
76 /// if (model == nullptr) {
77 ///   // Return error.
78 /// }
79 /// // Create an Interpreter with an InterpreterBuilder.
80 /// std::unique_ptr<tflite::Interpreter> interpreter;
81 /// tflite::ops::builtin::BuiltinOpResolver resolver;
82 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
83 ///   // Return failure.
84 /// }
85 /// if (interpreter->AllocateTensors() != kTfLiteOk) {
86 ///   // Return failure.
87 /// }
88 ///
89 /// auto input = interpreter->typed_tensor<float>(0);
90 /// for (int i = 0; i < input_size; i++) {
91 ///   input[i] = ...;
92 //  }
93 /// interpreter->Invoke();
94 /// </code></pre>
95 ///
96 /// Note: For nearly all practical use cases, one should not directly construct
97 /// an Interpreter object, but rather use the InterpreterBuilder.
98 ///
99 /// WARNING: This class is *not* thread-safe. The client is responsible for
100 /// ensuring serialized interaction to avoid data races and undefined behavior.
101 class Interpreter {
102  public:
103   // Instantiate an interpreter. All errors associated with reading and
104   // processing this model will be forwarded to the error_reporter object.
105   //
106   // Note, if error_reporter is nullptr, then a default StderrReporter is
107   // used. Ownership of 'error_reporter' remains with the caller.
108   // WARNING: Use of this constructor outside of an InterpreterBuilder is not
109   // recommended.
110   explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
111 
112   ~Interpreter();
113 
114   // Interpreters are not copyable as they have non-trivial memory semantics.
115   Interpreter(const Interpreter&) = delete;
116   Interpreter& operator=(const Interpreter&) = delete;
117 
118   // Functions to build interpreter
119 #ifndef DOXYGEN_SKIP
120   /// Provide a list of tensor indexes that are inputs to the model.
121   /// Each index is bound check and this modifies the consistent_ flag of the
122   /// interpreter.
123   TfLiteStatus SetInputs(std::vector<int> inputs);
124 
125   /// Provide a list of tensor indexes that are outputs to the model
126   /// Each index is bound check and this modifies the consistent_ flag of the
127   /// interpreter.
128   TfLiteStatus SetOutputs(std::vector<int> outputs);
129 
130   /// Provide a list of tensor indexes that are variable tensors.
131   /// Each index is bound check and this modifies the consistent_ flag of the
132   /// interpreter.
133   TfLiteStatus SetVariables(std::vector<int> variables);
134 
135   /// Adds a node with the given parameters and returns the index of the new
136   /// node in `node_index` (optionally). Interpreter will take ownership of
137   /// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
138   /// remains with the caller.
139   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
140                                      const std::vector<int>& outputs,
141                                      const char* init_data,
142                                      size_t init_data_size, void* builtin_data,
143                                      const TfLiteRegistration* registration,
144                                      int* node_index = nullptr);
145 
146   /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
147   /// The value pointed to by `first_new_tensor_index` will be set to the
148   /// index of the first new tensor if `first_new_tensor_index` is non-null.
149   TfLiteStatus AddTensors(int tensors_to_add,
150                           int* first_new_tensor_index = nullptr);
151 
152   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
153   /// This variant assumes an external buffer has been allocated of size
154   /// bytes. The lifetime of buffer must be ensured to be greater or equal
155   /// to Interpreter.
156   TfLiteStatus SetTensorParametersReadOnly(
157       int tensor_index, TfLiteType type, const char* name,
158       const std::vector<int>& dims, TfLiteQuantization quantization,
159       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
160 
161   /// Legacy. Deprecated in favor of above.
162   inline TfLiteStatus SetTensorParametersReadOnly(
163       int tensor_index, TfLiteType type, const char* name,
164       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
165       const char* buffer, size_t bytes,
166       const Allocation* allocation = nullptr) {
167     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
168                                        dims.data(), quantization, buffer, bytes,
169                                        allocation);
170   }
171 
172   TfLiteStatus SetTensorParametersReadOnly(
173       int tensor_index, TfLiteType type, const char* name, const size_t rank,
174       const int* dims, TfLiteQuantizationParams quantization,
175       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
176 
177   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
178   /// This variant assumes an external buffer has been allocated of size
179   /// bytes. The lifetime of buffer must be ensured to be greater or equal
180   /// to Interpreter.
181   TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
182                                             const char* name,
183                                             const std::vector<int>& dims,
184                                             TfLiteQuantization quantization,
185                                             bool is_variable = false);
186 
187   /// Legacy. Deprecated in favor of above.
188   inline TfLiteStatus SetTensorParametersReadWrite(
189       int tensor_index, TfLiteType type, const char* name,
190       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
191       bool is_variable = false,
192       const std::vector<int>* dims_signature = nullptr) {
193     size_t rank_dims_signature = 0;
194     const int* dims_signature_pointer = nullptr;
195     if (dims_signature) {
196       rank_dims_signature = dims_signature->size();
197       dims_signature_pointer = dims_signature->data();
198     }
199     return SetTensorParametersReadWrite(
200         tensor_index, type, name, dims.size(), dims.data(), quantization,
201         is_variable, rank_dims_signature, dims_signature_pointer);
202   }
203   TfLiteStatus SetTensorParametersReadWrite(
204       int tensor_index, TfLiteType type, const char* name, const size_t rank,
205       const int* dims, TfLiteQuantizationParams quantization,
206       bool is_variable = false, const size_t rank_dims_signature = 0,
207       const int* dims_signature = nullptr);
208 #endif  // DOXYGEN_SKIP
209   // Functions to access tensor data
210 
211   /// Read only access to list of inputs.
inputs()212   const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
213 
214   /// Return the name of a given input. The given index must be between 0 and
215   /// inputs().size().
GetInputName(int index)216   const char* GetInputName(int index) const {
217     return context_->tensors[inputs()[index]].name;
218   }
219 
220   /// Read only access to list of outputs.
outputs()221   const std::vector<int>& outputs() const {
222     return primary_subgraph().outputs();
223   }
224 
225   /// Read only access to list of variable tensors.
variables()226   const std::vector<int>& variables() const {
227     return primary_subgraph().variables();
228   }
229 
230   /// Return the name of a given output. The given index must be between 0 and
231   /// outputs().size().
GetOutputName(int index)232   const char* GetOutputName(int index) const {
233     return context_->tensors[outputs()[index]].name;
234   }
235 
236   /// Return the number of tensors in the model.
tensors_size()237   size_t tensors_size() const { return context_->tensors_size; }
238 
239   /// Return the number of ops in the model.
nodes_size()240   size_t nodes_size() const { return primary_subgraph().nodes_size(); }
241 
242   /// WARNING: Experimental interface, subject to change
execution_plan()243   const std::vector<int>& execution_plan() const {
244     return primary_subgraph().execution_plan();
245   }
246 
247   /// Get a mutable tensor data structure.
248   // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
249   // read/write access to structure
tensor(int tensor_index)250   TfLiteTensor* tensor(int tensor_index) {
251     return primary_subgraph().tensor(tensor_index);
252   }
253 
254   /// Get an immutable tensor data structure.
tensor(int tensor_index)255   const TfLiteTensor* tensor(int tensor_index) const {
256     return primary_subgraph().tensor(tensor_index);
257   }
258 
259   /// Returns a pointer to an operation and registration data structure if in
260   /// bounds from the primary subgraph(subgraph_[0]).
node_and_registration(int node_index)261   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
262       int node_index) const {
263     return primary_subgraph().node_and_registration(node_index);
264   }
265 
266   /// Returns a pointer to an operation and registration data structure if in
267   /// bounds.
node_and_registration(int subgraph_index,int node_index)268   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
269       int subgraph_index, int node_index) const {
270     return subgraph(subgraph_index)->node_and_registration(node_index);
271   }
272 
273   /// Perform a checked cast to the appropriate tensor type (mutable pointer
274   /// version).
275   template <class T>
typed_tensor(int tensor_index)276   T* typed_tensor(int tensor_index) {
277     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
278       if (tensor_ptr->type == typeToTfLiteType<T>()) {
279         return reinterpret_cast<T*>(tensor_ptr->data.raw);
280       }
281     }
282     return nullptr;
283   }
284 
285   /// Perform a checked cast to the appropriate tensor type (immutable pointer
286   /// version).
287   template <class T>
typed_tensor(int tensor_index)288   const T* typed_tensor(int tensor_index) const {
289     if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
290       if (tensor_ptr->type == typeToTfLiteType<T>()) {
291         return reinterpret_cast<const T*>(tensor_ptr->data.raw);
292       }
293     }
294     return nullptr;
295   }
296 
297   /// WARNING: Experimental interface, subject to change
298   /// Returns list of all keys of different method signatures defined in the
299   /// model.
300   /// Note, pointers returned have lifetime same as the Interpreter object.
signature_keys()301   std::vector<const std::string*> signature_keys() const {
302     std::vector<const std::string*> signature_keys;
303     signature_keys.reserve(signature_defs_.size());
304     for (const auto& sig_def : signature_defs_) {
305       signature_keys.emplace_back(&sig_def.signature_key);
306     }
307     return signature_keys;
308   }
309 
310   /// WARNING: Experimental interface, subject to change
311   /// Returns a pointer to the SignatureRunner instance to run the part of the
312   /// graph identified by a SignatureDef. The nullptr is returned if the given
313   /// signature key is not valid.
314   /// If you need to specify delegates, you have to do that before calling this
315   /// function. This function will additionally apply default delegates. Thus,
316   /// applying delegates after that might lead to undesirable behaviors.
317   /// Note, the pointed instance has lifetime same as the Interpreter object
318   /// and the SignatureRunner class is *not* thread-safe.
319   SignatureRunner* GetSignatureRunner(const char* signature_key);
320 
321   /// WARNING: Experimental interface, subject to change
322   // Return the subgraph index that corresponds to a SignatureDef, defined by
323   // 'signature_key'.
324   // If invalid name passed, -1 will be returned.
GetSubgraphIndexFromSignature(const char * signature_key)325   int GetSubgraphIndexFromSignature(const char* signature_key) const {
326     for (const auto& signature : signature_defs_) {
327       if (signature.signature_key == signature_key) {
328         return signature.subgraph_index;
329       }
330     }
331     return -1;
332   }
333 
334   /// WARNING: Experimental interface, subject to change
335   /// Returns the mapping of inputs to tensor index in the signature
336   /// specified through 'signature_key'.
337   /// If invalid name passed, an empty list will be returned.
signature_inputs(const char * signature_key)338   const std::map<std::string, uint32_t>& signature_inputs(
339       const char* signature_key) const {
340     for (const auto& sig_def : signature_defs_) {
341       if (sig_def.signature_key == signature_key) return sig_def.inputs;
342     }
343     static const std::map<std::string, uint32_t>* default_empty_list =
344         new std::map<std::string, uint32_t>();
345     return *default_empty_list;
346   }
347 
348   /// WARNING: Experimental interface, subject to change
349   /// Returns the mapping of outputs to tensor index in the signature
350   /// specified through 'signature_key'.
351   /// If invalid name passed, an empty list will be returned.
signature_outputs(const char * signature_key)352   const std::map<std::string, uint32_t>& signature_outputs(
353       const char* signature_key) const {
354     for (const auto& sig_def : signature_defs_) {
355       if (sig_def.signature_key == signature_key) return sig_def.outputs;
356     }
357     static const std::map<std::string, uint32_t>* default_empty_list =
358         new std::map<std::string, uint32_t>();
359     return *default_empty_list;
360   }
361 
362   /// WARNING: Experimental interface, subject to change
363   /// Returns the input tensor identified by 'signature_input_name' in the
364   /// signature identified by 'signature_key'.
365   /// Returns nullptr if not found.
input_tensor_by_signature(const char * signature_input_name,const char * signature_key)366   TfLiteTensor* input_tensor_by_signature(const char* signature_input_name,
367                                           const char* signature_key) {
368     const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
369     if (subgraph_index == -1) return nullptr;
370     const int tensor_index = GetTensorIndexFromSignature(
371         signature_input_name, signature_key, /*is_input=*/true);
372     if (tensor_index == -1) return nullptr;
373     return subgraph(subgraph_index)->tensor(tensor_index);
374   }
375 
376   /// WARNING: Experimental interface, subject to change
377   /// Returns the output tensor identified by 'signature_output_name' in the
378   /// signature identified by 'signature_key'.
379   /// Returns nullptr if not found.
output_tensor_by_signature(const char * signature_output_name,const char * signature_key)380   const TfLiteTensor* output_tensor_by_signature(
381       const char* signature_output_name, const char* signature_key) const {
382     const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
383     if (subgraph_index == -1) return nullptr;
384     const int tensor_index = GetTensorIndexFromSignature(
385         signature_output_name, signature_key, /*is_input=*/false);
386     if (tensor_index == -1) return nullptr;
387     return subgraph(subgraph_index)->tensor(tensor_index);
388   }
389 
390   /// Return a mutable pointer to the given input tensor. The given index must
391   /// be between 0 and inputs().size().
input_tensor(size_t index)392   TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
393 
394   /// Return an immutable pointer to the given input tensor. The given index
395   /// must be between 0 and inputs().size().
input_tensor(size_t index)396   const TfLiteTensor* input_tensor(size_t index) const {
397     return tensor(inputs()[index]);
398   }
399 
400   /// Return a mutable pointer into the data of a given input tensor. The given
401   /// index must be between 0 and inputs().size().
402   template <class T>
typed_input_tensor(int index)403   T* typed_input_tensor(int index) {
404     return typed_tensor<T>(inputs()[index]);
405   }
406 
407   /// Return an immutable pointer into the data of a given input tensor. The
408   /// given index must be between 0 and inputs().size().
409   template <class T>
typed_input_tensor(int index)410   const T* typed_input_tensor(int index) const {
411     return typed_tensor<T>(inputs()[index]);
412   }
413 
414   /// Return a mutable pointer to the given output tensor. The given index must
415   /// be between 0 and outputs().size().
output_tensor(size_t index)416   TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); }
417 
418   /// Return an immutable pointer to the given output tensor. The given index
419   /// must be between 0 and outputs().size().
output_tensor(size_t index)420   const TfLiteTensor* output_tensor(size_t index) const {
421     return tensor(outputs()[index]);
422   }
423 
424   /// Return a mutable pointer into the data of a given output tensor. The given
425   /// index must be between 0 and outputs().size().
426   template <class T>
typed_output_tensor(int index)427   T* typed_output_tensor(int index) {
428     return typed_tensor<T>(outputs()[index]);
429   }
430 
431   /// Return an immutable pointer into the data of a given output tensor. The
432   /// given index must be between 0 and outputs().size().
433   template <class T>
typed_output_tensor(int index)434   const T* typed_output_tensor(int index) const {
435     return typed_tensor<T>(outputs()[index]);
436   }
437 
438   /// Change the dimensionality of a given tensor. Note, this is only acceptable
439   /// for tensor indices that are inputs or variables.
440   /// Returns status of failure or success. Note that this doesn't actually
441   /// resize any existing buffers. A call to AllocateTensors() is required to
442   /// change the tensor input buffer.
443   TfLiteStatus ResizeInputTensor(int tensor_index,
444                                  const std::vector<int>& dims);
445 
446   // Change the dimensionality of a given tensor. This is only acceptable for
447   // tensor indices that are inputs or variables. Only unknown dimensions can be
448   // resized with this function. Unknown dimensions are indicated as `-1` in the
449   // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure
450   // or success.  Note that this doesn't actually resize any existing buffers.
451   /// A call to AllocateTensors() is required to change the tensor input buffer.
452   TfLiteStatus ResizeInputTensorStrict(int tensor_index,
453                                        const std::vector<int>& dims);
454 
455   // This releases memory held by non-persistent tensors. It does NOT re-perform
456   // memory planning.
457   // AllocateTensors needs to be called before next invocation.
458   /// WARNING: Experimental interface, subject to change
459   TfLiteStatus ReleaseNonPersistentMemory();
460 
461   // Update allocations for all tensors. This will redim dependent tensors
462   // using the input tensor dimensionality as given. This is relatively
463   // expensive. This *must be* called after the interpreter has been created
464   // and before running inference (and accessing tensor buffers), and *must be*
465   // called again if (and only if) an input tensor is resized. Returns status of
466   // success or failure.  Will fail if any of the ops in the model (other than
467   // those which were rewritten by delegates, if any) are not supported by the
468   // Interpreter's OpResolver.
469   TfLiteStatus AllocateTensors();
470 
471   /// Invoke the interpreter (run the whole graph in dependency order).
472   ///
473   /// NOTE: It is possible that the interpreter is not in a ready state
474   /// to evaluate (i.e. if a ResizeTensor() has been performed without an
475   /// AllocateTensors().
476   /// Returns status of success or failure.
477   TfLiteStatus Invoke();
478 
479   /// Set the number of threads available to the interpreter.
480   ///
481   /// NOTE: num_threads should be >= -1. Setting num_threads to 0 has the effect
482   /// to disable multithreading, which is equivalent to setting num_threads
483   /// to 1. If set to the value -1, the number of threads used will be
484   /// implementation-defined and platform-dependent.
485   TfLiteStatus SetNumThreads(int num_threads);
486 
487   /// Allow float16 precision for FP32 calculation when possible.
488   /// Default: not allow.
489   ///
490   /// WARNING: This API is deprecated: prefer controlling this via delegate
491   /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or
492   /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
493   /// This method will be removed in a future release.
494   void SetAllowFp16PrecisionForFp32(bool allow);
495 
496   /// Get the half precision flag.
497   /// WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()498   bool GetAllowFp16PrecisionForFp32() const {
499     return context_->allow_fp32_relax_to_fp16;
500   }
501 
502   /// Sets the cancellation function pointer in order to cancel a request in the
503   /// middle of a call to Invoke(). The interpreter queries this function during
504   /// inference, between op invocations; when it returns true, the interpreter
505   /// will abort execution and return `kTfLiteError`. The `data` parameter
506   /// contains any data used by the cancellation function, and if non-null,
507   /// remains owned by the caller.
508   /// WARNING: This is an experimental API and subject to change.
509   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
510 
511   /// Allow a delegate to look at the graph and modify the graph to handle
512   /// parts of the graph themselves. After this is called, the graph may
513   /// contain new nodes that replace 1 more nodes.
514   /// 'delegate' must outlive the interpreter.
515   /// Returns one of the following four status codes:
516   /// 1. kTfLiteOk: Success.
517   /// 2. kTfLiteDelegateError: Delegation failed due to an error in the
518   /// delegate, or the delegate parameter was null. The Interpreter has been
519   /// restored to its pre-delegation state.
520   /// NOTE: This undoes all delegates previously applied to the Interpreter.
521   /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
522   /// incompatibility with the TfLite runtime, e.g., the model graph is already
523   /// immutable when applying the delegate. However, the interpreter could still
524   /// be invoked.
525   /// 4. kTfLiteError: Unexpected/runtime failure.
526   /// WARNING: This is an experimental API and subject to change.
527   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
528 
529   // Owning handle to a TfLiteDelegate instance.
530   using TfLiteDelegatePtr =
531       std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
532 
533   /// Same as ModifyGraphWithDelegate except this interpreter takes
534   /// ownership of the provided delegate.
535   /// WARNING: This is an experimental API and subject to change.
536   template <typename Delegate, typename Deleter>
ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)537   inline TfLiteStatus ModifyGraphWithDelegate(
538       std::unique_ptr<Delegate, Deleter> delegate) {
539     Deleter deleter = std::move(delegate.get_deleter());
540 
541     // Note that we retain ownership of the delegate even if graph modification
542     // fails, as delegate use will be in an indeterminate state at that point.
543     owned_delegates_.emplace_back(
544         delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
545           deleter(
546               static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
547                   delegate_to_delete));
548         });
549     return ModifyGraphWithDelegate(owned_delegates_.back().get());
550   }
551 
552   /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no
553   /// virtual destructor. The default deleter of the unique_ptr does not know
554   /// how to delete C++ objects deriving from TfLiteDelegate.
555   TfLiteStatus ModifyGraphWithDelegate(
556       std::unique_ptr<TfLiteDelegate> delegate) = delete;
557 
558   /// Ensure the data in `tensor.data` is readable. In case delegate is used,
559   /// it might require to copy the data from delegate buffer to raw memory.
560   /// WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)561   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
562     return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
563   }
564 
565   /// Set the delegate buffer handle to a tensor. It can be called in the
566   /// following cases:
567   /// 1. Set the buffer handle to a tensor that's not being written by a
568   ///    delegate. For example, feeding an OpenGL texture as the input of the
569   ///    inference graph.
570   /// 2. Set the buffer handle to a tensor that uses the same delegate.
571   ///    For example, set an OpenGL texture as the output of inference, while
572   ///    the node which produces output is an OpenGL delegate node.
573   /// WARNING: This is an experimental API and subject to change.
574   TfLiteStatus SetBufferHandle(int tensor_index,
575                                TfLiteBufferHandle buffer_handle,
576                                TfLiteDelegate* delegate);
577 
578   /// Get the delegate buffer handle, and the delegate which can process the
579   /// buffer handle.
580   /// WARNING: This is an experimental API and subject to change.
581   TfLiteStatus GetBufferHandle(int tensor_index,
582                                TfLiteBufferHandle* buffer_handle,
583                                TfLiteDelegate** delegate);
584 
585   /// Sets the profiler to tracing execution. The caller retains ownership
586   /// of the profiler and must ensure its validity.
587   /// WARNING: This is an experimental API and subject to change.
588   void SetProfiler(Profiler* profiler);
589 
590   /// Same as SetProfiler except this interpreter takes ownership
591   /// of the provided profiler.
592   /// WARNING: This is an experimental API and subject to change.
593   void SetProfiler(std::unique_ptr<Profiler> profiler);
594 
595   /// Gets the profiler used for op tracing.
596   /// WARNING: This is an experimental API and subject to change.
597   Profiler* GetProfiler();
598 
599   // The default capacity of `tensors_` vector.
600   static constexpr int kTensorsReservedCapacity = 128;
601   /// The capacity headroom of `tensors_` vector before calling ops'
602   /// `prepare` and `invoke` function. In these functions, it's guaranteed
603   /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
604   /// pointers to existing tensors.
605   static constexpr int kTensorsCapacityHeadroom = 16;
606 
607   /// Set if buffer handle output is allowed.
608   ///
609   /// When using hardware delegation, Interpreter will make the data of output
610   /// tensors available in `tensor->data` by default. If the application can
611   /// consume the buffer handle directly (e.g. reading output from OpenGL
612   /// texture), it can set this flag to false, so Interpreter won't copy the
613   /// data from buffer handle to CPU memory.
614   /// WARNING: This is an experimental API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)615   void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
616     allow_buffer_handle_output_ = allow_buffer_handle_output;
617   }
618 
619   /// Reset all variable tensors to the default value.
620   /// If a variable tensor doesn't have a buffer, reset it to zero.
621   /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
622   /// to the value of the buffer.
623   /// WARNING: This is an experimental API and subject to change.
624   TfLiteStatus ResetVariableTensors();
625 
626   /// Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)627   const char* OpProfilingString(const TfLiteRegistration& op_reg,
628                                 const TfLiteNode* node) const {
629     if (op_reg.profiling_string == nullptr) return nullptr;
630     return op_reg.profiling_string(context_, node);
631   }
632 
633   // Set the value of an external context. TFLite interpreter doesn't take the
634   // memory ownership of this external context 'ctx', and the context should
635   // outlive the TFLite interpreter.
636   void SetExternalContext(TfLiteExternalContextType type,
637                           TfLiteExternalContext* ctx);
638 
639   // Assigns (or reassigns) a custom memory allocation for the given tensor.
640   // `flags` is a bitmask, see TfLiteCustomAllocationFlags.
641   // The runtime does NOT take ownership of the underlying memory.
642   //
643   // NOTE: User needs to call AllocateTensors() after this. In case of input
644   // resizing, buffers will be checked for required data size during
645   // AllocateTensors().
646   //
647   // Parameters should satisfy the following conditions:
648   // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
649   //    In general, this is true for I/O tensors & variable tensors.
650   // 2. allocation->data has the appropriate permissions for runtime access
651   //    (Read-only for inputs, Read-Write for others), and outlives Interpreter.
652   // 3. allocation->bytes >= tensor->bytes.
653   //    This condition is checked again if any tensors are resized.
654   // 4. allocation->data should be aligned to kDefaultTensorAlignment
655   //    defined in lite/util.h. (Currently 64 bytes)
656   //    This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is
657   //    set through `flags`.
658   //
659   // WARNING: This is an experimental interface that is subject to change.
660   TfLiteStatus SetCustomAllocationForTensor(
661       int tensor_index, const TfLiteCustomAllocation& allocation,
662       int64_t flags = kTfLiteCustomAllocationFlagsNone);
663 
664 #ifndef DOXYGEN_SKIP
665   /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
666   /// entries. The value pointed to by `first_new_subgraph_index` will be set to
667   /// the index of the first new subgraph if `first_new_subgraph_index` is
668   /// non-null.
669   /// WARNING: This is an experimental API and subject to change.
670   void AddSubgraphs(int subgraphs_to_add,
671                     int* first_new_subgraph_index = nullptr);
672 
673   /// Return the number of subgraphs in the model.
674   /// WARNING: This is an experimental API and subject to change.
subgraphs_size()675   size_t subgraphs_size() const { return subgraphs_.size(); }
676 
677   /// Get a pointer to a subgraph if in bounds.
678   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)679   const Subgraph* subgraph(int subgraph_index) const {
680     if (subgraph_index < 0 ||
681         static_cast<size_t>(subgraph_index) >= subgraphs_size()) {
682       return nullptr;
683     }
684     return subgraphs_[subgraph_index].get();
685   }
686 
687   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)688   Subgraph* subgraph(int subgraph_index) {
689     return const_cast<Subgraph*>(
690         static_cast<const Interpreter*>(this)->subgraph(subgraph_index));
691   }
692 
693   /// WARNING: Experimental interface, subject to change
primary_subgraph()694   Subgraph& primary_subgraph() {
695     return *subgraphs_.front();  /// Safe as subgraphs_ always has 1 entry.
696   }
697 
698   /// WARNING: Experimental interface, subject to change
primary_subgraph()699   const Subgraph& primary_subgraph() const {
700     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
701   }
702 
703   /// WARNING: Experimental interface, subject to change
704   // Get the error reporter associated with this interpreter.
error_reporter()705   ErrorReporter* error_reporter() const { return error_reporter_; }
706 
707 #endif  // DOXYGEN_SKIP
708 
709  private:
710   friend class InterpreterBuilder;
711   friend class tflite::InterpreterTest;
712   friend class tflite::delegates::InterpreterUtils;
713   friend class tflite::delegates::test_utils::TestDelegation;
714   friend class tflite::interpreter_wrapper::InterpreterWrapper;
715 
716   /// Set the value of an external context.
717   static void SetExternalContext(struct TfLiteContext* context,
718                                  TfLiteExternalContextType type,
719                                  TfLiteExternalContext* ctx);
720 
721   // Helper method that return the tensor index that corresponds to
722   // a name in a SignatureDef. Defined by 'signature_key', and
723   // 'signature_tensor_name'.
724   // If 'is_input' is true then the tensor is checked in input tensors,
725   // otherwise it will be checked in output tensors.
726   // Returns -1 if the tensor is not found.
GetTensorIndexFromSignature(const char * signature_tensor_name,const char * signature_key,bool is_input)727   int GetTensorIndexFromSignature(const char* signature_tensor_name,
728                                   const char* signature_key,
729                                   bool is_input) const {
730     // Iterate directly and don't use other methods to avoid extra allocation.
731     for (const auto& signature : signature_defs_) {
732       if (signature.signature_key != signature_key) continue;
733       auto& signature_list = (is_input ? signature.inputs : signature.outputs);
734       auto tensor_iter = signature_list.find(signature_tensor_name);
735       if (tensor_iter == signature_list.end()) return -1;
736       return tensor_iter->second;
737     }
738     return -1;
739   }
740 
741   // Applies TFLite default delegates.
742   TfLiteStatus ApplyLazyDelegateProviders();
743 
744   // Overrides execution plan. This bounds checks indices sent in.
745   // Note: Only used during initialization.
746   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
747 
748   // Sets the profiler to all subgraphs.
749   void SetSubgraphProfiler();
750 
751   // Remove delegates (for fallback behaviour). The interpreter is invokable
752   // afterwards.
753   TfLiteStatus RemoveAllDelegates();
754 
755   // Returns true if delegates have been applied.
756   bool HasDelegates();
757 
758   // Returns true if cancellation function returns true.
759   bool IsCancelled();
760 
761   // Sets the list of signature defs in the model.
SetSignatureDef(std::vector<internal::SignatureDef> signature_defs)762   void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) {
763     signature_defs_ = std::move(signature_defs);
764   }
765 
766   // Enables preserving intermediates for debugging.  Should only be set by
767   // InterpreterBuilder before allocating any tensors.
768   TfLiteStatus PreserveAllTensorsExperimental();
769 
770   // Sets model metadata as a mapping of name (key) and buffer (value) strings.
771   // Used by InterpreterBuilder, should be called after setting up subgraphs.
772   TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata);
773 
774   // A pure C data structure used to communicate with the pure C plugin
775   // interface. To avoid copying tensor metadata, this is also the definitive
776   // structure to store tensors.
777   // This is the primary subgraph context.
778   TfLiteContext* context_ = nullptr;
779 
780   // The error reporter delegate that tflite will forward queries errors to.
781   ErrorReporter* error_reporter_ = nullptr;
782 
783   // List of delegates that have been installed and are owned by this
784   // interpreter instance. Useful if client delegate ownership is burdensome.
785   // WARNING: This is an experimental API and subject to change.
786   // TODO(b/116667551): Use TfLiteExternalContext for storing state.
787   std::vector<
788       std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>>
789       owned_delegates_;
790 
791   // Profiler that has been installed and is owned by this interpreter instance.
792   // Useful if client profiler ownership is burdensome.
793   std::unique_ptr<Profiler> owned_profiler_;
794 
795   // Points to the installed Profiler instance.
796   Profiler* installed_profiler_ = nullptr;
797 
798   bool allow_buffer_handle_output_ = false;
799 
800   // List of active external contexts.
801   TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
802 
803   // The default external cpu backend context. After an TFLite interpreter is
804   // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point
805   // to this object. However, if this element value is overwritten via calling
806   // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to
807   // nullptr if necessary.
808   std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_;
809 
810   // Subgraphs
811   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
812 
813   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
814   resource::ResourceMap resources_;
815 
816   // A map of resource Ids. Owned by interpreter and shared by multiple
817   // subgraphs.
818   resource::ResourceIDMap resource_ids_;
819 
820   // A map of intialization statuses, that indicate whether the intialization
821   // subgraph invocation is done or not. Owned by interpreter and shared by
822   // multiple subgraphs.
823   resource::InitializationStatusMap initialization_status_map_;
824 
825   // Indicating delegates that the TFLite interpreter will apply by default.
826   // An empty one means there's no delegate to be applied by default or
827   // delegates have been applied and doesn't need to be applied again.
828   std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
829 
830   // List of SignatureDefs obtained from the model.
831   std::vector<internal::SignatureDef> signature_defs_;
832 
833   // Map of signature key to its corresponding SignatureRunner object.
834   // A SignatureRunner is basically a wrapper of the Subgraph corresponding to
835   // its SignatureDef.
836   std::map<std::string, SignatureRunner> signature_runner_map_;
837 
838   // Model metadata stored as mapping of name (key) to buffer (value).
839   // Data is mapped from the Metadata in TFLite flatbuffer model.
840   std::map<std::string, std::string> metadata_;
841 };
842 
843 }  // namespace tflite
844 #endif  // TENSORFLOW_LITE_INTERPRETER_H_
845