1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// \file 16 /// Main abstraction controlling the tflite interpreter. 17 /// See context.h for the API for defining operations (TfLiteRegistration). 18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_ 19 #define TENSORFLOW_LITE_INTERPRETER_H_ 20 21 #include <stddef.h> 22 #include <stdint.h> 23 24 #include <complex> 25 #include <cstdio> 26 #include <cstdlib> 27 #include <functional> 28 #include <map> 29 #include <memory> 30 #include <string> 31 #include <utility> 32 #include <vector> 33 34 #include "tensorflow/lite/allocation.h" 35 #include "tensorflow/lite/c/common.h" // IWYU pragma: export 36 #include "tensorflow/lite/core/api/error_reporter.h" 37 #include "tensorflow/lite/core/api/profiler.h" 38 #include "tensorflow/lite/core/subgraph.h" 39 #include "tensorflow/lite/experimental/resource/initialization_status.h" 40 #include "tensorflow/lite/experimental/resource/resource_base.h" 41 #include "tensorflow/lite/external_cpu_backend_context.h" 42 #include "tensorflow/lite/internal/signature_def.h" 43 #include "tensorflow/lite/memory_planner.h" 44 #include "tensorflow/lite/portable_type_to_tflitetype.h" 45 #include "tensorflow/lite/signature_runner.h" 46 #include "tensorflow/lite/stderr_reporter.h" 47 #include "tensorflow/lite/string_type.h" 48 #include "tensorflow/lite/type_to_tflitetype.h" 49 50 namespace tflite { 51 52 class InterpreterTest; // Class for friend declarations. 53 54 namespace delegates { 55 class InterpreterUtils; // Class for friend declarations. 56 57 namespace test_utils { 58 class TestDelegation; // Class for friend declarations. 59 } // namespace test_utils 60 } // namespace delegates 61 62 namespace interpreter_wrapper { 63 class InterpreterWrapper; // Class for friend declarations. 64 } // namespace interpreter_wrapper 65 66 /// An interpreter for a graph of nodes that input and output from tensors. 67 /// Each node of the graph processes a set of input tensors and produces a 68 /// set of output Tensors. All inputs/output tensors are referenced by index. 69 /// 70 /// Usage: 71 /// 72 /// <pre><code> 73 /// // Create model from file. Note that the model instance must outlive the 74 /// // interpreter instance. 75 /// auto model = tflite::FlatBufferModel::BuildFromFile(...); 76 /// if (model == nullptr) { 77 /// // Return error. 78 /// } 79 /// // Create an Interpreter with an InterpreterBuilder. 80 /// std::unique_ptr<tflite::Interpreter> interpreter; 81 /// tflite::ops::builtin::BuiltinOpResolver resolver; 82 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { 83 /// // Return failure. 84 /// } 85 /// if (interpreter->AllocateTensors() != kTfLiteOk) { 86 /// // Return failure. 87 /// } 88 /// 89 /// auto input = interpreter->typed_tensor<float>(0); 90 /// for (int i = 0; i < input_size; i++) { 91 /// input[i] = ...; 92 // } 93 /// interpreter->Invoke(); 94 /// </code></pre> 95 /// 96 /// Note: For nearly all practical use cases, one should not directly construct 97 /// an Interpreter object, but rather use the InterpreterBuilder. 98 /// 99 /// WARNING: This class is *not* thread-safe. The client is responsible for 100 /// ensuring serialized interaction to avoid data races and undefined behavior. 101 class Interpreter { 102 public: 103 // Instantiate an interpreter. All errors associated with reading and 104 // processing this model will be forwarded to the error_reporter object. 105 // 106 // Note, if error_reporter is nullptr, then a default StderrReporter is 107 // used. Ownership of 'error_reporter' remains with the caller. 108 // WARNING: Use of this constructor outside of an InterpreterBuilder is not 109 // recommended. 110 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); 111 112 ~Interpreter(); 113 114 // Interpreters are not copyable as they have non-trivial memory semantics. 115 Interpreter(const Interpreter&) = delete; 116 Interpreter& operator=(const Interpreter&) = delete; 117 118 // Functions to build interpreter 119 #ifndef DOXYGEN_SKIP 120 /// Provide a list of tensor indexes that are inputs to the model. 121 /// Each index is bound check and this modifies the consistent_ flag of the 122 /// interpreter. 123 TfLiteStatus SetInputs(std::vector<int> inputs); 124 125 /// Provide a list of tensor indexes that are outputs to the model 126 /// Each index is bound check and this modifies the consistent_ flag of the 127 /// interpreter. 128 TfLiteStatus SetOutputs(std::vector<int> outputs); 129 130 /// Provide a list of tensor indexes that are variable tensors. 131 /// Each index is bound check and this modifies the consistent_ flag of the 132 /// interpreter. 133 TfLiteStatus SetVariables(std::vector<int> variables); 134 135 /// Adds a node with the given parameters and returns the index of the new 136 /// node in `node_index` (optionally). Interpreter will take ownership of 137 /// `builtin_data` and destroy it with `free`. Ownership of 'init_data' 138 /// remains with the caller. 139 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, 140 const std::vector<int>& outputs, 141 const char* init_data, 142 size_t init_data_size, void* builtin_data, 143 const TfLiteRegistration* registration, 144 int* node_index = nullptr); 145 146 /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. 147 /// The value pointed to by `first_new_tensor_index` will be set to the 148 /// index of the first new tensor if `first_new_tensor_index` is non-null. 149 TfLiteStatus AddTensors(int tensors_to_add, 150 int* first_new_tensor_index = nullptr); 151 152 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 153 /// This variant assumes an external buffer has been allocated of size 154 /// bytes. The lifetime of buffer must be ensured to be greater or equal 155 /// to Interpreter. 156 TfLiteStatus SetTensorParametersReadOnly( 157 int tensor_index, TfLiteType type, const char* name, 158 const std::vector<int>& dims, TfLiteQuantization quantization, 159 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 160 161 /// Legacy. Deprecated in favor of above. 162 inline TfLiteStatus SetTensorParametersReadOnly( 163 int tensor_index, TfLiteType type, const char* name, 164 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 165 const char* buffer, size_t bytes, 166 const Allocation* allocation = nullptr) { 167 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), 168 dims.data(), quantization, buffer, bytes, 169 allocation); 170 } 171 172 TfLiteStatus SetTensorParametersReadOnly( 173 int tensor_index, TfLiteType type, const char* name, const size_t rank, 174 const int* dims, TfLiteQuantizationParams quantization, 175 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 176 177 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 178 /// This variant assumes an external buffer has been allocated of size 179 /// bytes. The lifetime of buffer must be ensured to be greater or equal 180 /// to Interpreter. 181 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, 182 const char* name, 183 const std::vector<int>& dims, 184 TfLiteQuantization quantization, 185 bool is_variable = false); 186 187 /// Legacy. Deprecated in favor of above. 188 inline TfLiteStatus SetTensorParametersReadWrite( 189 int tensor_index, TfLiteType type, const char* name, 190 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 191 bool is_variable = false, 192 const std::vector<int>* dims_signature = nullptr) { 193 size_t rank_dims_signature = 0; 194 const int* dims_signature_pointer = nullptr; 195 if (dims_signature) { 196 rank_dims_signature = dims_signature->size(); 197 dims_signature_pointer = dims_signature->data(); 198 } 199 return SetTensorParametersReadWrite( 200 tensor_index, type, name, dims.size(), dims.data(), quantization, 201 is_variable, rank_dims_signature, dims_signature_pointer); 202 } 203 TfLiteStatus SetTensorParametersReadWrite( 204 int tensor_index, TfLiteType type, const char* name, const size_t rank, 205 const int* dims, TfLiteQuantizationParams quantization, 206 bool is_variable = false, const size_t rank_dims_signature = 0, 207 const int* dims_signature = nullptr); 208 #endif // DOXYGEN_SKIP 209 // Functions to access tensor data 210 211 /// Read only access to list of inputs. inputs()212 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); } 213 214 /// Return the name of a given input. The given index must be between 0 and 215 /// inputs().size(). GetInputName(int index)216 const char* GetInputName(int index) const { 217 return context_->tensors[inputs()[index]].name; 218 } 219 220 /// Read only access to list of outputs. outputs()221 const std::vector<int>& outputs() const { 222 return primary_subgraph().outputs(); 223 } 224 225 /// Read only access to list of variable tensors. variables()226 const std::vector<int>& variables() const { 227 return primary_subgraph().variables(); 228 } 229 230 /// Return the name of a given output. The given index must be between 0 and 231 /// outputs().size(). GetOutputName(int index)232 const char* GetOutputName(int index) const { 233 return context_->tensors[outputs()[index]].name; 234 } 235 236 /// Return the number of tensors in the model. tensors_size()237 size_t tensors_size() const { return context_->tensors_size; } 238 239 /// Return the number of ops in the model. nodes_size()240 size_t nodes_size() const { return primary_subgraph().nodes_size(); } 241 242 /// WARNING: Experimental interface, subject to change execution_plan()243 const std::vector<int>& execution_plan() const { 244 return primary_subgraph().execution_plan(); 245 } 246 247 /// Get a mutable tensor data structure. 248 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this 249 // read/write access to structure tensor(int tensor_index)250 TfLiteTensor* tensor(int tensor_index) { 251 return primary_subgraph().tensor(tensor_index); 252 } 253 254 /// Get an immutable tensor data structure. tensor(int tensor_index)255 const TfLiteTensor* tensor(int tensor_index) const { 256 return primary_subgraph().tensor(tensor_index); 257 } 258 259 /// Returns a pointer to an operation and registration data structure if in 260 /// bounds from the primary subgraph(subgraph_[0]). node_and_registration(int node_index)261 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 262 int node_index) const { 263 return primary_subgraph().node_and_registration(node_index); 264 } 265 266 /// Returns a pointer to an operation and registration data structure if in 267 /// bounds. node_and_registration(int subgraph_index,int node_index)268 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 269 int subgraph_index, int node_index) const { 270 return subgraph(subgraph_index)->node_and_registration(node_index); 271 } 272 273 /// Perform a checked cast to the appropriate tensor type (mutable pointer 274 /// version). 275 template <class T> typed_tensor(int tensor_index)276 T* typed_tensor(int tensor_index) { 277 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 278 if (tensor_ptr->type == typeToTfLiteType<T>()) { 279 return reinterpret_cast<T*>(tensor_ptr->data.raw); 280 } 281 } 282 return nullptr; 283 } 284 285 /// Perform a checked cast to the appropriate tensor type (immutable pointer 286 /// version). 287 template <class T> typed_tensor(int tensor_index)288 const T* typed_tensor(int tensor_index) const { 289 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 290 if (tensor_ptr->type == typeToTfLiteType<T>()) { 291 return reinterpret_cast<const T*>(tensor_ptr->data.raw); 292 } 293 } 294 return nullptr; 295 } 296 297 /// WARNING: Experimental interface, subject to change 298 /// Returns list of all keys of different method signatures defined in the 299 /// model. 300 /// Note, pointers returned have lifetime same as the Interpreter object. signature_keys()301 std::vector<const std::string*> signature_keys() const { 302 std::vector<const std::string*> signature_keys; 303 signature_keys.reserve(signature_defs_.size()); 304 for (const auto& sig_def : signature_defs_) { 305 signature_keys.emplace_back(&sig_def.signature_key); 306 } 307 return signature_keys; 308 } 309 310 /// WARNING: Experimental interface, subject to change 311 /// Returns a pointer to the SignatureRunner instance to run the part of the 312 /// graph identified by a SignatureDef. The nullptr is returned if the given 313 /// signature key is not valid. 314 /// If you need to specify delegates, you have to do that before calling this 315 /// function. This function will additionally apply default delegates. Thus, 316 /// applying delegates after that might lead to undesirable behaviors. 317 /// Note, the pointed instance has lifetime same as the Interpreter object 318 /// and the SignatureRunner class is *not* thread-safe. 319 SignatureRunner* GetSignatureRunner(const char* signature_key); 320 321 /// WARNING: Experimental interface, subject to change 322 // Return the subgraph index that corresponds to a SignatureDef, defined by 323 // 'signature_key'. 324 // If invalid name passed, -1 will be returned. GetSubgraphIndexFromSignature(const char * signature_key)325 int GetSubgraphIndexFromSignature(const char* signature_key) const { 326 for (const auto& signature : signature_defs_) { 327 if (signature.signature_key == signature_key) { 328 return signature.subgraph_index; 329 } 330 } 331 return -1; 332 } 333 334 /// WARNING: Experimental interface, subject to change 335 /// Returns the mapping of inputs to tensor index in the signature 336 /// specified through 'signature_key'. 337 /// If invalid name passed, an empty list will be returned. signature_inputs(const char * signature_key)338 const std::map<std::string, uint32_t>& signature_inputs( 339 const char* signature_key) const { 340 for (const auto& sig_def : signature_defs_) { 341 if (sig_def.signature_key == signature_key) return sig_def.inputs; 342 } 343 static const std::map<std::string, uint32_t>* default_empty_list = 344 new std::map<std::string, uint32_t>(); 345 return *default_empty_list; 346 } 347 348 /// WARNING: Experimental interface, subject to change 349 /// Returns the mapping of outputs to tensor index in the signature 350 /// specified through 'signature_key'. 351 /// If invalid name passed, an empty list will be returned. signature_outputs(const char * signature_key)352 const std::map<std::string, uint32_t>& signature_outputs( 353 const char* signature_key) const { 354 for (const auto& sig_def : signature_defs_) { 355 if (sig_def.signature_key == signature_key) return sig_def.outputs; 356 } 357 static const std::map<std::string, uint32_t>* default_empty_list = 358 new std::map<std::string, uint32_t>(); 359 return *default_empty_list; 360 } 361 362 /// WARNING: Experimental interface, subject to change 363 /// Returns the input tensor identified by 'signature_input_name' in the 364 /// signature identified by 'signature_key'. 365 /// Returns nullptr if not found. input_tensor_by_signature(const char * signature_input_name,const char * signature_key)366 TfLiteTensor* input_tensor_by_signature(const char* signature_input_name, 367 const char* signature_key) { 368 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); 369 if (subgraph_index == -1) return nullptr; 370 const int tensor_index = GetTensorIndexFromSignature( 371 signature_input_name, signature_key, /*is_input=*/true); 372 if (tensor_index == -1) return nullptr; 373 return subgraph(subgraph_index)->tensor(tensor_index); 374 } 375 376 /// WARNING: Experimental interface, subject to change 377 /// Returns the output tensor identified by 'signature_output_name' in the 378 /// signature identified by 'signature_key'. 379 /// Returns nullptr if not found. output_tensor_by_signature(const char * signature_output_name,const char * signature_key)380 const TfLiteTensor* output_tensor_by_signature( 381 const char* signature_output_name, const char* signature_key) const { 382 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); 383 if (subgraph_index == -1) return nullptr; 384 const int tensor_index = GetTensorIndexFromSignature( 385 signature_output_name, signature_key, /*is_input=*/false); 386 if (tensor_index == -1) return nullptr; 387 return subgraph(subgraph_index)->tensor(tensor_index); 388 } 389 390 /// Return a mutable pointer to the given input tensor. The given index must 391 /// be between 0 and inputs().size(). input_tensor(size_t index)392 TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } 393 394 /// Return an immutable pointer to the given input tensor. The given index 395 /// must be between 0 and inputs().size(). input_tensor(size_t index)396 const TfLiteTensor* input_tensor(size_t index) const { 397 return tensor(inputs()[index]); 398 } 399 400 /// Return a mutable pointer into the data of a given input tensor. The given 401 /// index must be between 0 and inputs().size(). 402 template <class T> typed_input_tensor(int index)403 T* typed_input_tensor(int index) { 404 return typed_tensor<T>(inputs()[index]); 405 } 406 407 /// Return an immutable pointer into the data of a given input tensor. The 408 /// given index must be between 0 and inputs().size(). 409 template <class T> typed_input_tensor(int index)410 const T* typed_input_tensor(int index) const { 411 return typed_tensor<T>(inputs()[index]); 412 } 413 414 /// Return a mutable pointer to the given output tensor. The given index must 415 /// be between 0 and outputs().size(). output_tensor(size_t index)416 TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); } 417 418 /// Return an immutable pointer to the given output tensor. The given index 419 /// must be between 0 and outputs().size(). output_tensor(size_t index)420 const TfLiteTensor* output_tensor(size_t index) const { 421 return tensor(outputs()[index]); 422 } 423 424 /// Return a mutable pointer into the data of a given output tensor. The given 425 /// index must be between 0 and outputs().size(). 426 template <class T> typed_output_tensor(int index)427 T* typed_output_tensor(int index) { 428 return typed_tensor<T>(outputs()[index]); 429 } 430 431 /// Return an immutable pointer into the data of a given output tensor. The 432 /// given index must be between 0 and outputs().size(). 433 template <class T> typed_output_tensor(int index)434 const T* typed_output_tensor(int index) const { 435 return typed_tensor<T>(outputs()[index]); 436 } 437 438 /// Change the dimensionality of a given tensor. Note, this is only acceptable 439 /// for tensor indices that are inputs or variables. 440 /// Returns status of failure or success. Note that this doesn't actually 441 /// resize any existing buffers. A call to AllocateTensors() is required to 442 /// change the tensor input buffer. 443 TfLiteStatus ResizeInputTensor(int tensor_index, 444 const std::vector<int>& dims); 445 446 // Change the dimensionality of a given tensor. This is only acceptable for 447 // tensor indices that are inputs or variables. Only unknown dimensions can be 448 // resized with this function. Unknown dimensions are indicated as `-1` in the 449 // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure 450 // or success. Note that this doesn't actually resize any existing buffers. 451 /// A call to AllocateTensors() is required to change the tensor input buffer. 452 TfLiteStatus ResizeInputTensorStrict(int tensor_index, 453 const std::vector<int>& dims); 454 455 // This releases memory held by non-persistent tensors. It does NOT re-perform 456 // memory planning. 457 // AllocateTensors needs to be called before next invocation. 458 /// WARNING: Experimental interface, subject to change 459 TfLiteStatus ReleaseNonPersistentMemory(); 460 461 // Update allocations for all tensors. This will redim dependent tensors 462 // using the input tensor dimensionality as given. This is relatively 463 // expensive. This *must be* called after the interpreter has been created 464 // and before running inference (and accessing tensor buffers), and *must be* 465 // called again if (and only if) an input tensor is resized. Returns status of 466 // success or failure. Will fail if any of the ops in the model (other than 467 // those which were rewritten by delegates, if any) are not supported by the 468 // Interpreter's OpResolver. 469 TfLiteStatus AllocateTensors(); 470 471 /// Invoke the interpreter (run the whole graph in dependency order). 472 /// 473 /// NOTE: It is possible that the interpreter is not in a ready state 474 /// to evaluate (i.e. if a ResizeTensor() has been performed without an 475 /// AllocateTensors(). 476 /// Returns status of success or failure. 477 TfLiteStatus Invoke(); 478 479 /// Set the number of threads available to the interpreter. 480 /// 481 /// NOTE: num_threads should be >= -1. Setting num_threads to 0 has the effect 482 /// to disable multithreading, which is equivalent to setting num_threads 483 /// to 1. If set to the value -1, the number of threads used will be 484 /// implementation-defined and platform-dependent. 485 TfLiteStatus SetNumThreads(int num_threads); 486 487 /// Allow float16 precision for FP32 calculation when possible. 488 /// Default: not allow. 489 /// 490 /// WARNING: This API is deprecated: prefer controlling this via delegate 491 /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or 492 /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. 493 /// This method will be removed in a future release. 494 void SetAllowFp16PrecisionForFp32(bool allow); 495 496 /// Get the half precision flag. 497 /// WARNING: This is an experimental API and subject to change. GetAllowFp16PrecisionForFp32()498 bool GetAllowFp16PrecisionForFp32() const { 499 return context_->allow_fp32_relax_to_fp16; 500 } 501 502 /// Sets the cancellation function pointer in order to cancel a request in the 503 /// middle of a call to Invoke(). The interpreter queries this function during 504 /// inference, between op invocations; when it returns true, the interpreter 505 /// will abort execution and return `kTfLiteError`. The `data` parameter 506 /// contains any data used by the cancellation function, and if non-null, 507 /// remains owned by the caller. 508 /// WARNING: This is an experimental API and subject to change. 509 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*)); 510 511 /// Allow a delegate to look at the graph and modify the graph to handle 512 /// parts of the graph themselves. After this is called, the graph may 513 /// contain new nodes that replace 1 more nodes. 514 /// 'delegate' must outlive the interpreter. 515 /// Returns one of the following four status codes: 516 /// 1. kTfLiteOk: Success. 517 /// 2. kTfLiteDelegateError: Delegation failed due to an error in the 518 /// delegate, or the delegate parameter was null. The Interpreter has been 519 /// restored to its pre-delegation state. 520 /// NOTE: This undoes all delegates previously applied to the Interpreter. 521 /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the 522 /// incompatibility with the TfLite runtime, e.g., the model graph is already 523 /// immutable when applying the delegate. However, the interpreter could still 524 /// be invoked. 525 /// 4. kTfLiteError: Unexpected/runtime failure. 526 /// WARNING: This is an experimental API and subject to change. 527 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); 528 529 // Owning handle to a TfLiteDelegate instance. 530 using TfLiteDelegatePtr = 531 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 532 533 /// Same as ModifyGraphWithDelegate except this interpreter takes 534 /// ownership of the provided delegate. 535 /// WARNING: This is an experimental API and subject to change. 536 template <typename Delegate, typename Deleter> ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)537 inline TfLiteStatus ModifyGraphWithDelegate( 538 std::unique_ptr<Delegate, Deleter> delegate) { 539 Deleter deleter = std::move(delegate.get_deleter()); 540 541 // Note that we retain ownership of the delegate even if graph modification 542 // fails, as delegate use will be in an indeterminate state at that point. 543 owned_delegates_.emplace_back( 544 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { 545 deleter( 546 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( 547 delegate_to_delete)); 548 }); 549 return ModifyGraphWithDelegate(owned_delegates_.back().get()); 550 } 551 552 /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no 553 /// virtual destructor. The default deleter of the unique_ptr does not know 554 /// how to delete C++ objects deriving from TfLiteDelegate. 555 TfLiteStatus ModifyGraphWithDelegate( 556 std::unique_ptr<TfLiteDelegate> delegate) = delete; 557 558 /// Ensure the data in `tensor.data` is readable. In case delegate is used, 559 /// it might require to copy the data from delegate buffer to raw memory. 560 /// WARNING: This is an experimental API and subject to change. EnsureTensorDataIsReadable(int tensor_index)561 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { 562 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); 563 } 564 565 /// Set the delegate buffer handle to a tensor. It can be called in the 566 /// following cases: 567 /// 1. Set the buffer handle to a tensor that's not being written by a 568 /// delegate. For example, feeding an OpenGL texture as the input of the 569 /// inference graph. 570 /// 2. Set the buffer handle to a tensor that uses the same delegate. 571 /// For example, set an OpenGL texture as the output of inference, while 572 /// the node which produces output is an OpenGL delegate node. 573 /// WARNING: This is an experimental API and subject to change. 574 TfLiteStatus SetBufferHandle(int tensor_index, 575 TfLiteBufferHandle buffer_handle, 576 TfLiteDelegate* delegate); 577 578 /// Get the delegate buffer handle, and the delegate which can process the 579 /// buffer handle. 580 /// WARNING: This is an experimental API and subject to change. 581 TfLiteStatus GetBufferHandle(int tensor_index, 582 TfLiteBufferHandle* buffer_handle, 583 TfLiteDelegate** delegate); 584 585 /// Sets the profiler to tracing execution. The caller retains ownership 586 /// of the profiler and must ensure its validity. 587 /// WARNING: This is an experimental API and subject to change. 588 void SetProfiler(Profiler* profiler); 589 590 /// Same as SetProfiler except this interpreter takes ownership 591 /// of the provided profiler. 592 /// WARNING: This is an experimental API and subject to change. 593 void SetProfiler(std::unique_ptr<Profiler> profiler); 594 595 /// Gets the profiler used for op tracing. 596 /// WARNING: This is an experimental API and subject to change. 597 Profiler* GetProfiler(); 598 599 // The default capacity of `tensors_` vector. 600 static constexpr int kTensorsReservedCapacity = 128; 601 /// The capacity headroom of `tensors_` vector before calling ops' 602 /// `prepare` and `invoke` function. In these functions, it's guaranteed 603 /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate 604 /// pointers to existing tensors. 605 static constexpr int kTensorsCapacityHeadroom = 16; 606 607 /// Set if buffer handle output is allowed. 608 /// 609 /// When using hardware delegation, Interpreter will make the data of output 610 /// tensors available in `tensor->data` by default. If the application can 611 /// consume the buffer handle directly (e.g. reading output from OpenGL 612 /// texture), it can set this flag to false, so Interpreter won't copy the 613 /// data from buffer handle to CPU memory. 614 /// WARNING: This is an experimental API and subject to change. SetAllowBufferHandleOutput(bool allow_buffer_handle_output)615 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { 616 allow_buffer_handle_output_ = allow_buffer_handle_output; 617 } 618 619 /// Reset all variable tensors to the default value. 620 /// If a variable tensor doesn't have a buffer, reset it to zero. 621 /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it 622 /// to the value of the buffer. 623 /// WARNING: This is an experimental API and subject to change. 624 TfLiteStatus ResetVariableTensors(); 625 626 /// Retrieve an operator's description of its work, for profiling purposes. OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)627 const char* OpProfilingString(const TfLiteRegistration& op_reg, 628 const TfLiteNode* node) const { 629 if (op_reg.profiling_string == nullptr) return nullptr; 630 return op_reg.profiling_string(context_, node); 631 } 632 633 // Set the value of an external context. TFLite interpreter doesn't take the 634 // memory ownership of this external context 'ctx', and the context should 635 // outlive the TFLite interpreter. 636 void SetExternalContext(TfLiteExternalContextType type, 637 TfLiteExternalContext* ctx); 638 639 // Assigns (or reassigns) a custom memory allocation for the given tensor. 640 // `flags` is a bitmask, see TfLiteCustomAllocationFlags. 641 // The runtime does NOT take ownership of the underlying memory. 642 // 643 // NOTE: User needs to call AllocateTensors() after this. In case of input 644 // resizing, buffers will be checked for required data size during 645 // AllocateTensors(). 646 // 647 // Parameters should satisfy the following conditions: 648 // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent 649 // In general, this is true for I/O tensors & variable tensors. 650 // 2. allocation->data has the appropriate permissions for runtime access 651 // (Read-only for inputs, Read-Write for others), and outlives Interpreter. 652 // 3. allocation->bytes >= tensor->bytes. 653 // This condition is checked again if any tensors are resized. 654 // 4. allocation->data should be aligned to kDefaultTensorAlignment 655 // defined in lite/util.h. (Currently 64 bytes) 656 // This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is 657 // set through `flags`. 658 // 659 // WARNING: This is an experimental interface that is subject to change. 660 TfLiteStatus SetCustomAllocationForTensor( 661 int tensor_index, const TfLiteCustomAllocation& allocation, 662 int64_t flags = kTfLiteCustomAllocationFlagsNone); 663 664 #ifndef DOXYGEN_SKIP 665 /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph 666 /// entries. The value pointed to by `first_new_subgraph_index` will be set to 667 /// the index of the first new subgraph if `first_new_subgraph_index` is 668 /// non-null. 669 /// WARNING: This is an experimental API and subject to change. 670 void AddSubgraphs(int subgraphs_to_add, 671 int* first_new_subgraph_index = nullptr); 672 673 /// Return the number of subgraphs in the model. 674 /// WARNING: This is an experimental API and subject to change. subgraphs_size()675 size_t subgraphs_size() const { return subgraphs_.size(); } 676 677 /// Get a pointer to a subgraph if in bounds. 678 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)679 const Subgraph* subgraph(int subgraph_index) const { 680 if (subgraph_index < 0 || 681 static_cast<size_t>(subgraph_index) >= subgraphs_size()) { 682 return nullptr; 683 } 684 return subgraphs_[subgraph_index].get(); 685 } 686 687 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)688 Subgraph* subgraph(int subgraph_index) { 689 return const_cast<Subgraph*>( 690 static_cast<const Interpreter*>(this)->subgraph(subgraph_index)); 691 } 692 693 /// WARNING: Experimental interface, subject to change primary_subgraph()694 Subgraph& primary_subgraph() { 695 return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry. 696 } 697 698 /// WARNING: Experimental interface, subject to change primary_subgraph()699 const Subgraph& primary_subgraph() const { 700 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. 701 } 702 703 /// WARNING: Experimental interface, subject to change 704 // Get the error reporter associated with this interpreter. error_reporter()705 ErrorReporter* error_reporter() const { return error_reporter_; } 706 707 #endif // DOXYGEN_SKIP 708 709 private: 710 friend class InterpreterBuilder; 711 friend class tflite::InterpreterTest; 712 friend class tflite::delegates::InterpreterUtils; 713 friend class tflite::delegates::test_utils::TestDelegation; 714 friend class tflite::interpreter_wrapper::InterpreterWrapper; 715 716 /// Set the value of an external context. 717 static void SetExternalContext(struct TfLiteContext* context, 718 TfLiteExternalContextType type, 719 TfLiteExternalContext* ctx); 720 721 // Helper method that return the tensor index that corresponds to 722 // a name in a SignatureDef. Defined by 'signature_key', and 723 // 'signature_tensor_name'. 724 // If 'is_input' is true then the tensor is checked in input tensors, 725 // otherwise it will be checked in output tensors. 726 // Returns -1 if the tensor is not found. GetTensorIndexFromSignature(const char * signature_tensor_name,const char * signature_key,bool is_input)727 int GetTensorIndexFromSignature(const char* signature_tensor_name, 728 const char* signature_key, 729 bool is_input) const { 730 // Iterate directly and don't use other methods to avoid extra allocation. 731 for (const auto& signature : signature_defs_) { 732 if (signature.signature_key != signature_key) continue; 733 auto& signature_list = (is_input ? signature.inputs : signature.outputs); 734 auto tensor_iter = signature_list.find(signature_tensor_name); 735 if (tensor_iter == signature_list.end()) return -1; 736 return tensor_iter->second; 737 } 738 return -1; 739 } 740 741 // Applies TFLite default delegates. 742 TfLiteStatus ApplyLazyDelegateProviders(); 743 744 // Overrides execution plan. This bounds checks indices sent in. 745 // Note: Only used during initialization. 746 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); 747 748 // Sets the profiler to all subgraphs. 749 void SetSubgraphProfiler(); 750 751 // Remove delegates (for fallback behaviour). The interpreter is invokable 752 // afterwards. 753 TfLiteStatus RemoveAllDelegates(); 754 755 // Returns true if delegates have been applied. 756 bool HasDelegates(); 757 758 // Returns true if cancellation function returns true. 759 bool IsCancelled(); 760 761 // Sets the list of signature defs in the model. SetSignatureDef(std::vector<internal::SignatureDef> signature_defs)762 void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) { 763 signature_defs_ = std::move(signature_defs); 764 } 765 766 // Enables preserving intermediates for debugging. Should only be set by 767 // InterpreterBuilder before allocating any tensors. 768 TfLiteStatus PreserveAllTensorsExperimental(); 769 770 // Sets model metadata as a mapping of name (key) and buffer (value) strings. 771 // Used by InterpreterBuilder, should be called after setting up subgraphs. 772 TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata); 773 774 // A pure C data structure used to communicate with the pure C plugin 775 // interface. To avoid copying tensor metadata, this is also the definitive 776 // structure to store tensors. 777 // This is the primary subgraph context. 778 TfLiteContext* context_ = nullptr; 779 780 // The error reporter delegate that tflite will forward queries errors to. 781 ErrorReporter* error_reporter_ = nullptr; 782 783 // List of delegates that have been installed and are owned by this 784 // interpreter instance. Useful if client delegate ownership is burdensome. 785 // WARNING: This is an experimental API and subject to change. 786 // TODO(b/116667551): Use TfLiteExternalContext for storing state. 787 std::vector< 788 std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>> 789 owned_delegates_; 790 791 // Profiler that has been installed and is owned by this interpreter instance. 792 // Useful if client profiler ownership is burdensome. 793 std::unique_ptr<Profiler> owned_profiler_; 794 795 // Points to the installed Profiler instance. 796 Profiler* installed_profiler_ = nullptr; 797 798 bool allow_buffer_handle_output_ = false; 799 800 // List of active external contexts. 801 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; 802 803 // The default external cpu backend context. After an TFLite interpreter is 804 // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point 805 // to this object. However, if this element value is overwritten via calling 806 // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to 807 // nullptr if necessary. 808 std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_; 809 810 // Subgraphs 811 std::vector<std::unique_ptr<Subgraph>> subgraphs_; 812 813 // A map of resources. Owned by interpreter and shared by multiple subgraphs. 814 resource::ResourceMap resources_; 815 816 // A map of resource Ids. Owned by interpreter and shared by multiple 817 // subgraphs. 818 resource::ResourceIDMap resource_ids_; 819 820 // A map of intialization statuses, that indicate whether the intialization 821 // subgraph invocation is done or not. Owned by interpreter and shared by 822 // multiple subgraphs. 823 resource::InitializationStatusMap initialization_status_map_; 824 825 // Indicating delegates that the TFLite interpreter will apply by default. 826 // An empty one means there's no delegate to be applied by default or 827 // delegates have been applied and doesn't need to be applied again. 828 std::vector<TfLiteDelegatePtr> lazy_delegate_providers_; 829 830 // List of SignatureDefs obtained from the model. 831 std::vector<internal::SignatureDef> signature_defs_; 832 833 // Map of signature key to its corresponding SignatureRunner object. 834 // A SignatureRunner is basically a wrapper of the Subgraph corresponding to 835 // its SignatureDef. 836 std::map<std::string, SignatureRunner> signature_runner_map_; 837 838 // Model metadata stored as mapping of name (key) to buffer (value). 839 // Data is mapped from the Metadata in TFLite flatbuffer model. 840 std::map<std::string, std::string> metadata_; 841 }; 842 843 } // namespace tflite 844 #endif // TENSORFLOW_LITE_INTERPRETER_H_ 845