• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace topk_v2 {
26 constexpr int kInputTensor = 0;
27 constexpr int kInputTopK = 1;
28 constexpr int kOutputValues = 0;
29 constexpr int kOutputIndexes = 1;
30 
31 namespace {
ResizeOutput(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
33   const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
34   // INT32 number of top results is supported.
35   TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
36   // Check that the tensor contains only one value.
37   TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
38   const int32 k = *GetTensorData<int32_t>(top_k);
39 
40   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
41   const int num_dimensions = NumDimensions(input);
42   // Check that input has one or more dimensions.
43   TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
44                      "TopK k input must have 1 or more dimensions.");
45   // Check that k is less or equal the internal dimension.
46   TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
47                      "TopK k is higher than the internal dimension.");
48 
49   TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
50   TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
51   for (int i = 0; i < num_dimensions - 1; ++i) {
52     output_indexes_shape->data[i] = input->dims->data[i];
53     output_values_shape->data[i] = input->dims->data[i];
54   }
55   output_indexes_shape->data[num_dimensions - 1] = k;
56   output_values_shape->data[num_dimensions - 1] = k;
57   TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
58   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
59   // Force output types.
60   output_indexes->type = kTfLiteInt32;
61   output_values->type = input->type;
62   auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
63                                  TfLiteIntArray* delete_on_error) {
64     TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
65     if (status != kTfLiteOk) {
66       if (delete_on_error != nullptr) {
67         TfLiteIntArrayFree(delete_on_error);
68       }
69     }
70     return status;
71   };
72   TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
73                                            output_values_shape));
74   TF_LITE_ENSURE_OK(context,
75                     resize_tensor(output_values, output_values_shape, nullptr));
76   return kTfLiteOk;
77 }
78 
79 // The class that collects top indexes of k values. Based on template
80 // tensorflow::gtl::TopN<> but, for optimization,
81 // it re-uses the same container.
82 template <typename T>
83 class TopContainer {
84  public:
85   TopContainer() = delete;
TopContainer(int32 k,int32 row_size)86   TopContainer(int32 k, int32 row_size) : k_(k) {
87     container_.reserve(std::min(k, row_size) + 1);
88   }
89 
start_collecting(const T * values)90   void start_collecting(const T* values) {
91     values_ = values;
92     container_.clear();
93   }
push(int32 a)94   void push(int32 a) {
95     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
96     if (container_.size() <= k_) {
97       container_.push_back(a);
98       if (container_.size() == k_ + 1) {
99         std::make_heap(container_.begin(), container_.end(), comparator);
100         std::pop_heap(container_.begin(), container_.end(), comparator);
101       }
102     } else if (comparator(a, container_.front())) {
103       container_.back() = a;
104       std::push_heap(container_.begin(), container_.end(), comparator);
105       std::pop_heap(container_.begin(), container_.end(), comparator);
106     }
107   }
108 
sorted_result()109   const std::vector<int32>& sorted_result() {
110     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
111     if (container_.size() <= k_) {
112       std::sort(container_.begin(), container_.end(), comparator);
113     } else {
114       std::sort_heap(container_.begin(), container_.end() - 1, comparator);
115       container_.resize(k_);
116     }
117     return container_;
118   }
119 
120  private:
121   int32 k_;
122   std::vector<int32> container_;
123   const T* values_ = nullptr;
124 
compare_fun(int32 a,int32 b) const125   bool compare_fun(int32 a, int32 b) const {
126     if (values_[b] < values_[a]) {
127       return true;
128     } else if (values_[b] > values_[a]) {
129       return false;
130     } else {
131       return a < b;
132     }
133   }
134 };
135 
136 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
137 template <typename T>
TopK(int32 row_size,int32 num_rows,const T * data,int32 k,int32 * output_indexes,T * output_values)138 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
139           int32* output_indexes, T* output_values) {
140   TopContainer<T> topc(k, row_size);
141   for (int row = 0; row < num_rows; ++row) {
142     const T* values_row = data + row * row_size;
143     topc.start_collecting(values_row);
144     for (int32 c = 0; c < row_size; ++c) {
145       topc.push(c);
146     }
147 
148     // Prepare output buffers.
149     int32* indexes_row = output_indexes + row * k;
150     T* output_row = output_values + row * k;
151     // We always assume that the output is sorted.
152     const auto& top_k = topc.sorted_result();
153     std::copy(top_k.begin(), top_k.end(), indexes_row);
154     std::transform(top_k.begin(), top_k.end(), output_row,
155                    [values_row](const int32 loc) { return values_row[loc]; });
156   }
157 }
158 
159 }  // namespace
160 
Prepare(TfLiteContext * context,TfLiteNode * node)161 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
162   // Check that the inputs and outputs have the right sizes and types.
163   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
164   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
165 
166   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
167   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
168   TF_LITE_ENSURE_EQ(context, input->type, output_values->type);
169 
170   const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
171   TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
172 
173   // Set output dynamic if the input is not const.
174   if (IsConstantTensor(top_k)) {
175     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
176   } else {
177     TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
178     TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
179     SetTensorToDynamic(output_indexes);
180     SetTensorToDynamic(output_values);
181   }
182   return kTfLiteOk;
183 }
184 
Eval(TfLiteContext * context,TfLiteNode * node)185 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
186   TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
187   TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
188   if (IsDynamicTensor(output_values)) {
189     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
190   }
191   const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
192   const int32 k = top_k->data.i32[0];
193   // The tensor can have more than 2 dimensions or even be a vector, the code
194   // anyway calls the internal dimension as row;
195   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
196   const int32 row_size = input->dims->data[input->dims->size - 1];
197   int32 num_rows = 1;
198   for (int i = 0; i < input->dims->size - 1; ++i) {
199     num_rows *= input->dims->data[i];
200   }
201   switch (output_values->type) {
202     case kTfLiteFloat32:
203       TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32,
204            output_values->data.f);
205       break;
206     case kTfLiteUInt8:
207       TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
208            output_values->data.uint8);
209       break;
210     case kTfLiteInt8:
211       TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
212            output_values->data.int8);
213       break;
214     case kTfLiteInt32:
215       TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
216            output_values->data.i32);
217       break;
218     case kTfLiteInt64:
219       TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
220            output_values->data.i64);
221       break;
222     default:
223       context->ReportError(context,
224                            "Type %d is currently not supported by TopK.",
225                            output_values->type);
226       return kTfLiteError;
227   }
228 
229   return kTfLiteOk;
230 }
231 }  // namespace topk_v2
Register_TOPK_V2()232 TfLiteRegistration* Register_TOPK_V2() {
233   static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
234                                  topk_v2::Eval};
235   return &r;
236 }
237 }  // namespace builtin
238 }  // namespace ops
239 }  // namespace tflite
240