1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace topk_v2 {
26 constexpr int kInputTensor = 0;
27 constexpr int kInputTopK = 1;
28 constexpr int kOutputValues = 0;
29 constexpr int kOutputIndexes = 1;
30
31 namespace {
ResizeOutput(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
33 const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
34 // INT32 number of top results is supported.
35 TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
36 // Check that the tensor contains only one value.
37 TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
38 const int32 k = *GetTensorData<int32_t>(top_k);
39
40 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
41 const int num_dimensions = NumDimensions(input);
42 // Check that input has one or more dimensions.
43 TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
44 "TopK k input must have 1 or more dimensions.");
45 // Check that k is less or equal the internal dimension.
46 TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
47 "TopK k is higher than the internal dimension.");
48
49 TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
50 TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
51 for (int i = 0; i < num_dimensions - 1; ++i) {
52 output_indexes_shape->data[i] = input->dims->data[i];
53 output_values_shape->data[i] = input->dims->data[i];
54 }
55 output_indexes_shape->data[num_dimensions - 1] = k;
56 output_values_shape->data[num_dimensions - 1] = k;
57 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
58 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
59 // Force output types.
60 output_indexes->type = kTfLiteInt32;
61 output_values->type = input->type;
62 auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
63 TfLiteIntArray* delete_on_error) {
64 TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
65 if (status != kTfLiteOk) {
66 if (delete_on_error != nullptr) {
67 TfLiteIntArrayFree(delete_on_error);
68 }
69 }
70 return status;
71 };
72 TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
73 output_values_shape));
74 TF_LITE_ENSURE_OK(context,
75 resize_tensor(output_values, output_values_shape, nullptr));
76 return kTfLiteOk;
77 }
78
79 // The class that collects top indexes of k values. Based on template
80 // tensorflow::gtl::TopN<> but, for optimization,
81 // it re-uses the same container.
82 template <typename T>
83 class TopContainer {
84 public:
85 TopContainer() = delete;
TopContainer(int32 k,int32 row_size)86 TopContainer(int32 k, int32 row_size) : k_(k) {
87 container_.reserve(std::min(k, row_size) + 1);
88 }
89
start_collecting(const T * values)90 void start_collecting(const T* values) {
91 values_ = values;
92 container_.clear();
93 }
push(int32 a)94 void push(int32 a) {
95 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
96 if (container_.size() <= k_) {
97 container_.push_back(a);
98 if (container_.size() == k_ + 1) {
99 std::make_heap(container_.begin(), container_.end(), comparator);
100 std::pop_heap(container_.begin(), container_.end(), comparator);
101 }
102 } else if (comparator(a, container_.front())) {
103 container_.back() = a;
104 std::push_heap(container_.begin(), container_.end(), comparator);
105 std::pop_heap(container_.begin(), container_.end(), comparator);
106 }
107 }
108
sorted_result()109 const std::vector<int32>& sorted_result() {
110 auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
111 if (container_.size() <= k_) {
112 std::sort(container_.begin(), container_.end(), comparator);
113 } else {
114 std::sort_heap(container_.begin(), container_.end() - 1, comparator);
115 container_.resize(k_);
116 }
117 return container_;
118 }
119
120 private:
121 int32 k_;
122 std::vector<int32> container_;
123 const T* values_ = nullptr;
124
compare_fun(int32 a,int32 b) const125 bool compare_fun(int32 a, int32 b) const {
126 if (values_[b] < values_[a]) {
127 return true;
128 } else if (values_[b] > values_[a]) {
129 return false;
130 } else {
131 return a < b;
132 }
133 }
134 };
135
136 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
137 template <typename T>
TopK(int32 row_size,int32 num_rows,const T * data,int32 k,int32 * output_indexes,T * output_values)138 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
139 int32* output_indexes, T* output_values) {
140 TopContainer<T> topc(k, row_size);
141 for (int row = 0; row < num_rows; ++row) {
142 const T* values_row = data + row * row_size;
143 topc.start_collecting(values_row);
144 for (int32 c = 0; c < row_size; ++c) {
145 topc.push(c);
146 }
147
148 // Prepare output buffers.
149 int32* indexes_row = output_indexes + row * k;
150 T* output_row = output_values + row * k;
151 // We always assume that the output is sorted.
152 const auto& top_k = topc.sorted_result();
153 std::copy(top_k.begin(), top_k.end(), indexes_row);
154 std::transform(top_k.begin(), top_k.end(), output_row,
155 [values_row](const int32 loc) { return values_row[loc]; });
156 }
157 }
158
159 } // namespace
160
Prepare(TfLiteContext * context,TfLiteNode * node)161 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
162 // Check that the inputs and outputs have the right sizes and types.
163 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
164 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
165
166 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
167 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
168 TF_LITE_ENSURE_EQ(context, input->type, output_values->type);
169
170 const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
171 TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
172
173 // Set output dynamic if the input is not const.
174 if (IsConstantTensor(top_k)) {
175 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
176 } else {
177 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
178 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
179 SetTensorToDynamic(output_indexes);
180 SetTensorToDynamic(output_values);
181 }
182 return kTfLiteOk;
183 }
184
Eval(TfLiteContext * context,TfLiteNode * node)185 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
186 TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
187 TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
188 if (IsDynamicTensor(output_values)) {
189 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
190 }
191 const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
192 const int32 k = top_k->data.i32[0];
193 // The tensor can have more than 2 dimensions or even be a vector, the code
194 // anyway calls the internal dimension as row;
195 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
196 const int32 row_size = input->dims->data[input->dims->size - 1];
197 int32 num_rows = 1;
198 for (int i = 0; i < input->dims->size - 1; ++i) {
199 num_rows *= input->dims->data[i];
200 }
201 switch (output_values->type) {
202 case kTfLiteFloat32:
203 TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32,
204 output_values->data.f);
205 break;
206 case kTfLiteUInt8:
207 TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
208 output_values->data.uint8);
209 break;
210 case kTfLiteInt8:
211 TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
212 output_values->data.int8);
213 break;
214 case kTfLiteInt32:
215 TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
216 output_values->data.i32);
217 break;
218 case kTfLiteInt64:
219 TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
220 output_values->data.i64);
221 break;
222 default:
223 context->ReportError(context,
224 "Type %d is currently not supported by TopK.",
225 output_values->type);
226 return kTfLiteError;
227 }
228
229 return kTfLiteOk;
230 }
231 } // namespace topk_v2
Register_TOPK_V2()232 TfLiteRegistration* Register_TOPK_V2() {
233 static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
234 topk_v2::Eval};
235 return &r;
236 }
237 } // namespace builtin
238 } // namespace ops
239 } // namespace tflite
240