• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/topk_op.h"
21 
22 #include <algorithm>
23 #include <numeric>
24 #include <vector>
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/top_n.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 
39 template <typename Device, typename T>
40 class TopK : public OpKernel {
41  public:
TopK(OpKernelConstruction * context)42   explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
44     if (num_inputs() < 2) {  // k is an attr (TopK).
45       OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
46     } else {  // k is an input (TopKV2), so we won't know it until Compute.
47       k_ = -1;
48     }
49   }
50 
Compute(OpKernelContext * context)51   void Compute(OpKernelContext* context) override {
52     int k = k_;
53     if (num_inputs() >= 2) {
54       const auto& k_in = context->input(1);
55       OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
56                   errors::InvalidArgument("k must be scalar, got shape ",
57                                           k_in.shape().DebugString()));
58       k = k_in.scalar<int32>()();
59     }
60     OP_REQUIRES(context, k >= 0,
61                 errors::InvalidArgument("Need k >= 0, got ", k));
62     const auto& input_in = context->input(0);
63     OP_REQUIRES(context, input_in.dims() >= 1,
64                 errors::InvalidArgument("input must be >= 1-D, got shape ",
65                                         input_in.shape().DebugString()));
66     OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
67                 errors::InvalidArgument(
68                     "input must have at least k columns. Had ",
69                     input_in.dim_size(input_in.dims() - 1), ", needed ", k));
70 
71     const auto& input = input_in.flat_inner_dims<T>();
72 
73     const int64 num_rows = input.dimension(0);  // generally batch_size
74     const int64 num_cols = input.dimension(1);
75 
76     TensorShape output_shape = input_in.shape();
77     output_shape.set_dim(input_in.dims() - 1, k);
78     Tensor* values_out = nullptr;
79     OP_REQUIRES_OK(context,
80                    context->allocate_output(0, output_shape, &values_out));
81     Tensor* indices_out = nullptr;
82     OP_REQUIRES_OK(context,
83                    context->allocate_output(1, output_shape, &indices_out));
84 
85     // Nothing to do for top-nothing or over nothing.
86     if (k == 0 || num_rows == 0) return;
87 
88     auto values = values_out->flat_inner_dims<T>();
89     auto indices = indices_out->flat_inner_dims<int32>();
90     Status s = functor::TopKFunctor<Device, T>::Compute(
91         context, sorted_, k, input, num_rows, num_cols, values, indices);
92     OP_REQUIRES_OK(context, s);
93   }
94 
95  private:
96   int k_;
97   bool sorted_;
98 };
99 
100 namespace functor {
101 
102 template <typename T>
103 struct TopKFunctor<CPUDevice, T> {
104   static EIGEN_ALWAYS_INLINE Status
Computetensorflow::functor::TopKFunctor105   Compute(OpKernelContext* context, bool sorted, int k,
106           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
107           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
108           typename TTypes<int, 2>::Tensor indices) {
109     const CPUDevice& d = context->eigen_device<CPUDevice>();
110 
111     // Special case for k == 1.
112     if (k == 1) {
113 #ifdef EIGEN_HAS_INDEX_LIST
114       typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
115       typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
116       rows_by_one.set(0, num_rows);
117 #else
118       Eigen::array<int, 1> reduce_on_cols = {1};
119       Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
120 #endif
121 
122       values.device(d) =
123           input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
124       // Get the indices of the maximum values.
125       for (int r = 0; r < num_rows; ++r) {
126         indices(r, 0) = 0;
127         for (int c = 0; c < num_cols; ++c) {
128           if (values(r, 0) == input(r, c)) {
129             indices(r, 0) = c;
130             break;
131           }
132         }
133         values(r, 0) = input(r, indices(r, 0));
134       }
135 
136       return Status::OK();
137     }
138 
139     auto SortIndices = [&](int64 start_batch, int64 limit_batch) {
140       for (int32 b = start_batch; b < limit_batch; ++b) {
141         const T* input_data = &input(b, 0);
142         const auto stable_comp = [input_data](const int32 a, const int32 b) {
143           if (input_data[b] < input_data[a]) {
144             return true;
145           } else if (input_data[b] > input_data[a]) {
146             return false;
147           } else {
148             return a < b;
149           }
150         };
151         const auto comp = [input_data](const int32 a, const int32 b) {
152           return input_data[b] < input_data[a];
153         };
154         // TODO(ebrevdo): For large k < num_cols, instead of using
155         // TopN, it may be faster to create a temporary vector of
156         // values 0..num_cols - 1 and then use std::partial_sort_copy
157         // of this into indices. Choosing the appropriate minimum k or
158         // ratio of k/num_cols will require some experimentation.
159         if (k == num_cols) {
160           auto* begin = &indices(b, 0);
161           auto* end = &indices(b, k);
162           // Set the initial array of indices 0 ... k - 1.
163           std::iota(begin, end, 0);
164           // We want an in-place sort, but we can cheat because we're sorting
165           // indices that started out sorted.  First, do a std::sort, which
166           // is notably faster than std::stable_sort.
167           std::sort(begin, end, comp);
168           // Then, for runs of adjacent elements that were equal, sort the
169           // indices in those runs in increasing order.
170           for (auto* run_begin = begin; run_begin != end;) {
171             auto* run_end = run_begin + 1;
172             if (run_end == end) break;
173             if (input_data[*run_begin] == input_data[*run_end]) {
174               while (++run_end != end) {
175                 if (input_data[*run_begin] != input_data[*run_end]) break;
176               }
177               std::sort(run_begin, run_end);
178             }
179             run_begin = run_end;
180           }
181         } else {
182           // Use the TopN heap object to sort.
183           gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
184           filter.reserve(num_cols);
185           for (int32 c = 0; c < num_cols; ++c) {
186             filter.push(c);
187           }
188 
189           int32 i = 0;
190           if (sorted) {
191             std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
192             for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
193                  ++top_k_it, ++i) {
194               indices(b, i) = *top_k_it;
195             }
196           } else {
197             for (auto top_k_it = filter.unsorted_begin();
198                  top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
199               indices(b, i) = *top_k_it;
200             }
201           }
202         }
203         // Now that the indices are sorted, copy the values over in
204         // sorted order.
205         std::transform(&indices(b, 0), &indices(b, k), &values(b, 0),
206                        [b, &input](const int32 loc) { return input(b, loc); });
207       }  // for (int32 b = ...
208     };
209 
210     // Guesstimate of cost; 4*N*log(K) where N == num_cols.
211     // If K == N, assume the cost is N*log(K + 1).
212     const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
213                             Eigen::TensorOpCost::AddCost<T>();
214     const double base_cost =
215         cmp_cost *
216         static_cast<double>(num_cols *
217                             Eigen::numext::log2(static_cast<float>(k + 1)));
218     const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
219     const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
220     const double total_cost = sort_cost + copy_cost;
221     const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
222                                  ? kint64max
223                                  : static_cast<int64>(total_cost);
224     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
225     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
226           final_cost, SortIndices);
227 
228     return Status::OK();
229   }
230 };
231 
232 }  // namespace functor
233 
234 #define REGISTER_KERNELS_NAME(name, type)                       \
235   REGISTER_KERNEL_BUILDER(                                      \
236       Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
237       TopK<CPUDevice, type>)
238 
239 #define REGISTER_KERNELS(type)       \
240   REGISTER_KERNELS_NAME(TopK, type); \
241   REGISTER_KERNELS_NAME(TopKV2, type)
242 
243 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
244 #undef REGISTER_KERNELS_NAME
245 #undef REGISTER_KERNELS
246 
247 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
248 
249 namespace functor {
250 #define DECLARE_GPU_SPEC(T)                                                  \
251   template <>                                                                \
252   Status TopKFunctor<GPUDevice, T>::Compute(                                 \
253       OpKernelContext* context, bool sorted, int k,                          \
254       const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
255       const int64 num_cols, typename TTypes<T, 2>::Tensor values,            \
256       typename TTypes<int, 2>::Tensor indices);                              \
257   extern template struct functor::TopKFunctor<GPUDevice, T>;
258 
259 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
260 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
261 
262 #undef DECLARE_GPU_SPEC
263 
264 }  // namespace functor
265 
266 #define REGISTER_KERNELS(type)                                   \
267   REGISTER_KERNEL_BUILDER(                                       \
268       Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
269       TopK<GPUDevice, type>)                                     \
270   REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
271                               .Device(DEVICE_GPU)                \
272                               .TypeConstraint<type>("T")         \
273                               .HostMemory("k"),                  \
274                           TopK<GPUDevice, type>)
275 
276 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
277 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
278 #undef REGISTER_KERNELS
279 
280 #endif  // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
281 
282 }  // end namespace tensorflow
283