1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/topk_op.h" 21 22 #include <algorithm> 23 #include <numeric> 24 #include <vector> 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 typedef Eigen::GpuDevice GPUDevice; 38 39 template <typename Device, typename T> 40 class TopK : public OpKernel { 41 public: TopK(OpKernelConstruction * context)42 explicit TopK(OpKernelConstruction* context) : OpKernel(context) { 43 OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); 44 if (num_inputs() < 2) { // k is an attr (TopK). 45 OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 46 } else { // k is an input (TopKV2), so we won't know it until Compute. 47 k_ = -1; 48 } 49 } 50 Compute(OpKernelContext * context)51 void Compute(OpKernelContext* context) override { 52 int k = k_; 53 if (num_inputs() >= 2) { 54 const auto& k_in = context->input(1); 55 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), 56 errors::InvalidArgument("k must be scalar, got shape ", 57 k_in.shape().DebugString())); 58 k = k_in.scalar<int32>()(); 59 } 60 OP_REQUIRES(context, k >= 0, 61 errors::InvalidArgument("Need k >= 0, got ", k)); 62 const auto& input_in = context->input(0); 63 OP_REQUIRES(context, input_in.dims() >= 1, 64 errors::InvalidArgument("input must be >= 1-D, got shape ", 65 input_in.shape().DebugString())); 66 OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, 67 errors::InvalidArgument( 68 "input must have at least k columns. Had ", 69 input_in.dim_size(input_in.dims() - 1), ", needed ", k)); 70 71 const auto& input = input_in.flat_inner_dims<T>(); 72 73 const int64 num_rows = input.dimension(0); // generally batch_size 74 const int64 num_cols = input.dimension(1); 75 76 TensorShape output_shape = input_in.shape(); 77 output_shape.set_dim(input_in.dims() - 1, k); 78 Tensor* values_out = nullptr; 79 OP_REQUIRES_OK(context, 80 context->allocate_output(0, output_shape, &values_out)); 81 Tensor* indices_out = nullptr; 82 OP_REQUIRES_OK(context, 83 context->allocate_output(1, output_shape, &indices_out)); 84 85 // Nothing to do for top-nothing or over nothing. 86 if (k == 0 || num_rows == 0) return; 87 88 auto values = values_out->flat_inner_dims<T>(); 89 auto indices = indices_out->flat_inner_dims<int32>(); 90 Status s = functor::TopKFunctor<Device, T>::Compute( 91 context, sorted_, k, input, num_rows, num_cols, values, indices); 92 OP_REQUIRES_OK(context, s); 93 } 94 95 private: 96 int k_; 97 bool sorted_; 98 }; 99 100 namespace functor { 101 102 template <typename T> 103 struct TopKFunctor<CPUDevice, T> { 104 static EIGEN_ALWAYS_INLINE Status Computetensorflow::functor::TopKFunctor105 Compute(OpKernelContext* context, bool sorted, int k, 106 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, 107 const int64 num_cols, typename TTypes<T, 2>::Tensor values, 108 typename TTypes<int, 2>::Tensor indices) { 109 const CPUDevice& d = context->eigen_device<CPUDevice>(); 110 111 // Special case for k == 1. 112 if (k == 1) { 113 #ifdef EIGEN_HAS_INDEX_LIST 114 typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols; 115 typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one; 116 rows_by_one.set(0, num_rows); 117 #else 118 Eigen::array<int, 1> reduce_on_cols = {1}; 119 Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1}; 120 #endif 121 122 values.device(d) = 123 input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); 124 // Get the indices of the maximum values. 125 for (int r = 0; r < num_rows; ++r) { 126 indices(r, 0) = 0; 127 for (int c = 0; c < num_cols; ++c) { 128 if (values(r, 0) == input(r, c)) { 129 indices(r, 0) = c; 130 break; 131 } 132 } 133 values(r, 0) = input(r, indices(r, 0)); 134 } 135 136 return Status::OK(); 137 } 138 139 auto SortIndices = [&](int64 start_batch, int64 limit_batch) { 140 for (int32 b = start_batch; b < limit_batch; ++b) { 141 const T* input_data = &input(b, 0); 142 const auto stable_comp = [input_data](const int32 a, const int32 b) { 143 if (input_data[b] < input_data[a]) { 144 return true; 145 } else if (input_data[b] > input_data[a]) { 146 return false; 147 } else { 148 return a < b; 149 } 150 }; 151 const auto comp = [input_data](const int32 a, const int32 b) { 152 return input_data[b] < input_data[a]; 153 }; 154 // TODO(ebrevdo): For large k < num_cols, instead of using 155 // TopN, it may be faster to create a temporary vector of 156 // values 0..num_cols - 1 and then use std::partial_sort_copy 157 // of this into indices. Choosing the appropriate minimum k or 158 // ratio of k/num_cols will require some experimentation. 159 if (k == num_cols) { 160 auto* begin = &indices(b, 0); 161 auto* end = &indices(b, k); 162 // Set the initial array of indices 0 ... k - 1. 163 std::iota(begin, end, 0); 164 // We want an in-place sort, but we can cheat because we're sorting 165 // indices that started out sorted. First, do a std::sort, which 166 // is notably faster than std::stable_sort. 167 std::sort(begin, end, comp); 168 // Then, for runs of adjacent elements that were equal, sort the 169 // indices in those runs in increasing order. 170 for (auto* run_begin = begin; run_begin != end;) { 171 auto* run_end = run_begin + 1; 172 if (run_end == end) break; 173 if (input_data[*run_begin] == input_data[*run_end]) { 174 while (++run_end != end) { 175 if (input_data[*run_begin] != input_data[*run_end]) break; 176 } 177 std::sort(run_begin, run_end); 178 } 179 run_begin = run_end; 180 } 181 } else { 182 // Use the TopN heap object to sort. 183 gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp); 184 filter.reserve(num_cols); 185 for (int32 c = 0; c < num_cols; ++c) { 186 filter.push(c); 187 } 188 189 int32 i = 0; 190 if (sorted) { 191 std::unique_ptr<std::vector<int32>> top_k(filter.Extract()); 192 for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); 193 ++top_k_it, ++i) { 194 indices(b, i) = *top_k_it; 195 } 196 } else { 197 for (auto top_k_it = filter.unsorted_begin(); 198 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) { 199 indices(b, i) = *top_k_it; 200 } 201 } 202 } 203 // Now that the indices are sorted, copy the values over in 204 // sorted order. 205 std::transform(&indices(b, 0), &indices(b, k), &values(b, 0), 206 [b, &input](const int32 loc) { return input(b, loc); }); 207 } // for (int32 b = ... 208 }; 209 210 // Guesstimate of cost; 4*N*log(K) where N == num_cols. 211 // If K == N, assume the cost is N*log(K + 1). 212 const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() + 213 Eigen::TensorOpCost::AddCost<T>(); 214 const double base_cost = 215 cmp_cost * 216 static_cast<double>(num_cols * 217 Eigen::numext::log2(static_cast<float>(k + 1))); 218 const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; 219 const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>(); 220 const double total_cost = sort_cost + copy_cost; 221 const int64 final_cost = (total_cost >= static_cast<double>(kint64max)) 222 ? kint64max 223 : static_cast<int64>(total_cost); 224 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 225 Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 226 final_cost, SortIndices); 227 228 return Status::OK(); 229 } 230 }; 231 232 } // namespace functor 233 234 #define REGISTER_KERNELS_NAME(name, type) \ 235 REGISTER_KERNEL_BUILDER( \ 236 Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 237 TopK<CPUDevice, type>) 238 239 #define REGISTER_KERNELS(type) \ 240 REGISTER_KERNELS_NAME(TopK, type); \ 241 REGISTER_KERNELS_NAME(TopKV2, type) 242 243 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); 244 #undef REGISTER_KERNELS_NAME 245 #undef REGISTER_KERNELS 246 247 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 248 249 namespace functor { 250 #define DECLARE_GPU_SPEC(T) \ 251 template <> \ 252 Status TopKFunctor<GPUDevice, T>::Compute( \ 253 OpKernelContext* context, bool sorted, int k, \ 254 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \ 255 const int64 num_cols, typename TTypes<T, 2>::Tensor values, \ 256 typename TTypes<int, 2>::Tensor indices); \ 257 extern template struct functor::TopKFunctor<GPUDevice, T>; 258 259 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 260 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); 261 262 #undef DECLARE_GPU_SPEC 263 264 } // namespace functor 265 266 #define REGISTER_KERNELS(type) \ 267 REGISTER_KERNEL_BUILDER( \ 268 Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 269 TopK<GPUDevice, type>) \ 270 REGISTER_KERNEL_BUILDER(Name("TopKV2") \ 271 .Device(DEVICE_GPU) \ 272 .TypeConstraint<type>("T") \ 273 .HostMemory("k"), \ 274 TopK<GPUDevice, type>) 275 276 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); 277 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); 278 #undef REGISTER_KERNELS 279 280 #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM 281 282 } // end namespace tensorflow 283