1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ 16 #define TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ 17 18 #include <map> 19 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/lite/c/common.h" 22 23 namespace tflite { 24 namespace flex { 25 26 // Maps a TF Lite tensor index into a TensorFlow tensor. 27 // 28 // The TF Lite interpreter assigns integer indices to each of its tensors, but 29 // the Flex delegate deals in terms of TensorFlow tensors. This class maps 30 // from indices to tensors and allows the creation of new tensors to be 31 // associated with a given index. 32 class BufferMap { 33 public: 34 BufferMap(); 35 ~BufferMap(); 36 37 // Returns true if the given 'tensor_index' has a corresponding 38 // tensorflow::Tensor. 39 bool HasTensor(int tensor_index) const; 40 41 // Returns true if the given 'tensor_index' has a corresponding 42 // tensorflow::Tensor *and* the content is owned by TensorFlow (that is, the 43 // mapping was added by SetFromTensorFlow()). 44 bool IsTensorFlowTensor(int tensor_index) const; 45 46 // Returns the tensorflow::Tensor associated with the given 'tensor_index'. 47 // Precondition: HasTensor() is true. 48 tensorflow::Tensor GetTensor(int tensor_index) const; 49 50 // Returns the const pointer to tensorflow::Tensor associated with the given 51 // 'tensor_index'. 52 // Precondition: HasTensor() is true. 53 const tensorflow::Tensor* GetTensorPtr(int tensor_index) const; 54 55 // Associates the given tensorflow::Tensor with the given 'tensor_index'. 56 // Note that TensorFlow Tensors share data buffers, so this method is only a 57 // shallow copy. 58 void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor); 59 60 // Same as above but creates a new tensorflow::Tensor with a copy of the 61 // given TfLiteTensor's data. 62 void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); 63 64 private: 65 // Mapping from TL Lite tensor ID to TensorFlow's Tensor. All tensors that 66 // are inputs or outputs of a subgraph will be added here, irrespective of 67 // whether their data are managed by TF Lite or TensorFlow. 68 std::map<int, tensorflow::Tensor> id_to_tensor_; 69 // A list of tensors that are completely managed by TensorFlow. Most of the 70 // time, TF Lite will populate tensors that are inputs to subgraphs, while 71 // TensorFlow will populate output tensors. Occasionally, however, an input 72 // tensor is coming from a previous subgraph and could have been populated by 73 // TensorFlow. This set keeps track of all input or output tensors that have 74 // been populated by tensorflow. 75 std::set<int> owned_by_tf_; 76 }; 77 78 } // namespace flex 79 } // namespace tflite 80 81 #endif // TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ 82