1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Main abstraction controlling the tflite interpreter. 16 // See context.h for the API for defining operations (TfLiteRegistration). 17 #ifndef TENSORFLOW_LITE_INTERPRETER_H_ 18 #define TENSORFLOW_LITE_INTERPRETER_H_ 19 20 #include <complex> 21 #include <cstdio> 22 #include <cstdlib> 23 #include <memory> 24 #include <vector> 25 26 #include "tensorflow/lite/allocation.h" 27 #include "tensorflow/lite/c/common.h" // IWYU pragma: export 28 #include "tensorflow/lite/core/api/error_reporter.h" 29 #include "tensorflow/lite/core/api/profiler.h" 30 #include "tensorflow/lite/core/subgraph.h" 31 #include "tensorflow/lite/experimental/resource/resource_base.h" 32 #include "tensorflow/lite/external_cpu_backend_context.h" 33 #include "tensorflow/lite/memory_planner.h" 34 #include "tensorflow/lite/stderr_reporter.h" 35 #include "tensorflow/lite/type_to_tflitetype.h" 36 37 namespace tflite { 38 39 /// An interpreter for a graph of nodes that input and output from tensors. 40 /// Each node of the graph processes a set of input tensors and produces a 41 /// set of output Tensors. All inputs/output tensors are referenced by index. 42 /// 43 /// Usage: 44 /// 45 /// <pre><code> 46 /// // Create basic model 47 /// Interpreter foo(2, 1); 48 /// foo.SetTensorParametersReadWrite(0, ...); 49 /// foo.SetTensorParametersReadOnly(1, ...); 50 /// foo.SetNodeParameters(0, ...) 51 /// // Resize input array to 1 length. 52 /// foo.ResizeInputTensor(0, 1); 53 /// foo.AllocateTensors(); 54 /// // Install array data 55 /// foo.typed_tensor<float>(0)[0] = 3; 56 /// foo.Invoke(); 57 /// foo.typed_tensor<float>(0)[0] = 4; 58 /// foo.Invoke(); 59 /// // Resize input array and set data. 60 /// foo.ResizeInputTensor(0, 2); 61 /// foo.AllocateTensors(); 62 /// foo.typed_tensor<float>(0)[0] = 4; 63 /// foo.typed_tensor<float>(0)[1] = 8; 64 /// foo.Invoke(); 65 /// </code></pre> 66 /// 67 68 class Interpreter { 69 public: 70 /// Instantiate an interpreter. All errors associated with reading and 71 /// processing this model will be forwarded to the error_reporter object. 72 // 73 /// Note, if error_reporter is nullptr, then a default StderrReporter is 74 /// used. Ownership of 'error_reporter' remains with the caller. 75 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); 76 77 ~Interpreter(); 78 79 // Interpreters are not copyable as they have non-trivial memory semantics. 80 Interpreter(const Interpreter&) = delete; 81 Interpreter& operator=(const Interpreter&) = delete; 82 83 // Functions to build interpreter 84 #ifndef DOXYGEN_SKIP 85 /// Provide a list of tensor indexes that are inputs to the model. 86 /// Each index is bound check and this modifies the consistent_ flag of the 87 /// interpreter. 88 TfLiteStatus SetInputs(std::vector<int> inputs); 89 90 /// Provide a list of tensor indexes that are outputs to the model 91 /// Each index is bound check and this modifies the consistent_ flag of the 92 /// interpreter. 93 TfLiteStatus SetOutputs(std::vector<int> outputs); 94 95 /// Provide a list of tensor indexes that are variable tensors. 96 /// Each index is bound check and this modifies the consistent_ flag of the 97 /// interpreter. 98 TfLiteStatus SetVariables(std::vector<int> variables); 99 100 /// Ensure the internal node storage memory allocates at least `count` 101 /// spots for node. NOTE, this doesn't actually add operators. This is an 102 /// efficiency optimization that is subject to change. 103 void ReserveNodes(int count); 104 105 /// Adds a node with the given parameters and returns the index of the new 106 /// node in `node_index` (optionally). Interpreter will take ownership of 107 /// `builtin_data` and destroy it with `free`. Ownership of 'init_data' 108 /// remains with the caller. 109 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, 110 const std::vector<int>& outputs, 111 const char* init_data, 112 size_t init_data_size, void* builtin_data, 113 const TfLiteRegistration* registration, 114 int* node_index = nullptr); 115 116 /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. 117 /// The value pointed to by `first_new_tensor_index` will be set to the 118 /// index of the first new tensor if `first_new_tensor_index` is non-null. 119 TfLiteStatus AddTensors(int tensors_to_add, 120 int* first_new_tensor_index = nullptr); 121 122 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 123 /// This variant assumes an external buffer has been allocated of size 124 /// bytes. The lifetime of buffer must be ensured to be greater or equal 125 /// to Interpreter. 126 TfLiteStatus SetTensorParametersReadOnly( 127 int tensor_index, TfLiteType type, const char* name, 128 const std::vector<int>& dims, TfLiteQuantization quantization, 129 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 130 131 /// Legacy. Deprecated in favor of above. 132 inline TfLiteStatus SetTensorParametersReadOnly( 133 int tensor_index, TfLiteType type, const char* name, 134 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 135 const char* buffer, size_t bytes, 136 const Allocation* allocation = nullptr) { 137 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), 138 dims.data(), quantization, buffer, bytes, 139 allocation); 140 } 141 142 TfLiteStatus SetTensorParametersReadOnly( 143 int tensor_index, TfLiteType type, const char* name, const size_t rank, 144 const int* dims, TfLiteQuantizationParams quantization, 145 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 146 147 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 148 /// This variant assumes an external buffer has been allocated of size 149 /// bytes. The lifetime of buffer must be ensured to be greater or equal 150 /// to Interpreter. 151 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, 152 const char* name, 153 const std::vector<int>& dims, 154 TfLiteQuantization quantization, 155 bool is_variable = false); 156 157 /// Legacy. Deprecated in favor of above. 158 inline TfLiteStatus SetTensorParametersReadWrite( 159 int tensor_index, TfLiteType type, const char* name, 160 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 161 bool is_variable = false) { 162 return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), 163 dims.data(), quantization, is_variable); 164 } 165 TfLiteStatus SetTensorParametersReadWrite( 166 int tensor_index, TfLiteType type, const char* name, const size_t rank, 167 const int* dims, TfLiteQuantizationParams quantization, 168 bool is_variable = false); 169 #endif // DOXYGEN_SKIP 170 // Functions to access tensor data 171 172 /// Read only access to list of inputs. inputs()173 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); } 174 175 /// Return the name of a given input. The given index must be between 0 and 176 /// inputs().size(). GetInputName(int index)177 const char* GetInputName(int index) const { 178 return context_->tensors[inputs()[index]].name; 179 } 180 181 /// Read only access to list of outputs. outputs()182 const std::vector<int>& outputs() const { 183 return primary_subgraph().outputs(); 184 } 185 186 /// Read only access to list of variable tensors. variables()187 const std::vector<int>& variables() const { 188 return primary_subgraph().variables(); 189 } 190 191 /// Return the name of a given output. The given index must be between 0 and 192 /// outputs().size(). GetOutputName(int index)193 const char* GetOutputName(int index) const { 194 return context_->tensors[outputs()[index]].name; 195 } 196 197 /// Return the number of tensors in the model. tensors_size()198 size_t tensors_size() const { return context_->tensors_size; } 199 200 /// Return the number of ops in the model. nodes_size()201 size_t nodes_size() const { return primary_subgraph().nodes_size(); } 202 203 /// WARNING: Experimental interface, subject to change execution_plan()204 const std::vector<int>& execution_plan() const { 205 return primary_subgraph().execution_plan(); 206 } 207 208 #ifndef DOXYGEN_ 209 /// WARNING: Experimental interface, subject to change 210 /// Overrides execution plan. This bounds checks indices sent in. 211 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); 212 #endif // DOXYGEN_SKIP 213 214 /// Get a mutable tensor data structure. 215 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this 216 // read/write access to structure tensor(int tensor_index)217 TfLiteTensor* tensor(int tensor_index) { 218 return primary_subgraph().tensor(tensor_index); 219 } 220 221 /// Get an immutable tensor data structure. tensor(int tensor_index)222 const TfLiteTensor* tensor(int tensor_index) const { 223 return primary_subgraph().tensor(tensor_index); 224 } 225 226 /// Get a pointer to an operation and registration data structure if in 227 /// bounds. node_and_registration(int node_index)228 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 229 int node_index) const { 230 return primary_subgraph().node_and_registration(node_index); 231 } 232 233 /// Perform a checked cast to the appropriate tensor type (mutable pointer 234 /// version). 235 template <class T> typed_tensor(int tensor_index)236 T* typed_tensor(int tensor_index) { 237 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 238 if (tensor_ptr->type == typeToTfLiteType<T>()) { 239 return reinterpret_cast<T*>(tensor_ptr->data.raw); 240 } 241 } 242 return nullptr; 243 } 244 245 /// Perform a checked cast to the appropriate tensor type (immutable pointer 246 /// version). 247 template <class T> typed_tensor(int tensor_index)248 const T* typed_tensor(int tensor_index) const { 249 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 250 if (tensor_ptr->type == typeToTfLiteType<T>()) { 251 return reinterpret_cast<const T*>(tensor_ptr->data.raw); 252 } 253 } 254 return nullptr; 255 } 256 257 /// Return a mutable pointer to the given input tensor. The given index must 258 /// be between 0 and inputs().size(). input_tensor(size_t index)259 TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } 260 261 /// Return an immutable pointerto the given input tensor. The given index must 262 /// be between 0 and inputs().size(). input_tensor(size_t index)263 const TfLiteTensor* input_tensor(size_t index) const { 264 return tensor(inputs()[index]); 265 } 266 267 /// Return a mutable pointer into the data of a given input tensor. The given 268 /// index must be between 0 and inputs().size(). 269 template <class T> typed_input_tensor(int index)270 T* typed_input_tensor(int index) { 271 return typed_tensor<T>(inputs()[index]); 272 } 273 274 /// Return an immutable pointer into the data of a given input tensor. The 275 /// given index must be between 0 and inputs().size(). 276 template <class T> typed_input_tensor(int index)277 const T* typed_input_tensor(int index) const { 278 return typed_tensor<T>(inputs()[index]); 279 } 280 281 /// Return a mutable pointer to the given output tensor. The given index must 282 /// be between 0 and outputs().size(). output_tensor(size_t index)283 TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); } 284 285 /// Return an immutable pointer to the given output tensor. The given index 286 /// must be between 0 and outputs().size(). output_tensor(size_t index)287 const TfLiteTensor* output_tensor(size_t index) const { 288 return tensor(outputs()[index]); 289 } 290 291 /// Return a mutable pointer into the data of a given output tensor. The given 292 /// index must be between 0 and outputs().size(). 293 template <class T> typed_output_tensor(int index)294 T* typed_output_tensor(int index) { 295 return typed_tensor<T>(outputs()[index]); 296 } 297 298 /// Return an immutable pointer into the data of a given output tensor. The 299 /// given index must be between 0 and outputs().size(). 300 template <class T> typed_output_tensor(int index)301 const T* typed_output_tensor(int index) const { 302 return typed_tensor<T>(outputs()[index]); 303 } 304 305 /// Change the dimensionality of a given tensor. Note, this is only acceptable 306 /// for tensor indices that are inputs or variables. 307 /// Returns status of failure or success. 308 /// TODO(aselle): Consider implementing ArraySlice equivalent to make this 309 /// more adept at accepting data without an extra copy. Use absl::ArraySlice 310 /// if our partners determine that dependency is acceptable. 311 TfLiteStatus ResizeInputTensor(int tensor_index, 312 const std::vector<int>& dims); 313 314 // This releases memory held by non-persistent tensors. It does NOT re-perform 315 // memory planning. 316 // AllocateTensors needs to be called before next invocation. 317 /// WARNING: Experimental interface, subject to change 318 TfLiteStatus ReleaseNonPersistentMemory(); 319 320 /// Update allocations for all tensors. This will redim dependent tensors 321 /// using the input tensor dimensionality as given. This is relatively 322 /// expensive. If you know that your sizes are not changing, you need not call 323 /// this. Returns status of success or failure. 324 TfLiteStatus AllocateTensors(); 325 326 /// Invoke the interpreter (run the whole graph in dependency order). 327 /// 328 /// NOTE: It is possible that the interpreter is not in a ready state 329 /// to evaluate (i.e. if a ResizeTensor() has been performed without an 330 /// AllocateTensors(). 331 /// Returns status of success or failure. 332 TfLiteStatus Invoke(); 333 334 /// Enable or disable the NN API (true to enable) 335 void UseNNAPI(bool enable); 336 337 /// Set the number of threads available to the interpreter. 338 void SetNumThreads(int num_threads); 339 340 /// Allow float16 precision for FP32 calculation when possible. 341 /// default: not allow. 342 /// WARNING: This is an experimental API and subject to change. 343 void SetAllowFp16PrecisionForFp32(bool allow); 344 345 /// Get the half precision flag. 346 /// WARNING: This is an experimental API and subject to change. GetAllowFp16PrecisionForFp32()347 bool GetAllowFp16PrecisionForFp32() const { 348 return context_->allow_fp32_relax_to_fp16; 349 } 350 351 /// Sets the cancellation function pointer in order to cancel a request in the 352 /// middle of a call to Invoke(). The interpreter queries this function during 353 /// inference, between op invocations; when it returns true, the interpreter 354 /// will abort execution and return `kTfLiteError`. The `data` parameter 355 /// contains any data used by the cancellation function, and if non-null, 356 /// remains owned by the caller. 357 /// WARNING: This is an experimental API and subject to change. 358 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*)); 359 360 /// Allow a delegate to look at the graph and modify the graph to handle 361 /// parts of the graph themselves. After this is called, the graph may 362 /// contain new nodes that replace 1 more nodes. 363 /// 'delegate' must outlive the interpreter. 364 /// WARNING: This is an experimental API and subject to change. 365 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); 366 367 // Owning handle to a TfLiteDelegate instance. 368 using TfLiteDelegatePtr = 369 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 370 371 /// Same as ModifyGraphWithDelegate except this interpreter takes 372 /// ownership of the provided delegate. Be sure to construct the unique_ptr 373 /// with a suitable destruction function. 374 /// WARNING: This is an experimental API and subject to change. 375 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate); 376 377 /// Ensure the data in `tensor.data` is readable. In case delegate is used, 378 /// it might require to copy the data from delegate buffer to raw memory. 379 /// WARNING: This is an experimental API and subject to change. EnsureTensorDataIsReadable(int tensor_index)380 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { 381 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); 382 } 383 384 /// Set the delegate buffer handle to a tensor. It can be called in the 385 /// following cases: 386 /// 1. Set the buffer handle to a tensor that's not being written by a 387 /// delegate. For example, feeding an OpenGL texture as the input of the 388 /// inference graph. 389 /// 2. Set the buffer handle to a tensor that uses the same delegate. 390 /// For example, set an OpenGL texture as the output of inference, while 391 /// the node which produces output is an OpenGL delegate node. 392 /// WARNING: This is an experimental API and subject to change. 393 TfLiteStatus SetBufferHandle(int tensor_index, 394 TfLiteBufferHandle buffer_handle, 395 TfLiteDelegate* delegate); 396 397 /// Get the delegate buffer handle, and the delegate which can process the 398 /// buffer handle. 399 /// WARNING: This is an experimental API and subject to change. 400 TfLiteStatus GetBufferHandle(int tensor_index, 401 TfLiteBufferHandle* buffer_handle, 402 TfLiteDelegate** delegate); 403 404 /// Sets the profiler to tracing execution. The caller retains ownership 405 /// of the profiler and must ensure its validity. 406 /// WARNING: This is an experimental API and subject to change. 407 void SetProfiler(Profiler* profiler); 408 409 /// Gets the profiler used for op tracing. 410 /// WARNING: This is an experimental API and subject to change. 411 Profiler* GetProfiler(); 412 413 // The default capacity of `tensors_` vector. 414 static constexpr int kTensorsReservedCapacity = 128; 415 /// The capacity headroom of `tensors_` vector before calling ops' 416 /// `prepare` and `invoke` function. In these functions, it's guaranteed 417 /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate 418 /// pointers to existing tensors. 419 static constexpr int kTensorsCapacityHeadroom = 16; 420 421 /// Set if buffer handle output is allowed. 422 // 423 /// When using hardware delegation, Interpreter will make the data of output 424 /// tensors available in `tensor->data` by default. If the application can 425 /// consume the buffer handle directly (e.g. reading output from OpenGL 426 /// texture), it can set this flag to false, so Interpreter won't copy the 427 /// data from buffer handle to CPU memory. WARNING: This is an experimental 428 /// API and subject to change. SetAllowBufferHandleOutput(bool allow_buffer_handle_output)429 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { 430 allow_buffer_handle_output_ = allow_buffer_handle_output; 431 } 432 433 /// Reset all variable tensors to the default value. 434 /// If a variable tensor doesn't have a buffer, reset it to zero. 435 /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it 436 /// to the value of the buffer. 437 /// WARNING: This is an experimental API and subject to change. 438 TfLiteStatus ResetVariableTensors(); 439 440 /// Retrieve an operator's description of its work, for profiling purposes. OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)441 const char* OpProfilingString(const TfLiteRegistration& op_reg, 442 const TfLiteNode* node) const { 443 if (op_reg.profiling_string == nullptr) return nullptr; 444 return op_reg.profiling_string(context_, node); 445 } 446 447 // Set the value of an external context. TFLite interpreter doesn't take the 448 // memory ownership of this external context 'ctx', and the context should 449 // outlive the TFLite interpreter. 450 void SetExternalContext(TfLiteExternalContextType type, 451 TfLiteExternalContext* ctx); 452 453 #ifndef DOXYGEN_SKIP 454 /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph 455 /// entries. The value pointed to by `first_new_subgraph_index` will be set to 456 /// the index of the first new subgraph if `first_new_subgraph_index` is 457 /// non-null. 458 /// WARNING: This is an experimental API and subject to change. 459 void AddSubgraphs(int subgraphs_to_add, 460 int* first_new_subgraph_index = nullptr); 461 462 /// Return the number of subgraphs in the model. 463 /// WARNING: This is an experimental API and subject to change. subgraphs_size()464 size_t subgraphs_size() const { return subgraphs_.size(); } 465 466 /// Get a pointer to a subgraph if in bounds. 467 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)468 Subgraph* subgraph(int subgraph_index) { 469 if (subgraph_index < 0 || 470 static_cast<size_t>(subgraph_index) >= subgraphs_size()) 471 return nullptr; 472 return &*subgraphs_[subgraph_index]; 473 } 474 475 /// WARNING: Experimental interface, subject to change primary_subgraph()476 Subgraph& primary_subgraph() { 477 return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry. 478 } 479 480 /// WARNING: Experimental interface, subject to change primary_subgraph()481 const Subgraph& primary_subgraph() const { 482 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. 483 } 484 #endif // DOXYGEN_SKIP 485 486 private: 487 friend class InterpreterBuilder; 488 friend class InterpreterTest; 489 490 /// Set the value of an external context. 491 static void SetExternalContext(struct TfLiteContext* context, 492 TfLiteExternalContextType type, 493 TfLiteExternalContext* ctx); 494 495 // A pure C data structure used to communicate with the pure C plugin 496 // interface. To avoid copying tensor metadata, this is also the definitive 497 // structure to store tensors. 498 // This is the primary subgraph context. 499 TfLiteContext* context_; 500 501 // The error reporter delegate that tflite will forward queries errors to. 502 ErrorReporter* error_reporter_; 503 504 // List of delegates that have been installed and are owned by this 505 // interpreter instance. Useful if client delegate ownership is burdensome. 506 // WARNING: This is an experimental API and subject to change. 507 // TODO(b/116667551): Use TfLiteExternalContext for storing state. 508 std::vector<TfLiteDelegatePtr> owned_delegates_; 509 510 bool allow_buffer_handle_output_ = false; 511 512 // List of active external contexts. 513 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; 514 515 // The default external cpu backend context. After an TFLite interpreter is 516 // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point 517 // to this object. However, if this element value is overwritten via calling 518 // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to 519 // nullptr if necessary. 520 std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_; 521 522 // Subgraphs 523 std::vector<std::unique_ptr<Subgraph>> subgraphs_; 524 525 // A map of resources. Owned by interpreter and shared by multiple subgraphs. 526 resource::ResourceMap resources_; 527 }; 528 529 } // namespace tflite 530 #endif // TENSORFLOW_LITE_INTERPRETER_H_ 531