1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// \file 16 /// Main abstraction controlling the tflite interpreter. 17 /// See context.h for the API for defining operations (TfLiteRegistration). 18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_ 19 #define TENSORFLOW_LITE_INTERPRETER_H_ 20 21 #include <stddef.h> 22 #include <stdint.h> 23 24 #include <complex> 25 #include <cstdio> 26 #include <cstdlib> 27 #include <functional> 28 #include <map> 29 #include <memory> 30 #include <string> 31 #include <utility> 32 #include <vector> 33 34 #include "tensorflow/lite/allocation.h" 35 #include "tensorflow/lite/c/common.h" // IWYU pragma: export 36 #include "tensorflow/lite/core/api/error_reporter.h" 37 #include "tensorflow/lite/core/api/profiler.h" 38 #include "tensorflow/lite/core/subgraph.h" 39 #include "tensorflow/lite/experimental/resource/resource_base.h" 40 #include "tensorflow/lite/external_cpu_backend_context.h" 41 #include "tensorflow/lite/memory_planner.h" 42 #include "tensorflow/lite/portable_type_to_tflitetype.h" 43 #include "tensorflow/lite/stderr_reporter.h" 44 #include "tensorflow/lite/string_type.h" 45 #include "tensorflow/lite/type_to_tflitetype.h" 46 47 namespace tflite { 48 49 class InterpreterTest; // Class for friend declarations. 50 51 namespace delegates { 52 class InterpreterUtils; // Class for friend declarations. 53 54 namespace test_utils { 55 class TestDelegate; // Class for friend declarations. 56 } // namespace test_utils 57 } // namespace delegates 58 59 /// An interpreter for a graph of nodes that input and output from tensors. 60 /// Each node of the graph processes a set of input tensors and produces a 61 /// set of output Tensors. All inputs/output tensors are referenced by index. 62 /// 63 /// Usage: 64 /// 65 /// <pre><code> 66 /// // Create model from file. Note that the model instance must outlive the 67 /// // interpreter instance. 68 /// auto model = tflite::FlatBufferModel::BuildFromFile(...); 69 /// if (model == nullptr) { 70 /// // Return error. 71 /// } 72 /// // Create an Interpreter with an InterpreterBuilder. 73 /// std::unique_ptr<tflite::Interpreter> interpreter; 74 /// tflite::ops::builtin::BuiltinOpResolver resolver; 75 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { 76 /// // Return failure. 77 /// } 78 /// interpreter->AllocateTensors(); 79 /// 80 /// auto input = interpreter->typed_tensor<float>(0); 81 /// for (int i = 0; i < input_size; i++) { 82 /// input[i] = ...; 83 // } 84 /// interpreter.Invoke(); 85 /// </code></pre> 86 /// 87 /// Note: for nearly all practical use cases, one should not directly construct 88 /// an Interpreter object, but rather use the InterpreterBuilder. 89 90 class Interpreter { 91 public: 92 // Instantiate an interpreter. All errors associated with reading and 93 // processing this model will be forwarded to the error_reporter object. 94 // 95 // Note, if error_reporter is nullptr, then a default StderrReporter is 96 // used. Ownership of 'error_reporter' remains with the caller. 97 // WARNING: Use of this constructor outside of an InterpreterBuilder is not 98 // recommended. 99 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); 100 101 ~Interpreter(); 102 103 // Interpreters are not copyable as they have non-trivial memory semantics. 104 Interpreter(const Interpreter&) = delete; 105 Interpreter& operator=(const Interpreter&) = delete; 106 107 // Functions to build interpreter 108 #ifndef DOXYGEN_SKIP 109 /// Provide a list of tensor indexes that are inputs to the model. 110 /// Each index is bound check and this modifies the consistent_ flag of the 111 /// interpreter. 112 TfLiteStatus SetInputs(std::vector<int> inputs); 113 114 /// Provide a list of tensor indexes that are outputs to the model 115 /// Each index is bound check and this modifies the consistent_ flag of the 116 /// interpreter. 117 TfLiteStatus SetOutputs(std::vector<int> outputs); 118 119 /// Provide a list of tensor indexes that are variable tensors. 120 /// Each index is bound check and this modifies the consistent_ flag of the 121 /// interpreter. 122 TfLiteStatus SetVariables(std::vector<int> variables); 123 124 /// Ensure the internal node storage memory allocates at least `count` 125 /// spots for node. NOTE, this doesn't actually add operators. This is an 126 /// efficiency optimization that is subject to change. 127 void ReserveNodes(int count); 128 129 /// Adds a node with the given parameters and returns the index of the new 130 /// node in `node_index` (optionally). Interpreter will take ownership of 131 /// `builtin_data` and destroy it with `free`. Ownership of 'init_data' 132 /// remains with the caller. 133 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, 134 const std::vector<int>& outputs, 135 const char* init_data, 136 size_t init_data_size, void* builtin_data, 137 const TfLiteRegistration* registration, 138 int* node_index = nullptr); 139 140 /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. 141 /// The value pointed to by `first_new_tensor_index` will be set to the 142 /// index of the first new tensor if `first_new_tensor_index` is non-null. 143 TfLiteStatus AddTensors(int tensors_to_add, 144 int* first_new_tensor_index = nullptr); 145 146 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 147 /// This variant assumes an external buffer has been allocated of size 148 /// bytes. The lifetime of buffer must be ensured to be greater or equal 149 /// to Interpreter. 150 TfLiteStatus SetTensorParametersReadOnly( 151 int tensor_index, TfLiteType type, const char* name, 152 const std::vector<int>& dims, TfLiteQuantization quantization, 153 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 154 155 /// Legacy. Deprecated in favor of above. 156 inline TfLiteStatus SetTensorParametersReadOnly( 157 int tensor_index, TfLiteType type, const char* name, 158 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 159 const char* buffer, size_t bytes, 160 const Allocation* allocation = nullptr) { 161 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), 162 dims.data(), quantization, buffer, bytes, 163 allocation); 164 } 165 166 TfLiteStatus SetTensorParametersReadOnly( 167 int tensor_index, TfLiteType type, const char* name, const size_t rank, 168 const int* dims, TfLiteQuantizationParams quantization, 169 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 170 171 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 172 /// This variant assumes an external buffer has been allocated of size 173 /// bytes. The lifetime of buffer must be ensured to be greater or equal 174 /// to Interpreter. 175 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, 176 const char* name, 177 const std::vector<int>& dims, 178 TfLiteQuantization quantization, 179 bool is_variable = false); 180 181 /// Legacy. Deprecated in favor of above. 182 inline TfLiteStatus SetTensorParametersReadWrite( 183 int tensor_index, TfLiteType type, const char* name, 184 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 185 bool is_variable = false, 186 const std::vector<int>* dims_signature = nullptr) { 187 size_t rank_dims_signature = 0; 188 const int* dims_signature_pointer = nullptr; 189 if (dims_signature) { 190 rank_dims_signature = dims_signature->size(); 191 dims_signature_pointer = dims_signature->data(); 192 } 193 return SetTensorParametersReadWrite( 194 tensor_index, type, name, dims.size(), dims.data(), quantization, 195 is_variable, rank_dims_signature, dims_signature_pointer); 196 } 197 TfLiteStatus SetTensorParametersReadWrite( 198 int tensor_index, TfLiteType type, const char* name, const size_t rank, 199 const int* dims, TfLiteQuantizationParams quantization, 200 bool is_variable = false, const size_t rank_dims_signature = 0, 201 const int* dims_signature = nullptr); 202 #endif // DOXYGEN_SKIP 203 // Functions to access tensor data 204 205 /// Read only access to list of inputs. inputs()206 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); } 207 208 /// Return the name of a given input. The given index must be between 0 and 209 /// inputs().size(). GetInputName(int index)210 const char* GetInputName(int index) const { 211 return context_->tensors[inputs()[index]].name; 212 } 213 214 /// Read only access to list of outputs. outputs()215 const std::vector<int>& outputs() const { 216 return primary_subgraph().outputs(); 217 } 218 219 /// Read only access to list of variable tensors. variables()220 const std::vector<int>& variables() const { 221 return primary_subgraph().variables(); 222 } 223 224 /// Return the name of a given output. The given index must be between 0 and 225 /// outputs().size(). GetOutputName(int index)226 const char* GetOutputName(int index) const { 227 return context_->tensors[outputs()[index]].name; 228 } 229 230 /// Return the number of tensors in the model. tensors_size()231 size_t tensors_size() const { return context_->tensors_size; } 232 233 /// Return the number of ops in the model. nodes_size()234 size_t nodes_size() const { return primary_subgraph().nodes_size(); } 235 236 /// WARNING: Experimental interface, subject to change execution_plan()237 const std::vector<int>& execution_plan() const { 238 return primary_subgraph().execution_plan(); 239 } 240 241 #ifndef DOXYGEN_ 242 /// WARNING: Experimental interface, subject to change 243 /// Overrides execution plan. This bounds checks indices sent in. 244 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); 245 #endif // DOXYGEN_SKIP 246 247 /// Get a mutable tensor data structure. 248 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this 249 // read/write access to structure tensor(int tensor_index)250 TfLiteTensor* tensor(int tensor_index) { 251 return primary_subgraph().tensor(tensor_index); 252 } 253 254 /// Get an immutable tensor data structure. tensor(int tensor_index)255 const TfLiteTensor* tensor(int tensor_index) const { 256 return primary_subgraph().tensor(tensor_index); 257 } 258 259 /// Get a pointer to an operation and registration data structure if in 260 /// bounds. node_and_registration(int node_index)261 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 262 int node_index) const { 263 return primary_subgraph().node_and_registration(node_index); 264 } 265 266 /// Perform a checked cast to the appropriate tensor type (mutable pointer 267 /// version). 268 template <class T> typed_tensor(int tensor_index)269 T* typed_tensor(int tensor_index) { 270 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 271 if (tensor_ptr->type == typeToTfLiteType<T>()) { 272 return reinterpret_cast<T*>(tensor_ptr->data.raw); 273 } 274 } 275 return nullptr; 276 } 277 278 /// Perform a checked cast to the appropriate tensor type (immutable pointer 279 /// version). 280 template <class T> typed_tensor(int tensor_index)281 const T* typed_tensor(int tensor_index) const { 282 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 283 if (tensor_ptr->type == typeToTfLiteType<T>()) { 284 return reinterpret_cast<const T*>(tensor_ptr->data.raw); 285 } 286 } 287 return nullptr; 288 } 289 290 /// WARNING: Experimental interface, subject to change 291 /// Returns list of all names of different method signatures defined 292 /// in the model. 293 /// Note, pointers returned have lifetime same as the Interpreter object. signature_def_names()294 std::vector<const std::string*> signature_def_names() const { 295 std::vector<const std::string*> method_names; 296 method_names.reserve(signature_defs_.size()); 297 for (const auto& sig_def : signature_defs_) { 298 method_names.emplace_back(&sig_def.method_name); 299 } 300 return method_names; 301 } 302 303 /// WARNING: Experimental interface, subject to change 304 /// Returns the mapping of inputs to tensor index in the signature 305 /// specified through 'method_name'. 306 /// If invalid name passed, an empty list will be returned. signature_inputs(const char * method_name)307 const std::map<std::string, uint32_t>& signature_inputs( 308 const char* method_name) const { 309 for (const auto& sig_def : signature_defs_) { 310 if (sig_def.method_name == method_name) return sig_def.inputs; 311 } 312 static const std::map<std::string, uint32_t>* default_empty_list = 313 new std::map<std::string, uint32_t>(); 314 return *default_empty_list; 315 } 316 317 /// WARNING: Experimental interface, subject to change 318 /// Returns the mapping of outputs to tensor index in the signature 319 /// specified through 'method_name'. 320 /// If invalid name passed, an empty list will be returned. signature_outputs(const char * method_name)321 const std::map<std::string, uint32_t>& signature_outputs( 322 const char* method_name) const { 323 for (const auto& sig_def : signature_defs_) { 324 if (sig_def.method_name == method_name) return sig_def.outputs; 325 } 326 static const std::map<std::string, uint32_t>* default_empty_list = 327 new std::map<std::string, uint32_t>(); 328 return *default_empty_list; 329 } 330 331 /// WARNING: Experimental interface, subject to change 332 /// Returns the input tensor identified by 'signature_input_name' in the 333 /// signature identified by 'signature_method_name'. 334 /// Returns nullptr if not found. input_tensor_by_signature_name(const char * signature_input_name,const char * signature_method_name)335 TfLiteTensor* input_tensor_by_signature_name( 336 const char* signature_input_name, const char* signature_method_name) { 337 const int tensor_index = GetTensorIndexFromSignatureDefName( 338 signature_input_name, signature_method_name, /*is_input=*/true); 339 return tensor_index == -1 ? nullptr : tensor(tensor_index); 340 } 341 342 /// WARNING: Experimental interface, subject to change 343 /// Returns the output tensor identified by 'signature_output_name' in the 344 /// signature identified by 'signature_method_name'. 345 /// Returns nullptr if not found. output_tensor_by_signature_name(const char * signature_output_name,const char * signature_method_name)346 const TfLiteTensor* output_tensor_by_signature_name( 347 const char* signature_output_name, 348 const char* signature_method_name) const { 349 const int tensor_index = GetTensorIndexFromSignatureDefName( 350 signature_output_name, signature_method_name, /*is_input=*/false); 351 return tensor_index == -1 ? nullptr : tensor(tensor_index); 352 } 353 354 /// Return a mutable pointer to the given input tensor. The given index must 355 /// be between 0 and inputs().size(). input_tensor(size_t index)356 TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } 357 358 /// Return an immutable pointerto the given input tensor. The given index must 359 /// be between 0 and inputs().size(). input_tensor(size_t index)360 const TfLiteTensor* input_tensor(size_t index) const { 361 return tensor(inputs()[index]); 362 } 363 364 /// Return a mutable pointer into the data of a given input tensor. The given 365 /// index must be between 0 and inputs().size(). 366 template <class T> typed_input_tensor(int index)367 T* typed_input_tensor(int index) { 368 return typed_tensor<T>(inputs()[index]); 369 } 370 371 /// Return an immutable pointer into the data of a given input tensor. The 372 /// given index must be between 0 and inputs().size(). 373 template <class T> typed_input_tensor(int index)374 const T* typed_input_tensor(int index) const { 375 return typed_tensor<T>(inputs()[index]); 376 } 377 378 /// Return a mutable pointer to the given output tensor. The given index must 379 /// be between 0 and outputs().size(). output_tensor(size_t index)380 TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); } 381 382 /// Return an immutable pointer to the given output tensor. The given index 383 /// must be between 0 and outputs().size(). output_tensor(size_t index)384 const TfLiteTensor* output_tensor(size_t index) const { 385 return tensor(outputs()[index]); 386 } 387 388 /// Return a mutable pointer into the data of a given output tensor. The given 389 /// index must be between 0 and outputs().size(). 390 template <class T> typed_output_tensor(int index)391 T* typed_output_tensor(int index) { 392 return typed_tensor<T>(outputs()[index]); 393 } 394 395 /// Return an immutable pointer into the data of a given output tensor. The 396 /// given index must be between 0 and outputs().size(). 397 template <class T> typed_output_tensor(int index)398 const T* typed_output_tensor(int index) const { 399 return typed_tensor<T>(outputs()[index]); 400 } 401 402 /// Change the dimensionality of a given tensor. Note, this is only acceptable 403 /// for tensor indices that are inputs or variables. 404 /// Returns status of failure or success. Note that this doesn't actually 405 /// resize any existing buffers. A call to AllocateTensors() is required to 406 /// change the tensor input buffer. 407 TfLiteStatus ResizeInputTensor(int tensor_index, 408 const std::vector<int>& dims); 409 410 // WARNING: Experimental interface, subject to change 411 // Change the dimensionality of a given tensor. This is only acceptable for 412 // tensor indices that are inputs or variables. Only unknown dimensions can be 413 // resized with this function. Unknown dimensions are indicated as `-1` in the 414 // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure 415 // or success. Note that this doesn't actually resize any existing buffers. 416 /// A call to AllocateTensors() is required to change the tensor input buffer. 417 TfLiteStatus ResizeInputTensorStrict(int tensor_index, 418 const std::vector<int>& dims); 419 420 // This releases memory held by non-persistent tensors. It does NOT re-perform 421 // memory planning. 422 // AllocateTensors needs to be called before next invocation. 423 /// WARNING: Experimental interface, subject to change 424 TfLiteStatus ReleaseNonPersistentMemory(); 425 426 // Update allocations for all tensors. This will redim dependent tensors 427 // using the input tensor dimensionality as given. This is relatively 428 // expensive. This *must be* called after the interpreter has been created 429 // and before running inference (and accessing tensor buffers), and *must be* 430 // called again if (and only if) an input tensor is resized. Returns status of 431 // success or failure. 432 TfLiteStatus AllocateTensors(); 433 434 /// Invoke the interpreter (run the whole graph in dependency order). 435 /// 436 /// NOTE: It is possible that the interpreter is not in a ready state 437 /// to evaluate (i.e. if a ResizeTensor() has been performed without an 438 /// AllocateTensors(). 439 /// Returns status of success or failure. 440 TfLiteStatus Invoke(); 441 442 /// Set the number of threads available to the interpreter. 443 /// 444 /// NOTE: num_threads should be >= -1. 445 /// User may pass -1 to let the TFLite interpreter set the no of threads 446 /// available to itself. 447 TfLiteStatus SetNumThreads(int num_threads); 448 449 /// Allow float16 precision for FP32 calculation when possible. 450 /// Default: not allow. 451 /// 452 /// WARNING: This API is deprecated: prefer controlling this via delegate 453 /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or 454 /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. 455 /// This method will be removed in a future release. 456 void SetAllowFp16PrecisionForFp32(bool allow); 457 458 /// Get the half precision flag. 459 /// WARNING: This is an experimental API and subject to change. GetAllowFp16PrecisionForFp32()460 bool GetAllowFp16PrecisionForFp32() const { 461 return context_->allow_fp32_relax_to_fp16; 462 } 463 464 /// Sets the cancellation function pointer in order to cancel a request in the 465 /// middle of a call to Invoke(). The interpreter queries this function during 466 /// inference, between op invocations; when it returns true, the interpreter 467 /// will abort execution and return `kTfLiteError`. The `data` parameter 468 /// contains any data used by the cancellation function, and if non-null, 469 /// remains owned by the caller. 470 /// WARNING: This is an experimental API and subject to change. 471 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*)); 472 473 /// Allow a delegate to look at the graph and modify the graph to handle 474 /// parts of the graph themselves. After this is called, the graph may 475 /// contain new nodes that replace 1 more nodes. 476 /// 'delegate' must outlive the interpreter. 477 /// Returns one of the following four status codes: 478 /// 1. kTfLiteOk: Success. 479 /// 2. kTfLiteDelegateError: Delegation failed due to an error in the 480 /// delegate, or the delegate parameter was null. The Interpreter has been 481 /// restored to its pre-delegation state. 482 /// NOTE: This undoes all delegates previously applied to the Interpreter. 483 /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the 484 /// incompatibility with the TfLite runtime, e.g., the model graph is already 485 /// immutable when applying the delegate. However, the interpreter could still 486 /// be invoked. 487 /// 4. kTfLiteError: Unexpected/runtime failure. 488 /// WARNING: This is an experimental API and subject to change. 489 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); 490 491 // Owning handle to a TfLiteDelegate instance. 492 using TfLiteDelegatePtr = 493 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 494 495 /// Same as ModifyGraphWithDelegate except this interpreter takes 496 /// ownership of the provided delegate. 497 /// WARNING: This is an experimental API and subject to change. 498 template <typename Delegate, typename Deleter> ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)499 inline TfLiteStatus ModifyGraphWithDelegate( 500 std::unique_ptr<Delegate, Deleter> delegate) { 501 Deleter deleter = std::move(delegate.get_deleter()); 502 503 // Note that we retain ownership of the delegate even if graph modification 504 // fails, as delegate use will be in an indeterminate state at that point. 505 owned_delegates_.emplace_back( 506 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { 507 deleter( 508 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( 509 delegate_to_delete)); 510 }); 511 return ModifyGraphWithDelegate(owned_delegates_.back().get()); 512 } 513 514 /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no 515 /// virtual destructor. The default deleter of the unique_ptr does not know 516 /// how to delete C++ objects deriving from TfLiteDelegate. 517 TfLiteStatus ModifyGraphWithDelegate( 518 std::unique_ptr<TfLiteDelegate> delegate) = delete; 519 520 /// Ensure the data in `tensor.data` is readable. In case delegate is used, 521 /// it might require to copy the data from delegate buffer to raw memory. 522 /// WARNING: This is an experimental API and subject to change. EnsureTensorDataIsReadable(int tensor_index)523 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { 524 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); 525 } 526 527 /// Set the delegate buffer handle to a tensor. It can be called in the 528 /// following cases: 529 /// 1. Set the buffer handle to a tensor that's not being written by a 530 /// delegate. For example, feeding an OpenGL texture as the input of the 531 /// inference graph. 532 /// 2. Set the buffer handle to a tensor that uses the same delegate. 533 /// For example, set an OpenGL texture as the output of inference, while 534 /// the node which produces output is an OpenGL delegate node. 535 /// WARNING: This is an experimental API and subject to change. 536 TfLiteStatus SetBufferHandle(int tensor_index, 537 TfLiteBufferHandle buffer_handle, 538 TfLiteDelegate* delegate); 539 540 /// Get the delegate buffer handle, and the delegate which can process the 541 /// buffer handle. 542 /// WARNING: This is an experimental API and subject to change. 543 TfLiteStatus GetBufferHandle(int tensor_index, 544 TfLiteBufferHandle* buffer_handle, 545 TfLiteDelegate** delegate); 546 547 /// Sets the profiler to tracing execution. The caller retains ownership 548 /// of the profiler and must ensure its validity. 549 /// WARNING: This is an experimental API and subject to change. 550 void SetProfiler(Profiler* profiler); 551 552 /// Same as SetProfiler except this interpreter takes ownership 553 /// of the provided profiler. 554 /// WARNING: This is an experimental API and subject to change. 555 void SetProfiler(std::unique_ptr<Profiler> profiler); 556 557 /// Gets the profiler used for op tracing. 558 /// WARNING: This is an experimental API and subject to change. 559 Profiler* GetProfiler(); 560 561 // The default capacity of `tensors_` vector. 562 static constexpr int kTensorsReservedCapacity = 128; 563 /// The capacity headroom of `tensors_` vector before calling ops' 564 /// `prepare` and `invoke` function. In these functions, it's guaranteed 565 /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate 566 /// pointers to existing tensors. 567 static constexpr int kTensorsCapacityHeadroom = 16; 568 569 /// Set if buffer handle output is allowed. 570 // 571 /// When using hardware delegation, Interpreter will make the data of output 572 /// tensors available in `tensor->data` by default. If the application can 573 /// consume the buffer handle directly (e.g. reading output from OpenGL 574 /// texture), it can set this flag to false, so Interpreter won't copy the 575 /// data from buffer handle to CPU memory. WARNING: This is an experimental 576 /// API and subject to change. SetAllowBufferHandleOutput(bool allow_buffer_handle_output)577 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { 578 allow_buffer_handle_output_ = allow_buffer_handle_output; 579 } 580 581 /// Reset all variable tensors to the default value. 582 /// If a variable tensor doesn't have a buffer, reset it to zero. 583 /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it 584 /// to the value of the buffer. 585 /// WARNING: This is an experimental API and subject to change. 586 TfLiteStatus ResetVariableTensors(); 587 588 /// Retrieve an operator's description of its work, for profiling purposes. OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)589 const char* OpProfilingString(const TfLiteRegistration& op_reg, 590 const TfLiteNode* node) const { 591 if (op_reg.profiling_string == nullptr) return nullptr; 592 return op_reg.profiling_string(context_, node); 593 } 594 595 // Set the value of an external context. TFLite interpreter doesn't take the 596 // memory ownership of this external context 'ctx', and the context should 597 // outlive the TFLite interpreter. 598 void SetExternalContext(TfLiteExternalContextType type, 599 TfLiteExternalContext* ctx); 600 601 // Assigns (or reassigns) a custom memory allocation for the given tensor. 602 // If AllocateTensors() is called after this, the runtime does not consider 603 // the tensor during internal memory planning and will continue using the 604 // provided allocation for the tensor (assuming it satisfies the expected 605 // tensor byte length). 606 // The runtime does NOT take ownership of the underlying memory. 607 // Note that while this function can be called again to set a new allocation 608 // for the tensor, it can no longer be reset to the TFLite arena memory. 609 // 610 // Parameters should satisfy the following conditions: 611 // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent 612 // In general, this is true for I/O tensors & variable tensors. 613 // 2. allocation->data has the appropriate permissions for runtime access 614 // (Read-only for inputs, Read-Write for others), and outlives Interpreter. 615 // 3. allocation->bytes >= tensor->bytes. 616 // This condition is checked again if any tensors are resized. 617 // 4. allocation->data should be aligned to kDefaultTensorAlignment 618 // defined in lite/util.h. (Currently 64 bytes) 619 // 620 // WARNING: This is an experimental interface that is subject to change. 621 TfLiteStatus SetCustomAllocationForTensor( 622 int tensor_index, const TfLiteCustomAllocation& allocation); 623 624 #ifndef DOXYGEN_SKIP 625 /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph 626 /// entries. The value pointed to by `first_new_subgraph_index` will be set to 627 /// the index of the first new subgraph if `first_new_subgraph_index` is 628 /// non-null. 629 /// WARNING: This is an experimental API and subject to change. 630 void AddSubgraphs(int subgraphs_to_add, 631 int* first_new_subgraph_index = nullptr); 632 633 /// Return the number of subgraphs in the model. 634 /// WARNING: This is an experimental API and subject to change. subgraphs_size()635 size_t subgraphs_size() const { return subgraphs_.size(); } 636 637 /// Get a pointer to a subgraph if in bounds. 638 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)639 Subgraph* subgraph(int subgraph_index) { 640 if (subgraph_index < 0 || 641 static_cast<size_t>(subgraph_index) >= subgraphs_size()) 642 return nullptr; 643 return &*subgraphs_[subgraph_index]; 644 } 645 646 /// WARNING: Experimental interface, subject to change primary_subgraph()647 Subgraph& primary_subgraph() { 648 return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry. 649 } 650 651 /// WARNING: Experimental interface, subject to change primary_subgraph()652 const Subgraph& primary_subgraph() const { 653 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. 654 } 655 656 /// WARNING: Experimental interface, subject to change 657 // Get the error reporter associated with this interpreter. error_reporter()658 ErrorReporter* error_reporter() const { return error_reporter_; } 659 660 #endif // DOXYGEN_SKIP 661 662 private: 663 // Structure representing SignatureDef inputs/outputs. 664 struct SignatureDef { 665 // Maps name in signature def as key to index of the tensor in the model. 666 std::map<std::string, uint32_t> inputs; 667 // Maps name in signature def as key to index of the tensor in the model. 668 std::map<std::string, uint32_t> outputs; 669 // The method name for this signature. 670 std::string method_name; 671 // The key of this SignatureDef in the SavedModel signature def map. 672 std::string signature_def_key; 673 }; 674 friend class InterpreterBuilder; 675 friend class tflite::InterpreterTest; 676 friend class tflite::delegates::InterpreterUtils; 677 friend class tflite::delegates::test_utils::TestDelegate; 678 679 /// Set the value of an external context. 680 static void SetExternalContext(struct TfLiteContext* context, 681 TfLiteExternalContextType type, 682 TfLiteExternalContext* ctx); 683 684 // Helper method that return the tensot index that corresponds to 685 // a name in a SignatureDef. Defined by 'signature_method_name', and 686 // 'signature_tensor_name'. 687 // If 'is_input' is true then the tensor is checked in input tensors, 688 // otherwise it will be checked in output tensors. 689 // Returns -1 if the tensor is not found. GetTensorIndexFromSignatureDefName(const char * signature_tensor_name,const char * signature_method_name,bool is_input)690 int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name, 691 const char* signature_method_name, 692 bool is_input) const { 693 // Iterate directly and don't use other methods to avoid extra allocation. 694 for (const auto& signature : signature_defs_) { 695 if (signature.method_name != signature_method_name) continue; 696 auto& signature_list = (is_input ? signature.inputs : signature.outputs); 697 auto tensor_iter = signature_list.find(signature_tensor_name); 698 if (tensor_iter == signature_list.end()) return -1; 699 return tensor_iter->second; 700 } 701 return -1; 702 } 703 704 // Sets the profiler to all subgraphs. 705 void SetSubgraphProfiler(); 706 707 // Remove delegates (for fallback behaviour). The interpreter is invokable 708 // afterwards. 709 TfLiteStatus RemoveAllDelegates(); 710 711 // Returns true if delegates have been applied. 712 bool HasDelegates(); 713 714 // Returns true if cancellation function returns true. 715 bool IsCancelled(); 716 717 // Sets the list of signature defs in the model. SetSignatureDef(std::vector<SignatureDef> signature_defs)718 void SetSignatureDef(std::vector<SignatureDef> signature_defs) { 719 signature_defs_ = std::move(signature_defs); 720 } 721 722 // A pure C data structure used to communicate with the pure C plugin 723 // interface. To avoid copying tensor metadata, this is also the definitive 724 // structure to store tensors. 725 // This is the primary subgraph context. 726 TfLiteContext* context_ = nullptr; 727 728 // The error reporter delegate that tflite will forward queries errors to. 729 ErrorReporter* error_reporter_ = nullptr; 730 731 // List of delegates that have been installed and are owned by this 732 // interpreter instance. Useful if client delegate ownership is burdensome. 733 // WARNING: This is an experimental API and subject to change. 734 // TODO(b/116667551): Use TfLiteExternalContext for storing state. 735 std::vector< 736 std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>> 737 owned_delegates_; 738 739 // Profiler that has been installed and is owned by this interpreter instance. 740 // Useful if client profiler ownership is burdensome. 741 std::unique_ptr<Profiler> owned_profiler_; 742 743 // Points to the installed Profiler instance. 744 Profiler* installed_profiler_ = nullptr; 745 746 bool allow_buffer_handle_output_ = false; 747 748 // List of active external contexts. 749 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; 750 751 // The default external cpu backend context. After an TFLite interpreter is 752 // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point 753 // to this object. However, if this element value is overwritten via calling 754 // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to 755 // nullptr if necessary. 756 std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_; 757 758 // Subgraphs 759 std::vector<std::unique_ptr<Subgraph>> subgraphs_; 760 761 // A map of resources. Owned by interpreter and shared by multiple subgraphs. 762 resource::ResourceMap resources_; 763 764 // Indicating delegates that the TFLite interpreter will apply by default. 765 // An empty one means there's no delegate to be applied by default or 766 // delegates have been applied and doesn't need to be applied again. 767 std::vector<TfLiteDelegatePtr> lazy_delegate_providers_; 768 769 // List of signature def mapping inputs/output to tensor ids. 770 // We just keep track of tensor index. 771 std::vector<SignatureDef> signature_defs_; 772 }; 773 774 } // namespace tflite 775 #endif // TENSORFLOW_LITE_INTERPRETER_H_ 776