• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/topk_op.h"
21 
22 #include <algorithm>
23 #include <numeric>
24 #include <vector>
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/top_n.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 
39 template <typename Device, typename T>
40 class TopK : public OpKernel {
41  public:
TopK(OpKernelConstruction * context)42   explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
44     if (num_inputs() < 2) {  // k is an attr (TopK).
45       OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
46     } else {  // k is an input (TopKV2), so we won't know it until Compute.
47       k_ = -1;
48     }
49   }
50 
Compute(OpKernelContext * context)51   void Compute(OpKernelContext* context) override {
52     int k = k_;
53     if (num_inputs() >= 2) {
54       const auto& k_in = context->input(1);
55       OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
56                   errors::InvalidArgument("k must be scalar, got shape ",
57                                           k_in.shape().DebugString()));
58       k = k_in.scalar<int32>()();
59     }
60     OP_REQUIRES(context, k >= 0,
61                 errors::InvalidArgument("Need k >= 0, got ", k));
62     const auto& input_in = context->input(0);
63     OP_REQUIRES(context, input_in.dims() >= 1,
64                 errors::InvalidArgument("input must be >= 1-D, got shape ",
65                                         input_in.shape().DebugString()));
66     OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
67                 errors::InvalidArgument(
68                     "input must have at least k columns. Had ",
69                     input_in.dim_size(input_in.dims() - 1), ", needed ", k));
70 
71     const auto& input = input_in.flat_inner_dims<T>();
72 
73     const int64_t num_rows = input.dimension(0);  // generally batch_size
74     const int64_t num_cols = input.dimension(1);
75     OP_REQUIRES(
76         context, num_rows <= std::numeric_limits<int32>::max(),
77         errors::InvalidArgument(
78             "First dimension of flattened input must be <= INT_MAX, got ",
79             num_rows));
80     OP_REQUIRES(
81         context, num_cols <= std::numeric_limits<int32>::max(),
82         errors::InvalidArgument(
83             "Second dimension of flattened input must be <= INT_MAX, got ",
84             num_cols));
85 
86     TensorShape output_shape = input_in.shape();
87     output_shape.set_dim(input_in.dims() - 1, k);
88     Tensor* values_out = nullptr;
89     OP_REQUIRES_OK(context,
90                    context->allocate_output(0, output_shape, &values_out));
91     Tensor* indices_out = nullptr;
92     OP_REQUIRES_OK(context,
93                    context->allocate_output(1, output_shape, &indices_out));
94 
95     // Nothing to do for top-nothing or over nothing.
96     if (k == 0 || num_rows == 0) return;
97 
98     auto values = values_out->flat_inner_dims<T>();
99     auto indices = indices_out->flat_inner_dims<int32>();
100     Status s = functor::TopKFunctor<Device, T>::Compute(
101         context, sorted_, k, input, num_rows, num_cols, values, indices);
102     OP_REQUIRES_OK(context, s);
103   }
104 
105  private:
106   int k_;
107   bool sorted_;
108 };
109 
110 namespace functor {
111 
112 template <typename T>
113 struct TopKFunctor<CPUDevice, T> {
Computetensorflow::functor::TopKFunctor114   static EIGEN_ALWAYS_INLINE Status Compute(
115       OpKernelContext* context, bool sorted, int k,
116       const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows,
117       const int64_t num_cols, typename TTypes<T, 2>::Tensor values,
118       typename TTypes<int, 2>::Tensor indices) {
119     const CPUDevice& d = context->eigen_device<CPUDevice>();
120 
121     // Special case for k == 1.
122     if (k == 1) {
123 #ifdef EIGEN_HAS_INDEX_LIST
124       typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
125       typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
126       rows_by_one.set(0, num_rows);
127 #else
128       Eigen::array<int, 1> reduce_on_cols = {1};
129       Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
130 #endif
131 
132       values.device(d) =
133           input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
134       // Get the indices of the maximum values.
135       for (int r = 0; r < num_rows; ++r) {
136         indices(r, 0) = 0;
137         for (int c = 0; c < num_cols; ++c) {
138           if (values(r, 0) == input(r, c)) {
139             indices(r, 0) = c;
140             break;
141           }
142         }
143         values(r, 0) = input(r, indices(r, 0));
144       }
145 
146       return Status::OK();
147     }
148 
149     auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) {
150       for (int32_t b = start_batch; b < limit_batch; ++b) {
151         const T* input_data = &input(b, 0);
152         const auto stable_comp = [input_data](const int32_t a,
153                                               const int32_t b) {
154           if (input_data[b] < input_data[a]) {
155             return true;
156           } else if (input_data[b] > input_data[a]) {
157             return false;
158           } else {
159             return a < b;
160           }
161         };
162         const auto comp = [input_data](const int32_t a, const int32_t b) {
163           return input_data[b] < input_data[a];
164         };
165         // TODO(ebrevdo): For large k < num_cols, instead of using
166         // TopN, it may be faster to create a temporary vector of
167         // values 0..num_cols - 1 and then use std::partial_sort_copy
168         // of this into indices. Choosing the appropriate minimum k or
169         // ratio of k/num_cols will require some experimentation.
170         if (k == num_cols) {
171           auto* begin = &indices(b, 0);
172           auto* end = &indices(b, k);
173           // Set the initial array of indices 0 ... k - 1.
174           std::iota(begin, end, 0);
175           // We want an in-place sort, but we can cheat because we're sorting
176           // indices that started out sorted.  First, do a std::sort, which
177           // is notably faster than std::stable_sort.
178           std::sort(begin, end, comp);
179           // Then, for runs of adjacent elements that were equal, sort the
180           // indices in those runs in increasing order.
181           for (auto* run_begin = begin; run_begin != end;) {
182             auto* run_end = run_begin + 1;
183             if (run_end == end) break;
184             if (input_data[*run_begin] == input_data[*run_end]) {
185               while (++run_end != end) {
186                 if (input_data[*run_begin] != input_data[*run_end]) break;
187               }
188               std::sort(run_begin, run_end);
189             }
190             run_begin = run_end;
191           }
192         } else {
193           // Use the TopN heap object to sort.
194           gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
195           filter.reserve(num_cols);
196           for (int32_t c = 0; c < num_cols; ++c) {
197             filter.push(c);
198           }
199 
200           int32_t i = 0;
201           if (sorted) {
202             std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
203             for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
204                  ++top_k_it, ++i) {
205               indices(b, i) = *top_k_it;
206             }
207           } else {
208             for (auto top_k_it = filter.unsorted_begin();
209                  top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
210               indices(b, i) = *top_k_it;
211             }
212           }
213         }
214         // Now that the indices are sorted, copy the values over in
215         // sorted order.
216         std::transform(
217             &indices(b, 0), &indices(b, k), &values(b, 0),
218             [b, &input](const int32_t loc) { return input(b, loc); });
219       }  // for (int32 b = ...
220     };
221 
222     // Guesstimate of cost; 4*N*log(K) where N == num_cols.
223     // If K == N, assume the cost is N*log(K + 1).
224     const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
225                             Eigen::TensorOpCost::AddCost<T>();
226     const double base_cost =
227         cmp_cost *
228         static_cast<double>(num_cols *
229                             Eigen::numext::log2(static_cast<float>(k + 1)));
230     const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
231     const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
232     const double total_cost = sort_cost + copy_cost;
233     const int64_t final_cost = (total_cost >= static_cast<double>(kint64max))
234                                    ? kint64max
235                                    : static_cast<int64>(total_cost);
236     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
237     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
238           final_cost, SortIndices);
239 
240     return Status::OK();
241   }
242 };
243 
244 }  // namespace functor
245 
246 #define REGISTER_KERNELS_NAME(name, type)                       \
247   REGISTER_KERNEL_BUILDER(                                      \
248       Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
249       TopK<CPUDevice, type>)
250 
251 #define REGISTER_KERNELS(type)       \
252   REGISTER_KERNELS_NAME(TopK, type); \
253   REGISTER_KERNELS_NAME(TopKV2, type)
254 
255 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
256 #undef REGISTER_KERNELS_NAME
257 #undef REGISTER_KERNELS
258 
259 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
260 
261 namespace functor {
262 #define DECLARE_GPU_SPEC(T)                                                  \
263   template <>                                                                \
264   Status TopKFunctor<GPUDevice, T>::Compute(                                 \
265       OpKernelContext* context, bool sorted, int k,                          \
266       const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
267       const int64 num_cols, typename TTypes<T, 2>::Tensor values,            \
268       typename TTypes<int, 2>::Tensor indices);                              \
269   extern template struct functor::TopKFunctor<GPUDevice, T>;
270 
271 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
272 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
273 
274 #undef DECLARE_GPU_SPEC
275 
276 }  // namespace functor
277 
278 #define REGISTER_KERNELS(type)                                   \
279   REGISTER_KERNEL_BUILDER(                                       \
280       Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
281       TopK<GPUDevice, type>)                                     \
282   REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
283                               .Device(DEVICE_GPU)                \
284                               .TypeConstraint<type>("T")         \
285                               .HostMemory("k"),                  \
286                           TopK<GPUDevice, type>)
287 
288 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
289 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
290 #undef REGISTER_KERNELS
291 
292 #endif  // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
293 
294 }  // end namespace tensorflow
295