1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/topk_op.h" 21 22 #include <algorithm> 23 #include <numeric> 24 #include <vector> 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 typedef Eigen::GpuDevice GPUDevice; 38 39 template <typename Device, typename T> 40 class TopK : public OpKernel { 41 public: TopK(OpKernelConstruction * context)42 explicit TopK(OpKernelConstruction* context) : OpKernel(context) { 43 OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); 44 if (num_inputs() < 2) { // k is an attr (TopK). 45 OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 46 } else { // k is an input (TopKV2), so we won't know it until Compute. 47 k_ = -1; 48 } 49 } 50 Compute(OpKernelContext * context)51 void Compute(OpKernelContext* context) override { 52 int k = k_; 53 if (num_inputs() >= 2) { 54 const auto& k_in = context->input(1); 55 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), 56 errors::InvalidArgument("k must be scalar, got shape ", 57 k_in.shape().DebugString())); 58 k = k_in.scalar<int32>()(); 59 } 60 OP_REQUIRES(context, k >= 0, 61 errors::InvalidArgument("Need k >= 0, got ", k)); 62 const auto& input_in = context->input(0); 63 OP_REQUIRES(context, input_in.dims() >= 1, 64 errors::InvalidArgument("input must be >= 1-D, got shape ", 65 input_in.shape().DebugString())); 66 OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, 67 errors::InvalidArgument( 68 "input must have at least k columns. Had ", 69 input_in.dim_size(input_in.dims() - 1), ", needed ", k)); 70 71 const auto& input = input_in.flat_inner_dims<T>(); 72 73 const int64_t num_rows = input.dimension(0); // generally batch_size 74 const int64_t num_cols = input.dimension(1); 75 OP_REQUIRES( 76 context, num_rows <= std::numeric_limits<int32>::max(), 77 errors::InvalidArgument( 78 "First dimension of flattened input must be <= INT_MAX, got ", 79 num_rows)); 80 OP_REQUIRES( 81 context, num_cols <= std::numeric_limits<int32>::max(), 82 errors::InvalidArgument( 83 "Second dimension of flattened input must be <= INT_MAX, got ", 84 num_cols)); 85 86 TensorShape output_shape = input_in.shape(); 87 output_shape.set_dim(input_in.dims() - 1, k); 88 Tensor* values_out = nullptr; 89 OP_REQUIRES_OK(context, 90 context->allocate_output(0, output_shape, &values_out)); 91 Tensor* indices_out = nullptr; 92 OP_REQUIRES_OK(context, 93 context->allocate_output(1, output_shape, &indices_out)); 94 95 // Nothing to do for top-nothing or over nothing. 96 if (k == 0 || num_rows == 0) return; 97 98 auto values = values_out->flat_inner_dims<T>(); 99 auto indices = indices_out->flat_inner_dims<int32>(); 100 Status s = functor::TopKFunctor<Device, T>::Compute( 101 context, sorted_, k, input, num_rows, num_cols, values, indices); 102 OP_REQUIRES_OK(context, s); 103 } 104 105 private: 106 int k_; 107 bool sorted_; 108 }; 109 110 namespace functor { 111 112 template <typename T> 113 struct TopKFunctor<CPUDevice, T> { Computetensorflow::functor::TopKFunctor114 static EIGEN_ALWAYS_INLINE Status Compute( 115 OpKernelContext* context, bool sorted, int k, 116 const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, 117 const int64_t num_cols, typename TTypes<T, 2>::Tensor values, 118 typename TTypes<int, 2>::Tensor indices) { 119 const CPUDevice& d = context->eigen_device<CPUDevice>(); 120 121 // Special case for k == 1. 122 if (k == 1) { 123 #ifdef EIGEN_HAS_INDEX_LIST 124 typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols; 125 typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one; 126 rows_by_one.set(0, num_rows); 127 #else 128 Eigen::array<int, 1> reduce_on_cols = {1}; 129 Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1}; 130 #endif 131 132 values.device(d) = 133 input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); 134 // Get the indices of the maximum values. 135 for (int r = 0; r < num_rows; ++r) { 136 indices(r, 0) = 0; 137 for (int c = 0; c < num_cols; ++c) { 138 if (values(r, 0) == input(r, c)) { 139 indices(r, 0) = c; 140 break; 141 } 142 } 143 values(r, 0) = input(r, indices(r, 0)); 144 } 145 146 return Status::OK(); 147 } 148 149 auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) { 150 for (int32_t b = start_batch; b < limit_batch; ++b) { 151 const T* input_data = &input(b, 0); 152 const auto stable_comp = [input_data](const int32_t a, 153 const int32_t b) { 154 if (input_data[b] < input_data[a]) { 155 return true; 156 } else if (input_data[b] > input_data[a]) { 157 return false; 158 } else { 159 return a < b; 160 } 161 }; 162 const auto comp = [input_data](const int32_t a, const int32_t b) { 163 return input_data[b] < input_data[a]; 164 }; 165 // TODO(ebrevdo): For large k < num_cols, instead of using 166 // TopN, it may be faster to create a temporary vector of 167 // values 0..num_cols - 1 and then use std::partial_sort_copy 168 // of this into indices. Choosing the appropriate minimum k or 169 // ratio of k/num_cols will require some experimentation. 170 if (k == num_cols) { 171 auto* begin = &indices(b, 0); 172 auto* end = &indices(b, k); 173 // Set the initial array of indices 0 ... k - 1. 174 std::iota(begin, end, 0); 175 // We want an in-place sort, but we can cheat because we're sorting 176 // indices that started out sorted. First, do a std::sort, which 177 // is notably faster than std::stable_sort. 178 std::sort(begin, end, comp); 179 // Then, for runs of adjacent elements that were equal, sort the 180 // indices in those runs in increasing order. 181 for (auto* run_begin = begin; run_begin != end;) { 182 auto* run_end = run_begin + 1; 183 if (run_end == end) break; 184 if (input_data[*run_begin] == input_data[*run_end]) { 185 while (++run_end != end) { 186 if (input_data[*run_begin] != input_data[*run_end]) break; 187 } 188 std::sort(run_begin, run_end); 189 } 190 run_begin = run_end; 191 } 192 } else { 193 // Use the TopN heap object to sort. 194 gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp); 195 filter.reserve(num_cols); 196 for (int32_t c = 0; c < num_cols; ++c) { 197 filter.push(c); 198 } 199 200 int32_t i = 0; 201 if (sorted) { 202 std::unique_ptr<std::vector<int32>> top_k(filter.Extract()); 203 for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); 204 ++top_k_it, ++i) { 205 indices(b, i) = *top_k_it; 206 } 207 } else { 208 for (auto top_k_it = filter.unsorted_begin(); 209 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) { 210 indices(b, i) = *top_k_it; 211 } 212 } 213 } 214 // Now that the indices are sorted, copy the values over in 215 // sorted order. 216 std::transform( 217 &indices(b, 0), &indices(b, k), &values(b, 0), 218 [b, &input](const int32_t loc) { return input(b, loc); }); 219 } // for (int32 b = ... 220 }; 221 222 // Guesstimate of cost; 4*N*log(K) where N == num_cols. 223 // If K == N, assume the cost is N*log(K + 1). 224 const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() + 225 Eigen::TensorOpCost::AddCost<T>(); 226 const double base_cost = 227 cmp_cost * 228 static_cast<double>(num_cols * 229 Eigen::numext::log2(static_cast<float>(k + 1))); 230 const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; 231 const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>(); 232 const double total_cost = sort_cost + copy_cost; 233 const int64_t final_cost = (total_cost >= static_cast<double>(kint64max)) 234 ? kint64max 235 : static_cast<int64>(total_cost); 236 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 237 Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 238 final_cost, SortIndices); 239 240 return Status::OK(); 241 } 242 }; 243 244 } // namespace functor 245 246 #define REGISTER_KERNELS_NAME(name, type) \ 247 REGISTER_KERNEL_BUILDER( \ 248 Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 249 TopK<CPUDevice, type>) 250 251 #define REGISTER_KERNELS(type) \ 252 REGISTER_KERNELS_NAME(TopK, type); \ 253 REGISTER_KERNELS_NAME(TopKV2, type) 254 255 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); 256 #undef REGISTER_KERNELS_NAME 257 #undef REGISTER_KERNELS 258 259 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 260 261 namespace functor { 262 #define DECLARE_GPU_SPEC(T) \ 263 template <> \ 264 Status TopKFunctor<GPUDevice, T>::Compute( \ 265 OpKernelContext* context, bool sorted, int k, \ 266 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \ 267 const int64 num_cols, typename TTypes<T, 2>::Tensor values, \ 268 typename TTypes<int, 2>::Tensor indices); \ 269 extern template struct functor::TopKFunctor<GPUDevice, T>; 270 271 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 272 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); 273 274 #undef DECLARE_GPU_SPEC 275 276 } // namespace functor 277 278 #define REGISTER_KERNELS(type) \ 279 REGISTER_KERNEL_BUILDER( \ 280 Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 281 TopK<GPUDevice, type>) \ 282 REGISTER_KERNEL_BUILDER(Name("TopKV2") \ 283 .Device(DEVICE_GPU) \ 284 .TypeConstraint<type>("T") \ 285 .HostMemory("k"), \ 286 TopK<GPUDevice, type>) 287 288 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); 289 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); 290 #undef REGISTER_KERNELS 291 292 #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM 293 294 } // end namespace tensorflow 295