1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <algorithm>
18 #include <iterator>
19 #include <vector>
20
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace topk_v2 {
31 constexpr int kInputTensor = 0;
32 constexpr int kInputTopK = 1;
33 constexpr int kOutputValues = 0;
34 constexpr int kOutputIndexes = 1;
35
36 namespace {
ResizeOutput(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
38 const TfLiteTensor* top_k;
39 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
40 // INT32 number of top results is supported.
41 TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
42 // Check that the tensor contains only one value.
43 TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
44 const int32 k = *GetTensorData<int32_t>(top_k);
45
46 const TfLiteTensor* input;
47 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
48 const int num_dimensions = NumDimensions(input);
49 // Check that input has one or more dimensions.
50 TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
51 "TopK k input must have 1 or more dimensions.");
52 // Check that k is less or equal the internal dimension.
53 TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
54 "TopK k is higher than the internal dimension.");
55
56 TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
57 TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
58 for (int i = 0; i < num_dimensions - 1; ++i) {
59 output_indexes_shape->data[i] = input->dims->data[i];
60 output_values_shape->data[i] = input->dims->data[i];
61 }
62 output_indexes_shape->data[num_dimensions - 1] = k;
63 output_values_shape->data[num_dimensions - 1] = k;
64 TfLiteTensor* output_indexes;
65 TF_LITE_ENSURE_OK(
66 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
67 TfLiteTensor* output_values;
68 TF_LITE_ENSURE_OK(
69 context, GetOutputSafe(context, node, kOutputValues, &output_values));
70 // Force output types.
71 output_indexes->type = kTfLiteInt32;
72 output_values->type = input->type;
73 auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
74 TfLiteIntArray* delete_on_error) {
75 TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
76 if (status != kTfLiteOk) {
77 if (delete_on_error != nullptr) {
78 TfLiteIntArrayFree(delete_on_error);
79 }
80 }
81 return status;
82 };
83 TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
84 output_values_shape));
85 TF_LITE_ENSURE_OK(context,
86 resize_tensor(output_values, output_values_shape, nullptr));
87 return kTfLiteOk;
88 }
89
90 // Class that collects indices of top k values. Based on template
91 // tensorflow::gtl::TopN<> but, for optimization, it re-uses the same container.
92 template <typename T>
93 class TopContainer {
94 public:
95 TopContainer() = delete;
TopContainer(int32 k,int32 row_size)96 TopContainer(int32 k, int32 row_size) : k_(k) {
97 container_.reserve(std::min(k, row_size) + 1);
98 }
99
start_collecting(const T * values)100 void start_collecting(const T* values) {
101 values_ = values;
102 container_.clear();
103 }
push(int32 a)104 void push(int32 a) {
105 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
106 if (container_.size() <= k_) {
107 container_.push_back(a);
108 if (container_.size() == k_ + 1) {
109 std::make_heap(container_.begin(), container_.end(), comparator);
110 std::pop_heap(container_.begin(), container_.end(), comparator);
111 }
112 } else if (comparator(a, container_.front())) {
113 // Due to how we defined comparator / compare_fun, container_.front()
114 // contains the index of the smallest of the top-k elements seen so far.
115 //
116 // If control reaches this point, we know that the current index a
117 // corresponds to an element which is bigger than the smallest of the
118 // top-k elements seen so far. Hence, we have to update the indices of
119 // the top-k elements, by removing the index of the smallest top-k
120 // element, adding a, and making sure container_[0:k] is still a heap.
121
122 // Store index a into container_[k].
123 container_.back() = a;
124
125 // Swap container_[0] and container_[k], and rearrange elements from
126 // container_[0,k) such that they are a heap according to comparator. For
127 // more info, see https://en.cppreference.com/w/cpp/algorithm/pop_heap.
128 std::pop_heap(container_.begin(), container_.end(), comparator);
129 }
130 }
131
sorted_result()132 const std::vector<int32>& sorted_result() {
133 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
134 if (container_.size() <= k_) {
135 // Note: due to the way we defined compare_fun (see comments for that
136 // function) std::sort puts the indices from container_ in decreasing
137 // order of the corresponding elements.
138 std::sort(container_.begin(), container_.end(), comparator);
139 } else {
140 std::sort_heap(container_.begin(), container_.end() - 1, comparator);
141 container_.resize(k_);
142 }
143 return container_;
144 }
145
146 private:
147 const int32 k_;
148
149 // container_[0,k) holds the indices of the largest k elements from values_
150 // seen so far and are maintained in a min-heap order: container_.front() is
151 // the index of the smallest of the top-k elements see so far.
152 //
153 // container_[k] is used as temporary space (not part of the min-heap).
154 std::vector<int32> container_;
155
156 const T* values_ = nullptr;
157
158 // Compares indices a and b based on the corresponding elements from values_.
159 //
160 // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a]
161 // (notice the inversion of direction, not a typo); ties (==) are broken in
162 // favor of earlier elements (i.e., a < b).
compare_fun(int32 a,int32 b) const163 bool compare_fun(int32 a, int32 b) const {
164 if (values_[b] < values_[a]) {
165 return true;
166 } else if (values_[b] > values_[a]) {
167 return false;
168 } else {
169 return a < b;
170 }
171 }
172 };
173
174 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
175 template <typename T>
TopK(int32 row_size,int32 num_rows,const T * data,int32 k,int32 * output_indexes,T * output_values)176 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
177 int32* output_indexes, T* output_values) {
178 TopContainer<T> topc(k, row_size);
179 for (int row = 0; row < num_rows; ++row) {
180 const T* values_row = data + row * row_size;
181 topc.start_collecting(values_row);
182 for (int32 c = 0; c < row_size; ++c) {
183 topc.push(c);
184 }
185
186 // Prepare output buffers.
187 int32* indexes_row = output_indexes + row * k;
188 T* output_row = output_values + row * k;
189 // We always assume that the output is sorted.
190 const auto& top_k = topc.sorted_result();
191 std::copy(top_k.begin(), top_k.end(), indexes_row);
192 std::transform(top_k.begin(), top_k.end(), output_row,
193 [values_row](const int32 loc) { return values_row[loc]; });
194 }
195 }
196
197 } // namespace
198
Prepare(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
200 // Check that the inputs and outputs have the right sizes and types.
201 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
202 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
203
204 const TfLiteTensor* input;
205 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
206 TfLiteTensor* output_values;
207 TF_LITE_ENSURE_OK(
208 context, GetOutputSafe(context, node, kOutputValues, &output_values));
209 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
210
211 const TfLiteTensor* top_k;
212 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
213 TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
214
215 // Set output dynamic if the input is not const.
216 if (IsConstantTensor(top_k)) {
217 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
218 } else {
219 TfLiteTensor* output_indexes;
220 TF_LITE_ENSURE_OK(
221 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
222 TfLiteTensor* output_values;
223 TF_LITE_ENSURE_OK(
224 context, GetOutputSafe(context, node, kOutputValues, &output_values));
225 SetTensorToDynamic(output_indexes);
226 SetTensorToDynamic(output_values);
227 }
228 return kTfLiteOk;
229 }
230
Eval(TfLiteContext * context,TfLiteNode * node)231 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
232 TfLiteTensor* output_values;
233 TF_LITE_ENSURE_OK(
234 context, GetOutputSafe(context, node, kOutputValues, &output_values));
235 TfLiteTensor* output_indexes;
236 TF_LITE_ENSURE_OK(
237 context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
238 if (IsDynamicTensor(output_values)) {
239 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
240 }
241 const TfLiteTensor* top_k;
242 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
243 const int32 k = top_k->data.i32[0];
244 // The tensor can have more than 2 dimensions or even be a vector, the code
245 // anyway calls the internal dimension as row;
246 const TfLiteTensor* input;
247 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
248 const int32 row_size = input->dims->data[input->dims->size - 1];
249 int32 num_rows = 1;
250 for (int i = 0; i < input->dims->size - 1; ++i) {
251 num_rows *= input->dims->data[i];
252 }
253 switch (output_values->type) {
254 case kTfLiteFloat32:
255 TopK(row_size, num_rows, GetTensorData<float>(input), k,
256 output_indexes->data.i32, GetTensorData<float>(output_values));
257 break;
258 case kTfLiteUInt8:
259 TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
260 output_values->data.uint8);
261 break;
262 case kTfLiteInt8:
263 TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
264 output_values->data.int8);
265 break;
266 case kTfLiteInt32:
267 TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
268 output_values->data.i32);
269 break;
270 case kTfLiteInt64:
271 TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
272 output_values->data.i64);
273 break;
274 default:
275 TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK.",
276 TfLiteTypeGetName(output_values->type));
277 return kTfLiteError;
278 }
279
280 return kTfLiteOk;
281 }
282 } // namespace topk_v2
Register_TOPK_V2()283 TfLiteRegistration* Register_TOPK_V2() {
284 static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
285 topk_v2::Eval};
286 return &r;
287 }
288 } // namespace builtin
289 } // namespace ops
290 } // namespace tflite
291