• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
16 #define TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
17 
18 #include <map>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/lite/c/common.h"
22 
23 namespace tflite {
24 namespace flex {
25 
26 // Maps a TF Lite tensor index into a TensorFlow tensor.
27 //
28 // The TF Lite interpreter assigns integer indices to each of its tensors, but
29 // the Flex delegate deals in terms of TensorFlow tensors. This class maps
30 // from indices to tensors and allows the creation of new tensors to be
31 // associated with a given index.
32 class BufferMap {
33  public:
34   BufferMap();
35   ~BufferMap();
36 
37   // Returns true if the given 'tensor_index' has a corresponding
38   // tensorflow::Tensor.
39   bool HasTensor(int tensor_index) const;
40 
41   // Returns true if the given 'tensor_index' has a corresponding
42   // tensorflow::Tensor *and* the content is owned by TensorFlow (that is, the
43   // mapping was added by SetFromTensorFlow()).
44   bool IsTensorFlowTensor(int tensor_index) const;
45 
46   // Returns the tensorflow::Tensor associated with the given 'tensor_index'.
47   // Precondition: HasTensor() is true.
48   tensorflow::Tensor GetTensor(int tensor_index) const;
49 
50   // Returns the const pointer to tensorflow::Tensor associated with the given
51   // 'tensor_index'.
52   // Precondition: HasTensor() is true.
53   const tensorflow::Tensor* GetTensorPtr(int tensor_index) const;
54 
55   // Associates the given tensorflow::Tensor with the given 'tensor_index'.
56   // Note that TensorFlow Tensors share data buffers, so this method is only a
57   // shallow copy.
58   void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor);
59 
60   // Same as above but creates a new tensorflow::Tensor with a copy of the
61   // given TfLiteTensor's data.
62   void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor);
63 
64  private:
65   // Mapping from TL Lite tensor ID to TensorFlow's Tensor. All tensors that
66   // are inputs or outputs of a subgraph will be added here, irrespective of
67   // whether their data are managed by TF Lite or TensorFlow.
68   std::map<int, tensorflow::Tensor> id_to_tensor_;
69   // A list of tensors that are completely managed by TensorFlow. Most of the
70   // time, TF Lite will populate tensors that are inputs to subgraphs, while
71   // TensorFlow will populate output tensors. Occasionally, however, an input
72   // tensor is coming from a previous subgraph and could have been populated by
73   // TensorFlow. This set keeps track of all input or output tensors that have
74   // been populated by tensorflow.
75   std::set<int> owned_by_tf_;
76 };
77 
78 }  // namespace flex
79 }  // namespace tflite
80 
81 #endif  // TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
82