• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <algorithm>
18 #include <iterator>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace topk_v2 {
31 constexpr int kInputTensor = 0;
32 constexpr int kInputTopK = 1;
33 constexpr int kOutputValues = 0;
34 constexpr int kOutputIndexes = 1;
35 
36 namespace {
ResizeOutput(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
38   const TfLiteTensor* top_k;
39   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
40   // INT32 number of top results is supported.
41   TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
42   // Check that the tensor contains only one value.
43   TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
44   const int32 k = *GetTensorData<int32_t>(top_k);
45 
46   const TfLiteTensor* input;
47   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
48   const int num_dimensions = NumDimensions(input);
49   // Check that input has one or more dimensions.
50   TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
51                      "TopK k input must have 1 or more dimensions.");
52   // Check that k is less or equal the internal dimension.
53   TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
54                      "TopK k is higher than the internal dimension.");
55 
56   TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
57   TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
58   for (int i = 0; i < num_dimensions - 1; ++i) {
59     output_indexes_shape->data[i] = input->dims->data[i];
60     output_values_shape->data[i] = input->dims->data[i];
61   }
62   output_indexes_shape->data[num_dimensions - 1] = k;
63   output_values_shape->data[num_dimensions - 1] = k;
64   TfLiteTensor* output_indexes;
65   TF_LITE_ENSURE_OK(
66       context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
67   TfLiteTensor* output_values;
68   TF_LITE_ENSURE_OK(
69       context, GetOutputSafe(context, node, kOutputValues, &output_values));
70   // Force output types.
71   output_indexes->type = kTfLiteInt32;
72   output_values->type = input->type;
73   auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
74                                  TfLiteIntArray* delete_on_error) {
75     TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
76     if (status != kTfLiteOk) {
77       if (delete_on_error != nullptr) {
78         TfLiteIntArrayFree(delete_on_error);
79       }
80     }
81     return status;
82   };
83   TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
84                                            output_values_shape));
85   TF_LITE_ENSURE_OK(context,
86                     resize_tensor(output_values, output_values_shape, nullptr));
87   return kTfLiteOk;
88 }
89 
90 // Class that collects indices of top k values.  Based on template
91 // tensorflow::gtl::TopN<> but, for optimization, it re-uses the same container.
92 template <typename T>
93 class TopContainer {
94  public:
95   TopContainer() = delete;
TopContainer(int32 k,int32 row_size)96   TopContainer(int32 k, int32 row_size) : k_(k) {
97     container_.reserve(std::min(k, row_size) + 1);
98   }
99 
start_collecting(const T * values)100   void start_collecting(const T* values) {
101     values_ = values;
102     container_.clear();
103   }
push(int32 a)104   void push(int32 a) {
105     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
106     if (container_.size() <= k_) {
107       container_.push_back(a);
108       if (container_.size() == k_ + 1) {
109         std::make_heap(container_.begin(), container_.end(), comparator);
110         std::pop_heap(container_.begin(), container_.end(), comparator);
111       }
112     } else if (comparator(a, container_.front())) {
113       // Due to how we defined comparator / compare_fun, container_.front()
114       // contains the index of the smallest of the top-k elements seen so far.
115       //
116       // If control reaches this point, we know that the current index a
117       // corresponds to an element which is bigger than the smallest of the
118       // top-k elements seen so far.  Hence, we have to update the indices of
119       // the top-k elements, by removing the index of the smallest top-k
120       // element, adding a, and making sure container_[0:k] is still a heap.
121 
122       // Store index a into container_[k].
123       container_.back() = a;
124 
125       // Swap container_[0] and container_[k], and rearrange elements from
126       // container_[0,k) such that they are a heap according to comparator.  For
127       // more info, see https://en.cppreference.com/w/cpp/algorithm/pop_heap.
128       std::pop_heap(container_.begin(), container_.end(), comparator);
129     }
130   }
131 
sorted_result()132   const std::vector<int32>& sorted_result() {
133     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
134     if (container_.size() <= k_) {
135       // Note: due to the way we defined compare_fun (see comments for that
136       // function) std::sort puts the indices from container_ in decreasing
137       // order of the corresponding elements.
138       std::sort(container_.begin(), container_.end(), comparator);
139     } else {
140       std::sort_heap(container_.begin(), container_.end() - 1, comparator);
141       container_.resize(k_);
142     }
143     return container_;
144   }
145 
146  private:
147   const int32 k_;
148 
149   // container_[0,k) holds the indices of the largest k elements from values_
150   // seen so far and are maintained in a min-heap order: container_.front() is
151   // the index of the smallest of the top-k elements see so far.
152   //
153   // container_[k] is used as temporary space (not part of the min-heap).
154   std::vector<int32> container_;
155 
156   const T* values_ = nullptr;
157 
158   // Compares indices a and b based on the corresponding elements from values_.
159   //
160   // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a]
161   // (notice the inversion of direction, not a typo); ties (==) are broken in
162   // favor of earlier elements (i.e., a < b).
compare_fun(int32 a,int32 b) const163   bool compare_fun(int32 a, int32 b) const {
164     if (values_[b] < values_[a]) {
165       return true;
166     } else if (values_[b] > values_[a]) {
167       return false;
168     } else {
169       return a < b;
170     }
171   }
172 };
173 
174 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
175 template <typename T>
TopK(int32 row_size,int32 num_rows,const T * data,int32 k,int32 * output_indexes,T * output_values)176 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
177           int32* output_indexes, T* output_values) {
178   TopContainer<T> topc(k, row_size);
179   for (int row = 0; row < num_rows; ++row) {
180     const T* values_row = data + row * row_size;
181     topc.start_collecting(values_row);
182     for (int32 c = 0; c < row_size; ++c) {
183       topc.push(c);
184     }
185 
186     // Prepare output buffers.
187     int32* indexes_row = output_indexes + row * k;
188     T* output_row = output_values + row * k;
189     // We always assume that the output is sorted.
190     const auto& top_k = topc.sorted_result();
191     std::copy(top_k.begin(), top_k.end(), indexes_row);
192     std::transform(top_k.begin(), top_k.end(), output_row,
193                    [values_row](const int32 loc) { return values_row[loc]; });
194   }
195 }
196 
197 }  // namespace
198 
Prepare(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
200   // Check that the inputs and outputs have the right sizes and types.
201   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
202   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
203 
204   const TfLiteTensor* input;
205   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
206   TfLiteTensor* output_values;
207   TF_LITE_ENSURE_OK(
208       context, GetOutputSafe(context, node, kOutputValues, &output_values));
209   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
210 
211   const TfLiteTensor* top_k;
212   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
213   TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
214 
215   // Set output dynamic if the input is not const.
216   if (IsConstantTensor(top_k)) {
217     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
218   } else {
219     TfLiteTensor* output_indexes;
220     TF_LITE_ENSURE_OK(
221         context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
222     TfLiteTensor* output_values;
223     TF_LITE_ENSURE_OK(
224         context, GetOutputSafe(context, node, kOutputValues, &output_values));
225     SetTensorToDynamic(output_indexes);
226     SetTensorToDynamic(output_values);
227   }
228   return kTfLiteOk;
229 }
230 
Eval(TfLiteContext * context,TfLiteNode * node)231 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
232   TfLiteTensor* output_values;
233   TF_LITE_ENSURE_OK(
234       context, GetOutputSafe(context, node, kOutputValues, &output_values));
235   TfLiteTensor* output_indexes;
236   TF_LITE_ENSURE_OK(
237       context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
238   if (IsDynamicTensor(output_values)) {
239     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
240   }
241   const TfLiteTensor* top_k;
242   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
243   const int32 k = top_k->data.i32[0];
244   // The tensor can have more than 2 dimensions or even be a vector, the code
245   // anyway calls the internal dimension as row;
246   const TfLiteTensor* input;
247   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
248   const int32 row_size = input->dims->data[input->dims->size - 1];
249   int32 num_rows = 1;
250   for (int i = 0; i < input->dims->size - 1; ++i) {
251     num_rows *= input->dims->data[i];
252   }
253   switch (output_values->type) {
254     case kTfLiteFloat32:
255       TopK(row_size, num_rows, GetTensorData<float>(input), k,
256            output_indexes->data.i32, GetTensorData<float>(output_values));
257       break;
258     case kTfLiteUInt8:
259       TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
260            output_values->data.uint8);
261       break;
262     case kTfLiteInt8:
263       TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
264            output_values->data.int8);
265       break;
266     case kTfLiteInt32:
267       TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
268            output_values->data.i32);
269       break;
270     case kTfLiteInt64:
271       TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
272            output_values->data.i64);
273       break;
274     default:
275       TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK.",
276                          TfLiteTypeGetName(output_values->type));
277       return kTfLiteError;
278   }
279 
280   return kTfLiteOk;
281 }
282 }  // namespace topk_v2
Register_TOPK_V2()283 TfLiteRegistration* Register_TOPK_V2() {
284   static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
285                                  topk_v2::Eval};
286   return &r;
287 }
288 }  // namespace builtin
289 }  // namespace ops
290 }  // namespace tflite
291