1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/topk_op.h" 21 22 #include <algorithm> 23 #include <numeric> 24 #include <vector> 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 typedef Eigen::GpuDevice GPUDevice; 38 39 template <typename Device, typename T> 40 class TopK : public OpKernel { 41 public: TopK(OpKernelConstruction * context)42 explicit TopK(OpKernelConstruction* context) : OpKernel(context) { 43 OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); 44 if (num_inputs() < 2) { // k is an attr (TopK). 45 OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 46 } else { // k is an input (TopKV2), so we won't know it until Compute. 47 k_ = -1; 48 } 49 } 50 Compute(OpKernelContext * context)51 void Compute(OpKernelContext* context) override { 52 int k = k_; 53 if (num_inputs() >= 2) { 54 const auto& k_in = context->input(1); 55 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), 56 errors::InvalidArgument("k must be scalar, got shape ", 57 k_in.shape().DebugString())); 58 k = k_in.scalar<int32>()(); 59 } 60 OP_REQUIRES(context, k >= 0, 61 errors::InvalidArgument("Need k >= 0, got ", k)); 62 const auto& input_in = context->input(0); 63 OP_REQUIRES(context, input_in.dims() >= 1, 64 errors::InvalidArgument("input must be >= 1-D, got shape ", 65 input_in.shape().DebugString())); 66 OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, 67 errors::InvalidArgument( 68 "input must have at least k columns. Had ", 69 input_in.dim_size(input_in.dims() - 1), ", needed ", k)); 70 71 const auto& input = input_in.flat_inner_dims<T>(); 72 73 const int64 num_rows = input.dimension(0); // generally batch_size 74 const int64 num_cols = input.dimension(1); 75 76 TensorShape output_shape = input_in.shape(); 77 output_shape.set_dim(input_in.dims() - 1, k); 78 Tensor* values_out = nullptr; 79 OP_REQUIRES_OK(context, 80 context->allocate_output(0, output_shape, &values_out)); 81 Tensor* indices_out = nullptr; 82 OP_REQUIRES_OK(context, 83 context->allocate_output(1, output_shape, &indices_out)); 84 85 // Nothing to do for top-nothing or over nothing. 86 if (k == 0 || num_rows == 0) return; 87 88 auto values = values_out->flat_inner_dims<T>(); 89 auto indices = indices_out->flat_inner_dims<int32>(); 90 Status s = functor::TopKFunctor<Device, T>::Compute( 91 context, sorted_, k, input, num_rows, num_cols, values, indices); 92 OP_REQUIRES_OK(context, s); 93 } 94 95 private: 96 int k_; 97 bool sorted_; 98 }; 99 100 namespace functor { 101 102 template <typename T> 103 struct TopKFunctor<CPUDevice, T> { 104 static EIGEN_ALWAYS_INLINE Status Computetensorflow::functor::TopKFunctor105 Compute(OpKernelContext* context, bool sorted, int k, 106 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, 107 const int64 num_cols, typename TTypes<T, 2>::Tensor values, 108 typename TTypes<int, 2>::Tensor indices) { 109 const CPUDevice& d = context->eigen_device<CPUDevice>(); 110 111 // Special case for k == 1. 112 if (k == 1) { 113 #ifdef EIGEN_HAS_INDEX_LIST 114 typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols; 115 typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one; 116 rows_by_one.set(0, num_rows); 117 #else 118 Eigen::array<int, 1> reduce_on_cols = {1}; 119 Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1}; 120 #endif 121 122 values.device(d) = 123 input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); 124 // Get the indices of the maximum values. 125 for (int r = 0; r < num_rows; ++r) { 126 for (int c = 0; c < num_cols; ++c) { 127 if (values(r, 0) == input(r, c)) { 128 indices(r, 0) = c; 129 break; 130 } 131 } 132 } 133 134 return Status::OK(); 135 } 136 137 auto SortIndices = [&, context](int start_batch, int limit_batch) { 138 for (int32 b = start_batch; b < limit_batch; ++b) { 139 const T* input_data = &input(b, 0); 140 const auto stable_comp = [input_data](const int32 a, const int32 b) { 141 if (input_data[b] < input_data[a]) { 142 return true; 143 } else if (input_data[b] > input_data[a]) { 144 return false; 145 } else { 146 return a < b; 147 } 148 }; 149 const auto comp = [input_data](const int32 a, const int32 b) { 150 return input_data[b] < input_data[a]; 151 }; 152 // TODO(ebrevdo): For large k < num_cols, instead of using 153 // TopN, it may be faster to create a temporary vector of 154 // values 0..num_cols - 1 and then use std::partial_sort_copy 155 // of this into indices. Choosing the appropriate minimum k or 156 // ratio of k/num_cols will require some experimentation. 157 if (k == num_cols) { 158 auto* begin = &indices(b, 0); 159 auto* end = &indices(b, k); 160 // Set the initial array of indices 0 ... k - 1. 161 std::iota(begin, end, 0); 162 // We want an in-place sort, but we can cheat because we're sorting 163 // indices that started out sorted. First, do a std::sort, which 164 // is notably faster than std::stable_sort. 165 std::sort(begin, end, comp); 166 // Then, for runs of adjacent elements that were equal, sort the 167 // indices in those runs in increasing order. 168 for (auto* run_begin = begin; run_begin != end;) { 169 auto* run_end = run_begin + 1; 170 if (run_end == end) break; 171 if (input_data[*run_begin] == input_data[*run_end]) { 172 while (++run_end != end) { 173 if (input_data[*run_begin] != input_data[*run_end]) break; 174 } 175 std::sort(run_begin, run_end); 176 } 177 run_begin = run_end; 178 } 179 } else { 180 // Use the TopN heap object to sort. 181 gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp); 182 filter.reserve(num_cols); 183 for (int32 c = 0; c < num_cols; ++c) { 184 filter.push(c); 185 } 186 187 int32 i = 0; 188 if (sorted) { 189 std::unique_ptr<std::vector<int32>> top_k(filter.Extract()); 190 for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); 191 ++top_k_it, ++i) { 192 indices(b, i) = *top_k_it; 193 } 194 } else { 195 for (auto top_k_it = filter.unsorted_begin(); 196 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) { 197 indices(b, i) = *top_k_it; 198 } 199 } 200 } 201 // Now that the indices are sorted, copy the values over in 202 // sorted order. 203 std::transform(&indices(b, 0), &indices(b, k), &values(b, 0), 204 [b, &input](const int32 loc) { return input(b, loc); }); 205 } // for (int32 b = ... 206 }; 207 208 // Guesstimate of cost; 4*N*log(K) where N == num_cols. 209 // If K == N, assume the cost is N*log(K + 1). 210 const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() + 211 Eigen::TensorOpCost::AddCost<T>(); 212 const double base_cost = 213 cmp_cost * 214 static_cast<double>(num_cols * 215 Eigen::numext::log2(static_cast<float>(k + 1))); 216 const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; 217 const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>(); 218 const double total_cost = sort_cost + copy_cost; 219 const int64 final_cost = (total_cost >= static_cast<double>(kint64max)) 220 ? kint64max 221 : static_cast<int64>(total_cost); 222 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 223 Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 224 final_cost, SortIndices); 225 226 return Status::OK(); 227 } 228 }; 229 230 } // namespace functor 231 232 #define REGISTER_KERNELS_NAME(name, type) \ 233 REGISTER_KERNEL_BUILDER( \ 234 Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 235 TopK<CPUDevice, type>) 236 237 #define REGISTER_KERNELS(type) \ 238 REGISTER_KERNELS_NAME(TopK, type); \ 239 REGISTER_KERNELS_NAME(TopKV2, type) 240 241 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); 242 #undef REGISTER_KERNELS_NAME 243 #undef REGISTER_KERNELS 244 245 #ifdef GOOGLE_CUDA 246 247 namespace functor { 248 #define DECLARE_GPU_SPEC(T) \ 249 template <> \ 250 Status TopKFunctor<GPUDevice, T>::Compute( \ 251 OpKernelContext* context, bool sorted, int k, \ 252 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \ 253 const int64 num_cols, typename TTypes<T, 2>::Tensor values, \ 254 typename TTypes<int, 2>::Tensor indices); \ 255 extern template struct functor::TopKFunctor<GPUDevice, T>; 256 257 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 258 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); 259 260 #undef DECLARE_GPU_SPEC 261 262 } // namespace functor 263 264 #define REGISTER_KERNELS(type) \ 265 REGISTER_KERNEL_BUILDER( \ 266 Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 267 TopK<GPUDevice, type>) \ 268 REGISTER_KERNEL_BUILDER(Name("TopKV2") \ 269 .Device(DEVICE_GPU) \ 270 .TypeConstraint<type>("T") \ 271 .HostMemory("k"), \ 272 TopK<GPUDevice, type>) 273 274 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); 275 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); 276 277 #undef REGISTER_KERNELS 278 279 #endif // end GOOGLE_CUDA 280 281 } // end namespace tensorflow 282