1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Main abstraction controlling the tflite interpreter.
16 // See context.h for the API for defining operations (TfLiteRegistration).
17 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
18 #define TENSORFLOW_LITE_INTERPRETER_H_
19
20 #include <complex>
21 #include <cstdio>
22 #include <cstdlib>
23 #include <vector>
24
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/c/c_api_internal.h"
27 #include "tensorflow/lite/core/api/error_reporter.h"
28 #include "tensorflow/lite/core/subgraph.h"
29 #include "tensorflow/lite/memory_planner.h"
30 #include "tensorflow/lite/profiling/profiler.h"
31 #include "tensorflow/lite/stderr_reporter.h"
32
33 namespace tflite {
34
35 // Map statically from a c++ type to a TfLiteType (used below for safe casts).
36 template <class T>
typeToTfLiteType()37 constexpr TfLiteType typeToTfLiteType() {
38 return kTfLiteNoType;
39 }
40 template <>
41 constexpr TfLiteType typeToTfLiteType<int>() {
42 return kTfLiteInt32;
43 }
44 template <>
45 constexpr TfLiteType typeToTfLiteType<int16_t>() {
46 return kTfLiteInt16;
47 }
48 template <>
49 constexpr TfLiteType typeToTfLiteType<int64_t>() {
50 return kTfLiteInt64;
51 }
52 template <>
53 constexpr TfLiteType typeToTfLiteType<float>() {
54 return kTfLiteFloat32;
55 }
56 template <>
57 constexpr TfLiteType typeToTfLiteType<unsigned char>() {
58 return kTfLiteUInt8;
59 }
60 template <>
61 constexpr TfLiteType typeToTfLiteType<int8_t>() {
62 return kTfLiteInt8;
63 }
64 template <>
65 constexpr TfLiteType typeToTfLiteType<bool>() {
66 return kTfLiteBool;
67 }
68 template <>
69 constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
70 return kTfLiteComplex64;
71 }
72 template <>
73 constexpr TfLiteType typeToTfLiteType<string>() {
74 return kTfLiteString;
75 }
76
77 // An interpreter for a graph of nodes that input and output from tensors.
78 // Each node of the graph processes a set of input tensors and produces a
79 // set of output Tensors. All inputs/output tensors are referenced by index.
80 //
81 // Usage:
82 //
83 // -- Create basic model
84 // Interpreter foo(2, 1);
85 // foo.SetTensorParametersReadWrite(0, ...);
86 // foo.SetTensorParametersReadOnly(1, ...);
87 // foo.SetNodeParameters(0, ...)
88 //
89 // -- Resize input array to 1 length.
90 // foo.ResizeInputTensor(0, 1);
91 // foo.AllocateTensors();
92 // -- Install array data
93 // foo.typed_tensor<float>(0)[0] = 3;
94 // foo.Invoke();
95 // foo.typed_tensor<float>(0)[0] = 4;
96 // foo.Invoke();
97 // -- Resize input array and set data.
98 // foo.ResizeInputTensor(0, 2);
99 // foo.AllocateTensors();
100 // foo.typed_tensor<float>(0)[0] = 4;
101 // foo.typed_tensor<float>(0)[1] = 8;
102 // foo.Invoke();
103 //
104
105 class Interpreter {
106 public:
107 // Instantiate an interpreter. All errors associated with reading and
108 // processing this model will be forwarded to the error_reporter object.
109 //
110 // Note, if error_reporter is nullptr, then a default StderrReporter is
111 // used. Ownership of 'error_reporter' remains with the caller.
112 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
113
114 ~Interpreter();
115
116 // Interpreters are not copyable as they have non-trivial memory semantics.
117 Interpreter(const Interpreter&) = delete;
118 Interpreter& operator=(const Interpreter&) = delete;
119
120 // Functions to build interpreter
121
122 // Provide a list of tensor indexes that are inputs to the model.
123 // Each index is bound check and this modifies the consistent_ flag of the
124 // interpreter.
125 TfLiteStatus SetInputs(std::vector<int> inputs);
126
127 // Provide a list of tensor indexes that are outputs to the model
128 // Each index is bound check and this modifies the consistent_ flag of the
129 // interpreter.
130 TfLiteStatus SetOutputs(std::vector<int> outputs);
131
132 // Provide a list of tensor indexes that are variable tensors.
133 // Each index is bound check and this modifies the consistent_ flag of the
134 // interpreter.
135 TfLiteStatus SetVariables(std::vector<int> variables);
136
137 // Ensure the internal node storage memory allocates at least `count`
138 // spots for node. NOTE, this doesn't actually add operators. This is an
139 // efficiency optimization that is subject to change.
140 void ReserveNodes(int count);
141
142 // Adds a node with the given parameters and returns the index of the new
143 // node in `node_index` (optionally). Interpreter will take ownership of
144 // `builtin_data` and destroy it with `free`. Ownership of 'init_data'
145 // remains with the caller.
146 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
147 const std::vector<int>& outputs,
148 const char* init_data,
149 size_t init_data_size, void* builtin_data,
150 const TfLiteRegistration* registration,
151 int* node_index = nullptr);
152
153 // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
154 // The value pointed to by `first_new_tensor_index` will be set to the
155 // index of the first new tensor if `first_new_tensor_index` is non-null.
156 TfLiteStatus AddTensors(int tensors_to_add,
157 int* first_new_tensor_index = nullptr);
158
159 // Set description of inputs/outputs/data/fptrs for node `node_index`.
160 // This variant assumes an external buffer has been allocated of size
161 // bytes. The lifetime of buffer must be ensured to be greater or equal
162 // to Interpreter.
163 TfLiteStatus SetTensorParametersReadOnly(
164 int tensor_index, TfLiteType type, const char* name,
165 const std::vector<int>& dims, TfLiteQuantization quantization,
166 const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
167
168 // Legacy. Deprecated in favor of above.
169 inline TfLiteStatus SetTensorParametersReadOnly(
170 int tensor_index, TfLiteType type, const char* name,
171 const std::vector<int>& dims, TfLiteQuantizationParams quantization,
172 const char* buffer, size_t bytes,
173 const Allocation* allocation = nullptr) {
174 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
175 dims.data(), quantization, buffer, bytes,
176 allocation);
177 }
178
179 TfLiteStatus SetTensorParametersReadOnly(
180 int tensor_index, TfLiteType type, const char* name, const size_t rank,
181 const int* dims, TfLiteQuantizationParams quantization,
182 const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
183
184 // Set description of inputs/outputs/data/fptrs for node `node_index`.
185 // This variant assumes an external buffer has been allocated of size
186 // bytes. The lifetime of buffer must be ensured to be greater or equal
187 // to Interpreter.
188 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
189 const char* name,
190 const std::vector<int>& dims,
191 TfLiteQuantization quantization,
192 bool is_variable = false);
193
194 // Legacy. Deprecated in favor of above.
195 inline TfLiteStatus SetTensorParametersReadWrite(
196 int tensor_index, TfLiteType type, const char* name,
197 const std::vector<int>& dims, TfLiteQuantizationParams quantization,
198 bool is_variable = false) {
199 return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
200 dims.data(), quantization, is_variable);
201 }
202 TfLiteStatus SetTensorParametersReadWrite(
203 int tensor_index, TfLiteType type, const char* name, const size_t rank,
204 const int* dims, TfLiteQuantizationParams quantization,
205 bool is_variable = false);
206
207 // Functions to access tensor data
208
209 // Read only access to list of inputs.
inputs()210 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
211
212 // Return the name of a given input. The given index must be between 0 and
213 // inputs().size().
GetInputName(int index)214 const char* GetInputName(int index) const {
215 return context_->tensors[inputs()[index]].name;
216 }
217
218 // Read only access to list of outputs.
outputs()219 const std::vector<int>& outputs() const {
220 return primary_subgraph().outputs();
221 }
222
223 // Read only access to list of variable tensors.
variables()224 const std::vector<int>& variables() const {
225 return primary_subgraph().variables();
226 }
227
228 // Return the name of a given output. The given index must be between 0 and
229 // outputs().size().
GetOutputName(int index)230 const char* GetOutputName(int index) const {
231 return context_->tensors[outputs()[index]].name;
232 }
233
234 // Return the number of tensors in the model.
tensors_size()235 size_t tensors_size() const { return context_->tensors_size; }
236
237 // Return the number of ops in the model.
nodes_size()238 size_t nodes_size() const { return primary_subgraph().nodes_size(); }
239
240 // WARNING: Experimental interface, subject to change
execution_plan()241 const std::vector<int>& execution_plan() const {
242 return primary_subgraph().execution_plan();
243 }
244
245 // WARNING: Experimental interface, subject to change
246 // Overrides execution plan. This bounds checks indices sent in.
247 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
248
249 // Get a mutable tensor data structure.
250 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
251 // read/write access to structure
tensor(int tensor_index)252 TfLiteTensor* tensor(int tensor_index) {
253 return primary_subgraph().tensor(tensor_index);
254 }
255
256 // Get an immutable tensor data structure.
tensor(int tensor_index)257 const TfLiteTensor* tensor(int tensor_index) const {
258 return primary_subgraph().tensor(tensor_index);
259 }
260
261 // Get a pointer to an operation and registration data structure if in bounds.
node_and_registration(int node_index)262 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
263 int node_index) const {
264 return primary_subgraph().node_and_registration(node_index);
265 }
266
267 // Perform a checked cast to the appropriate tensor type (mutable pointer
268 // version).
269 template <class T>
typed_tensor(int tensor_index)270 T* typed_tensor(int tensor_index) {
271 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
272 if (tensor_ptr->type == typeToTfLiteType<T>()) {
273 return reinterpret_cast<T*>(tensor_ptr->data.raw);
274 }
275 }
276 return nullptr;
277 }
278
279 // Perform a checked cast to the appropriate tensor type (immutable pointer
280 // version).
281 template <class T>
typed_tensor(int tensor_index)282 const T* typed_tensor(int tensor_index) const {
283 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
284 if (tensor_ptr->type == typeToTfLiteType<T>()) {
285 return reinterpret_cast<const T*>(tensor_ptr->data.raw);
286 }
287 }
288 return nullptr;
289 }
290
291 // Return a mutable pointer into the data of a given input tensor. The given
292 // index must be between 0 and inputs().size().
293 template <class T>
typed_input_tensor(int index)294 T* typed_input_tensor(int index) {
295 return typed_tensor<T>(inputs()[index]);
296 }
297
298 // Return an immutable pointer into the data of a given input tensor. The
299 // given index must be between 0 and inputs().size().
300 template <class T>
typed_input_tensor(int index)301 const T* typed_input_tensor(int index) const {
302 return typed_tensor<T>(inputs()[index]);
303 }
304
305 // Return a mutable pointer into the data of a given output tensor. The given
306 // index must be between 0 and outputs().size().
307 template <class T>
typed_output_tensor(int index)308 T* typed_output_tensor(int index) {
309 return typed_tensor<T>(outputs()[index]);
310 }
311
312 // Return an immutable pointer into the data of a given output tensor. The
313 // given index must be between 0 and outputs().size().
314 template <class T>
typed_output_tensor(int index)315 const T* typed_output_tensor(int index) const {
316 return typed_tensor<T>(outputs()[index]);
317 }
318
319 // Change the dimensionality of a given tensor. Note, this is only acceptable
320 // for tensor indices that are inputs.
321 // Returns status of failure or success.
322 // TODO(aselle): Consider implementing ArraySlice equivalent to make this
323 // more adept at accepting data without an extra copy. Use absl::ArraySlice
324 // if our partners determine that dependency is acceptable.
325 TfLiteStatus ResizeInputTensor(int tensor_index,
326 const std::vector<int>& dims);
327
328 // Update allocations for all tensors. This will redim dependent tensors using
329 // the input tensor dimensionality as given. This is relatively expensive.
330 // If you know that your sizes are not changing, you need not call this.
331 // Returns status of success or failure.
332 TfLiteStatus AllocateTensors();
333
334 // Invoke the interpreter (run the whole graph in dependency order).
335 //
336 // NOTE: It is possible that the interpreter is not in a ready state
337 // to evaluate (i.e. if a ResizeTensor() has been performed without an
338 // AllocateTensors().
339 // Returns status of success or failure.
340 TfLiteStatus Invoke();
341
342 // Enable or disable the NN API (true to enable)
343 void UseNNAPI(bool enable);
344
345 // Set the number of threads available to the interpreter.
346 void SetNumThreads(int num_threads);
347
348 // Allow float16 precision for FP32 calculation when possible.
349 // default: not allow.
350 // WARNING: This is an experimental API and subject to change.
351 void SetAllowFp16PrecisionForFp32(bool allow);
352
353 // Get the half precision flag.
354 // WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()355 bool GetAllowFp16PrecisionForFp32() const {
356 return context_->allow_fp32_relax_to_fp16;
357 }
358
359 // Sets the cancellation function pointer in order to cancel a request in the
360 // middle of a call to Invoke(). The interpreter queries this function during
361 // inference, between op invocations; when it returns true, the interpreter
362 // will abort execution and return `kTfLiteError`. The `data` parameter
363 // contains any data used by the cancellation function, and if non-null,
364 // remains owned by the caller.
365 // WARNING: This is an experimental API and subject to change.
366 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
367
368 // Owning handle to a TfLiteDelegate instance.
369 using TfLiteDelegatePtr =
370 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
371
372 // Allow a delegate to look at the graph and modify the graph to handle
373 // parts of the graph themselves. After this is called, the graph may
374 // contain new nodes that replace 1 more nodes.
375 // WARNING: This is an experimental API and subject to change.
376 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
377
378 // Ensure the data in `tensor.data` is readable. In case delegate is used,
379 // it might require to copy the data from delegate buffer to raw memory.
380 // WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)381 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
382 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
383 }
384
385 // Set the delegate buffer handle to a tensor. It can be called in the
386 // following cases:
387 // 1. Set the buffer handle to a tensor that's not being written by a
388 // delegate. For example, feeding an OpenGL texture as the input of the
389 // inference graph.
390 // 2. Set the buffer handle to a tensor that uses the same delegate.
391 // For example, set an OpenGL texture as the output of inference, while
392 // the node which produces output is an OpenGL delegate node.
393 // WARNING: This is an experimental API and subject to change.
394 TfLiteStatus SetBufferHandle(int tensor_index,
395 TfLiteBufferHandle buffer_handle,
396 TfLiteDelegate* delegate);
397
398 // Get the delegate buffer handle, and the delegate which can process the
399 // buffer handle.
400 // WARNING: This is an experimental API and subject to change.
401 TfLiteStatus GetBufferHandle(int tensor_index,
402 TfLiteBufferHandle* buffer_handle,
403 TfLiteDelegate** delegate);
404
405 void SetProfiler(profiling::Profiler* profiler);
406
407 profiling::Profiler* GetProfiler();
408
409 // The default capacity of `tensors_` vector.
410 static constexpr int kTensorsReservedCapacity = 128;
411 // The capacity headroom of `tensors_` vector before calling ops'
412 // `prepare` and `invoke` function. In these functions, it's guaranteed
413 // allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
414 // pointers to existing tensors.
415 static constexpr int kTensorsCapacityHeadroom = 16;
416
417 // Set if buffer handle output is allowed.
418 //
419 // When using hardware delegation, Interpreter will make the data of output
420 // tensors available in `tensor->data` by default. If the application can
421 // consume the buffer handle directly (e.g. reading output from OpenGL
422 // texture), it can set this flag to false, so Interpreter won't copy the data
423 // from buffer handle to CPU memory.
424 // WARNING: This is an experimental API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)425 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
426 allow_buffer_handle_output_ = allow_buffer_handle_output;
427 }
428
429 // Reset all variable tensors to the default value.
430 // If a variable tensor doesn't have a buffer, reset it to zero.
431 // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
432 // to the value of the buffer.
433 // WARNING: This is an experimental API and subject to change.
434 TfLiteStatus ResetVariableTensors();
435
436 // Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)437 const char* OpProfilingString(const TfLiteRegistration& op_reg,
438 const TfLiteNode* node) const {
439 if (op_reg.profiling_string == nullptr) return nullptr;
440 return op_reg.profiling_string(context_, node);
441 }
442
443 // Set the value of an external context.
444 void SetExternalContext(TfLiteExternalContextType type,
445 TfLiteExternalContext* ctx);
446
447 // Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
448 // entries. The value pointed to by `first_new_subgraph_index` will be set to
449 // the index of the first new subgraph if `first_new_subgraph_index` is
450 // non-null.
451 // WARNING: This is an experimental API and subject to change.
452 void AddSubgraphs(int subgraphs_to_add,
453 int* first_new_subgraph_index = nullptr);
454
455 // Return the number of subgraphs in the model.
456 // WARNING: This is an experimental API and subject to change.
subgraphs_size()457 size_t subgraphs_size() const { return subgraphs_.size(); }
458
459 // Get a pointer to a subgraph if in bounds.
460 // WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)461 Subgraph* subgraph(int subgraph_index) {
462 if (subgraph_index < 0 ||
463 static_cast<size_t>(subgraph_index) >= subgraphs_size())
464 return nullptr;
465 return &*subgraphs_[subgraph_index];
466 }
467
468 // WARNING: Experimental interface, subject to change
primary_subgraph()469 Subgraph& primary_subgraph() {
470 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
471 }
472
473 // WARNING: Experimental interface, subject to change
primary_subgraph()474 const Subgraph& primary_subgraph() const {
475 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
476 }
477
478 private:
479 friend class InterpreterBuilder;
480 friend class InterpreterTest;
481
482 // Set the value of an external context.
483 static void SetExternalContext(struct TfLiteContext* context,
484 TfLiteExternalContextType type,
485 TfLiteExternalContext* ctx);
486
487 // Variant of the public ModifyGraphWithDelegate method that additionally
488 // Assumes ownership of the provided delegate.
489 // WARNING: This is an experimental API and subject to change.
ModifyGraphWithDelegate(TfLiteDelegatePtr delegate)490 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) {
491 // Note that we retain ownership of the delegate even if graph modification
492 // fails, as delegate use will be in an indeterminate state at that point.
493 owned_delegates_.push_back(std::move(delegate));
494 return ModifyGraphWithDelegate(owned_delegates_.back().get());
495 }
496
497 // A pure C data structure used to communicate with the pure C plugin
498 // interface. To avoid copying tensor metadata, this is also the definitive
499 // structure to store tensors.
500 // This is the primary subgraph context.
501 TfLiteContext* context_;
502
503 // The error reporter delegate that tflite will forward queries errors to.
504 ErrorReporter* error_reporter_;
505
506 // List of delegates that have been installed and are owned by this
507 // interpreter instance. Useful if client delegate ownership is burdensome.
508 // WARNING: This is an experimental API and subject to change.
509 // TODO(b/116667551): Use TfLiteExternalContext for storing state.
510 std::vector<TfLiteDelegatePtr> owned_delegates_;
511
512 bool allow_buffer_handle_output_ = false;
513
514 // List of active external contexts.
515 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
516
517 // Subgraphs
518 std::vector<std::unique_ptr<Subgraph>> subgraphs_;
519 };
520
521 } // namespace tflite
522 #endif // TENSORFLOW_LITE_INTERPRETER_H_
523