• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <numeric>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/op_macros.h"
29 #include "tensorflow/lite/string_util.h"
30 #include "tensorflow/lite/type_to_tflitetype.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32 
33 namespace tflite {
34 namespace task {
35 namespace core {
36 
37 // Checks if data type of tensor is T and returns the pointer casted to T if
38 // applicable, returns nullptr if tensor type is not T.
39 // See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
40 template <typename T>
TypedTensor(const TfLiteTensor * tensor_ptr)41 T* TypedTensor(const TfLiteTensor* tensor_ptr) {
42   if (tensor_ptr->type == typeToTfLiteType<T>()) {
43     return reinterpret_cast<T*>(tensor_ptr->data.raw);
44   }
45   return nullptr;
46 }
47 
48 // Checks and returns type of a tensor, fails if tensor type is not T.
49 template <typename T>
AssertAndReturnTypedTensor(const TfLiteTensor * tensor)50 T* AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
51   if (T* v = TypedTensor<T>(tensor)) return v;
52   // TODO(b/150903834): throw exceptions instead
53   TF_LITE_ASSERT(tensor->data.raw);
54   TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name,
55                              ". Requested ",
56                              TfLiteTypeGetName(typeToTfLiteType<T>()), ", got ",
57                              TfLiteTypeGetName(tensor->type), ".")
58                     .c_str());
59 }
60 
61 // Populates tensor with array of data, fails if data type doesn't match tensor
62 // type or has not the same number of elements.
63 template <typename T>
PopulateTensor(const T * data,int num_elements,TfLiteTensor * tensor)64 inline void PopulateTensor(const T* data, int num_elements,
65                            TfLiteTensor* tensor) {
66   T* v = AssertAndReturnTypedTensor<T>(tensor);
67   size_t bytes = num_elements * sizeof(T);
68   // TODO(b/150903834): throw exceptions instead
69   TF_LITE_ASSERT(tensor->bytes == bytes);
70   memcpy(v, data, bytes);
71 }
72 
73 // Populates tensor with vector of data, fails if data type doesn't match tensor
74 // type or has not the same number of elements.
75 template <typename T>
PopulateTensor(const std::vector<T> & data,TfLiteTensor * tensor)76 inline void PopulateTensor(const std::vector<T>& data, TfLiteTensor* tensor) {
77   return PopulateTensor<T>(data.data(), data.size(), tensor);
78 }
79 
80 template <>
81 inline void PopulateTensor<std::string>(const std::vector<std::string>& data,
82                                         TfLiteTensor* tensor) {
83   if (tensor->type != kTfLiteString) {
84     TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name,
85                                ". Requested STRING, got ",
86                                TfLiteTypeGetName(tensor->type), ".")
87                       .c_str());
88   }
89   tflite::DynamicBuffer input_buf;
90   for (const auto& value : data) {
91     input_buf.AddString(value.data(), value.length());
92   }
93   input_buf.WriteToTensorAsVector(tensor);
94 }
95 
96 // Populates tensor one data item, fails if data type doesn't match tensor
97 // type.
98 template <typename T>
PopulateTensor(const T & data,TfLiteTensor * tensor)99 inline void PopulateTensor(const T& data, TfLiteTensor* tensor) {
100   T* v = AssertAndReturnTypedTensor<T>(tensor);
101   *v = data;
102 }
103 
104 template <>
105 inline void PopulateTensor<std::string>(const std::string& data,
106                                         TfLiteTensor* tensor) {
107   tflite::DynamicBuffer input_buf;
108   input_buf.AddString(data.data(), data.length());
109   input_buf.WriteToTensorAsVector(tensor);
110 }
111 
112 // Populates a vector from the tensor, fails if data type doesn't match tensor
113 // type.
114 template <typename T>
PopulateVector(const TfLiteTensor * tensor,std::vector<T> * data)115 inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) {
116   AssertAndReturnTypedTensor<T>(tensor);
117   const T* results = GetTensorData<T>(tensor);
118   size_t num = tensor->bytes / sizeof(tensor->type);
119   data->reserve(num);
120   for (int i = 0; i < num; i++) {
121     data->emplace_back(results[i]);
122   }
123 }
124 
125 template <>
126 inline void PopulateVector<std::string>(const TfLiteTensor* tensor,
127                                         std::vector<std::string>* data) {
128   AssertAndReturnTypedTensor<std::string>(tensor);
129   int num = GetStringCount(tensor);
130   data->reserve(num);
131   for (int i = 0; i < num; i++) {
132     const auto& strref = tflite::GetString(tensor, i);
133     data->emplace_back(strref.str, strref.len);
134   }
135 }
136 
137 // Returns the reversely sorted indices of a vector.
138 template <typename T>
ReverseSortIndices(const std::vector<T> & v)139 std::vector<size_t> ReverseSortIndices(const std::vector<T>& v) {
140   std::vector<size_t> idx(v.size());
141   std::iota(idx.begin(), idx.end(), 0);
142 
143   std::stable_sort(idx.begin(), idx.end(),
144                    [&v](size_t i1, size_t i2) { return v[i2] < v[i1]; });
145 
146   return idx;
147 }
148 
149 // Returns the original (dequantized) value of the 'index'-th element of
150 // 'tensor.
151 double Dequantize(const TfLiteTensor& tensor, int index);
152 
153 // Returns the index-th string from the tensor.
154 std::string GetStringAtIndex(const TfLiteTensor* labels, int index);
155 
156 // Loads binary content of a file into a string.
157 std::string LoadBinaryContent(const char* filename);
158 
159 // Gets the tensor from a vector of tensors with name specified inside metadata.
160 template <typename TensorType>
FindTensorByName(const std::vector<TensorType * > & tensors,const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>> * tensor_metadatas,const std::string & name)161 static TensorType* FindTensorByName(
162     const std::vector<TensorType*>& tensors,
163     const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
164         tensor_metadatas,
165     const std::string& name) {
166   if (tensor_metadatas == nullptr ||
167       tensor_metadatas->size() != tensors.size()) {
168     return nullptr;
169   }
170   for (int i = 0; i < tensor_metadatas->size(); i++) {
171     if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) {
172       return tensors[i];
173     }
174   }
175   return nullptr;
176 }
177 
178 }  // namespace core
179 }  // namespace task
180 }  // namespace tflite
181 
182 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
183