• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/topk_op.h"
21 
22 #include <algorithm>
23 #include <numeric>
24 #include <vector>
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/top_n.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 
39 template <typename Device, typename T>
40 class TopK : public OpKernel {
41  public:
TopK(OpKernelConstruction * context)42   explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
44     if (num_inputs() < 2) {  // k is an attr (TopK).
45       OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
46     } else {  // k is an input (TopKV2), so we won't know it until Compute.
47       k_ = -1;
48     }
49   }
50 
Compute(OpKernelContext * context)51   void Compute(OpKernelContext* context) override {
52     int k = k_;
53     if (num_inputs() >= 2) {
54       const auto& k_in = context->input(1);
55       OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
56                   errors::InvalidArgument("k must be scalar, got shape ",
57                                           k_in.shape().DebugString()));
58       k = k_in.scalar<int32>()();
59     }
60     OP_REQUIRES(context, k >= 0,
61                 errors::InvalidArgument("Need k >= 0, got ", k));
62     const auto& input_in = context->input(0);
63     OP_REQUIRES(context, input_in.dims() >= 1,
64                 errors::InvalidArgument("input must be >= 1-D, got shape ",
65                                         input_in.shape().DebugString()));
66     OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
67                 errors::InvalidArgument(
68                     "input must have at least k columns. Had ",
69                     input_in.dim_size(input_in.dims() - 1), ", needed ", k));
70 
71     const auto& input = input_in.flat_inner_dims<T>();
72 
73     const int64 num_rows = input.dimension(0);  // generally batch_size
74     const int64 num_cols = input.dimension(1);
75 
76     TensorShape output_shape = input_in.shape();
77     output_shape.set_dim(input_in.dims() - 1, k);
78     Tensor* values_out = nullptr;
79     OP_REQUIRES_OK(context,
80                    context->allocate_output(0, output_shape, &values_out));
81     Tensor* indices_out = nullptr;
82     OP_REQUIRES_OK(context,
83                    context->allocate_output(1, output_shape, &indices_out));
84 
85     // Nothing to do for top-nothing or over nothing.
86     if (k == 0 || num_rows == 0) return;
87 
88     auto values = values_out->flat_inner_dims<T>();
89     auto indices = indices_out->flat_inner_dims<int32>();
90     Status s = functor::TopKFunctor<Device, T>::Compute(
91         context, sorted_, k, input, num_rows, num_cols, values, indices);
92     OP_REQUIRES_OK(context, s);
93   }
94 
95  private:
96   int k_;
97   bool sorted_;
98 };
99 
100 namespace functor {
101 
102 template <typename T>
103 struct TopKFunctor<CPUDevice, T> {
104   static EIGEN_ALWAYS_INLINE Status
Computetensorflow::functor::TopKFunctor105   Compute(OpKernelContext* context, bool sorted, int k,
106           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
107           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
108           typename TTypes<int, 2>::Tensor indices) {
109     const CPUDevice& d = context->eigen_device<CPUDevice>();
110 
111     // Special case for k == 1.
112     if (k == 1) {
113 #ifdef EIGEN_HAS_INDEX_LIST
114       typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
115       typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
116       rows_by_one.set(0, num_rows);
117 #else
118       Eigen::array<int, 1> reduce_on_cols = {1};
119       Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
120 #endif
121 
122       values.device(d) =
123           input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
124       // Get the indices of the maximum values.
125       for (int r = 0; r < num_rows; ++r) {
126         for (int c = 0; c < num_cols; ++c) {
127           if (values(r, 0) == input(r, c)) {
128             indices(r, 0) = c;
129             break;
130           }
131         }
132       }
133 
134       return Status::OK();
135     }
136 
137     auto SortIndices = [&, context](int start_batch, int limit_batch) {
138       for (int32 b = start_batch; b < limit_batch; ++b) {
139         const T* input_data = &input(b, 0);
140         const auto stable_comp = [input_data](const int32 a, const int32 b) {
141           if (input_data[b] < input_data[a]) {
142             return true;
143           } else if (input_data[b] > input_data[a]) {
144             return false;
145           } else {
146             return a < b;
147           }
148         };
149         const auto comp = [input_data](const int32 a, const int32 b) {
150           return input_data[b] < input_data[a];
151         };
152         // TODO(ebrevdo): For large k < num_cols, instead of using
153         // TopN, it may be faster to create a temporary vector of
154         // values 0..num_cols - 1 and then use std::partial_sort_copy
155         // of this into indices. Choosing the appropriate minimum k or
156         // ratio of k/num_cols will require some experimentation.
157         if (k == num_cols) {
158           auto* begin = &indices(b, 0);
159           auto* end = &indices(b, k);
160           // Set the initial array of indices 0 ... k - 1.
161           std::iota(begin, end, 0);
162           // We want an in-place sort, but we can cheat because we're sorting
163           // indices that started out sorted.  First, do a std::sort, which
164           // is notably faster than std::stable_sort.
165           std::sort(begin, end, comp);
166           // Then, for runs of adjacent elements that were equal, sort the
167           // indices in those runs in increasing order.
168           for (auto* run_begin = begin; run_begin != end;) {
169             auto* run_end = run_begin + 1;
170             if (run_end == end) break;
171             if (input_data[*run_begin] == input_data[*run_end]) {
172               while (++run_end != end) {
173                 if (input_data[*run_begin] != input_data[*run_end]) break;
174               }
175               std::sort(run_begin, run_end);
176             }
177             run_begin = run_end;
178           }
179         } else {
180           // Use the TopN heap object to sort.
181           gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
182           filter.reserve(num_cols);
183           for (int32 c = 0; c < num_cols; ++c) {
184             filter.push(c);
185           }
186 
187           int32 i = 0;
188           if (sorted) {
189             std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
190             for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
191                  ++top_k_it, ++i) {
192               indices(b, i) = *top_k_it;
193             }
194           } else {
195             for (auto top_k_it = filter.unsorted_begin();
196                  top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
197               indices(b, i) = *top_k_it;
198             }
199           }
200         }
201         // Now that the indices are sorted, copy the values over in
202         // sorted order.
203         std::transform(&indices(b, 0), &indices(b, k), &values(b, 0),
204                        [b, &input](const int32 loc) { return input(b, loc); });
205       }  // for (int32 b = ...
206     };
207 
208     // Guesstimate of cost; 4*N*log(K) where N == num_cols.
209     // If K == N, assume the cost is N*log(K + 1).
210     const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
211                             Eigen::TensorOpCost::AddCost<T>();
212     const double base_cost =
213         cmp_cost *
214         static_cast<double>(num_cols *
215                             Eigen::numext::log2(static_cast<float>(k + 1)));
216     const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
217     const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
218     const double total_cost = sort_cost + copy_cost;
219     const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
220                                  ? kint64max
221                                  : static_cast<int64>(total_cost);
222     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
223     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
224           final_cost, SortIndices);
225 
226     return Status::OK();
227   }
228 };
229 
230 }  // namespace functor
231 
232 #define REGISTER_KERNELS_NAME(name, type)                       \
233   REGISTER_KERNEL_BUILDER(                                      \
234       Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
235       TopK<CPUDevice, type>)
236 
237 #define REGISTER_KERNELS(type)       \
238   REGISTER_KERNELS_NAME(TopK, type); \
239   REGISTER_KERNELS_NAME(TopKV2, type)
240 
241 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
242 #undef REGISTER_KERNELS_NAME
243 #undef REGISTER_KERNELS
244 
245 #ifdef GOOGLE_CUDA
246 
247 namespace functor {
248 #define DECLARE_GPU_SPEC(T)                                                  \
249   template <>                                                                \
250   Status TopKFunctor<GPUDevice, T>::Compute(                                 \
251       OpKernelContext* context, bool sorted, int k,                          \
252       const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
253       const int64 num_cols, typename TTypes<T, 2>::Tensor values,            \
254       typename TTypes<int, 2>::Tensor indices);                              \
255   extern template struct functor::TopKFunctor<GPUDevice, T>;
256 
257 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
258 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
259 
260 #undef DECLARE_GPU_SPEC
261 
262 }  // namespace functor
263 
264 #define REGISTER_KERNELS(type)                                   \
265   REGISTER_KERNEL_BUILDER(                                       \
266       Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
267       TopK<GPUDevice, type>)                                     \
268   REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
269                               .Device(DEVICE_GPU)                \
270                               .TypeConstraint<type>("T")         \
271                               .HostMemory("k"),                  \
272                           TopK<GPUDevice, type>)
273 
274 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
275 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
276 
277 #undef REGISTER_KERNELS
278 
279 #endif  // end GOOGLE_CUDA
280 
281 }  // end namespace tensorflow
282