1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
16 #define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19
20 #define EIGEN_USE_GPU
21
22 #include <cmath>
23 #include <string>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/gpu_prim.h"
32 #include "tensorflow/core/kernels/topk_op.h"
33 #include "tensorflow/core/lib/gtl/top_n.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/gpu_kernel_helper.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::GpuDevice GPUDevice;
41
42 namespace impl {
43
44 enum class HeapType { kMinHeap, kMaxHeap };
45 enum class PreferIndices { kLower, kHigher };
46
47 template <typename T>
48 struct Entry {
49 int index;
50 T value;
51
52 // Test-only.
greaterEntry53 static bool greater(const Entry<T>& a, const Entry<T>& b) {
54 if (a.value == b.value) {
55 return a.index < b.index;
56 }
57 return a.value > b.value;
58 }
59 };
60
61 template <typename T>
62 struct LinearData {
63 typedef impl::Entry<T> Entry;
64
65 __device__ Entry& operator[](std::size_t index) const { return data[index]; }
66
get_indexLinearData67 __device__ int get_index(int i) const { return data[i].index; }
get_valueLinearData68 __device__ T get_value(int i) const { return data[i].value; }
69
70 Entry* const data;
71 };
72
73 template <typename T>
74 struct IndirectLinearData {
75 typedef impl::Entry<T> Entry;
76
77 __device__ Entry& operator[](std::size_t index) const { return data[index]; }
78
get_indexIndirectLinearData79 __device__ int get_index(int i) const {
80 return backing_data[data[i].index].index;
81 }
get_valueIndirectLinearData82 __device__ T get_value(int i) const { return data[i].value; }
83
84 Entry* const data;
85 Entry* const backing_data;
86 };
87
88 template <typename T>
89 struct StridedData {
90 typedef impl::Entry<T> Entry;
91
92 __device__ Entry& operator[](std::size_t index) const {
93 return data[index * blockDim.x + threadIdx.x];
94 }
95
get_indexStridedData96 __device__ int get_index(int i) const { return (*this)[i].index; }
get_valueStridedData97 __device__ T get_value(int i) const { return (*this)[i].value; }
98
99 Entry* const data;
100 };
101
102 // A heap of Entry<T> that can either work as a min-heap or as a max-heap.
103 template <HeapType heapType, PreferIndices preferIndices,
104 template <typename> class Data, typename T>
105 struct IndexedHeap {
106 typedef typename Data<T>::Entry Entry;
107 const Data<T> data;
IndexedHeapIndexedHeap108 __device__ IndexedHeap(const Data<T>& d) : data(d) {}
109
is_aboveIndexedHeap110 __device__ bool is_above(int left, int right) {
111 T left_value = data.get_value(left);
112 T right_value = data.get_value(right);
113 if (left_value == right_value) {
114 if (preferIndices == PreferIndices::kLower) {
115 return data.get_index(left) < data.get_index(right);
116 } else {
117 return data.get_index(left) > data.get_index(right);
118 }
119 }
120 if (heapType == HeapType::kMinHeap) {
121 return left_value < right_value;
122 } else {
123 return left_value > right_value;
124 }
125 }
126
assignIndexedHeap127 __device__ void assign(int i, const Entry& entry) { data[i] = entry; }
128
push_upIndexedHeap129 __device__ void push_up(int i) {
130 int child = i;
131 int parent;
132 for (; child > 0; child = parent) {
133 parent = (child - 1) / 2;
134 if (!is_above(child, parent)) {
135 // Heap property satisfied.
136 break;
137 }
138 swap(child, parent);
139 }
140 }
141
swapIndexedHeap142 __device__ void swap(int a, int b) {
143 auto tmp = data[b];
144 data[b] = data[a];
145 data[a] = tmp;
146 }
147
push_root_downIndexedHeap148 __device__ void push_root_down(int k) { push_down(0, k); }
149
150 // MAX-HEAPIFY in Cormen
push_downIndexedHeap151 __device__ void push_down(int node, int k) {
152 while (true) {
153 const int left = 2 * node + 1;
154 const int right = left + 1;
155 int smallest = node;
156 if (left < k && is_above(left, smallest)) {
157 smallest = left;
158 }
159 if (right < k && is_above(right, smallest)) {
160 smallest = right;
161 }
162 if (smallest == node) {
163 break;
164 }
165 swap(smallest, node);
166 node = smallest;
167 }
168 }
169
170 // BUILD-MAX-HEAPIFY in Cormen
buildIndexedHeap171 __device__ void build(int k) {
172 for (int node = (k - 1) / 2; node >= 0; node--) {
173 push_down(node, k);
174 }
175 }
176
177 // HEAP-EXTRACT-MAX in Cormen
remove_rootIndexedHeap178 __device__ void remove_root(int k) {
179 data[0] = data[k - 1];
180 push_root_down(k - 1);
181 }
182
183 // in-place HEAPSORT in Cormen
184 // This method destroys the heap property.
sortIndexedHeap185 __device__ void sort(int k) {
186 for (int slot = k - 1; slot > 0; slot--) {
187 // This is like remove_root but we insert the element at the end.
188 swap(slot, 0);
189 // Heap is now an element smaller.
190 push_root_down(/*k=*/slot);
191 }
192 }
193
replace_rootIndexedHeap194 __device__ void replace_root(const Entry& entry, int k) {
195 data[0] = entry;
196 push_root_down(k);
197 }
198
rootIndexedHeap199 __device__ const Entry& root() { return data[0]; }
200 };
201
202 template <HeapType heapType, PreferIndices preferIndices,
203 template <typename> class Data, typename T>
make_indexed_heap(typename Data<T>::Entry * data)204 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
205 typename Data<T>::Entry* data) {
206 return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
207 }
208
209 // heapTopK walks over [input, input+length) with `step_size` stride starting at
210 // `start_index`.
211 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
212 // access elements in `heap_entries`. If sorted=true, the elements will be
213 // sorted at the end.
214 template <typename T, template <typename> class Data = LinearData>
215 __device__ void heapTopK(const T* __restrict__ input, int length, int k,
216 Entry<T>* __restrict__ heap_entries,
217 bool sorted = false, int start_index = 0,
218 int step_size = 1) {
219 assert(k <= length);
220
221 auto heap =
222 make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
223 heap_entries);
224
225 int heap_end_index = start_index + k * step_size;
226 if (heap_end_index > length) {
227 heap_end_index = length;
228 }
229 // Initialize the min-heap.
230 for (int index = start_index, slot = 0; index < heap_end_index;
231 index += step_size, slot++) {
232 heap.assign(slot, {index, input[index]});
233 }
234
235 heap.build(k);
236
237 // Now iterate over the remaining items.
238 // If an item is smaller than the min element, it is not amongst the top k.
239 // Otherwise, replace the min element with it and push upwards.
240 for (int index = heap_end_index; index < length; index += step_size) {
241 // We prefer elements with lower indices. This is given here.
242 // Later elements automatically have higher indices, so can be discarded.
243 if (input[index] > heap.root().value) {
244 // This element should replace the min.
245 heap.replace_root({index, input[index]}, k);
246 }
247 }
248
249 // Sort if wanted.
250 if (sorted) {
251 heap.sort(k);
252 }
253 }
254
255 // mergeShards performs a top-k merge on `num_shards` many sorted streams that
256 // are sorted and stored in `entries` in a strided way:
257 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
258 // The overall top k elements are written to `top_k_values` and their indices
259 // to top_k_indices.
260 // `top_k_heap` is used as temporary storage for the merge heap.
261 template <typename T>
mergeShards(int num_shards,int k,Entry<T> * __restrict__ entries,Entry<T> * __restrict__ top_k_heap,T * top_k_values,int * top_k_indices)262 __device__ void mergeShards(int num_shards, int k,
263 Entry<T>* __restrict__ entries,
264 Entry<T>* __restrict__ top_k_heap, T* top_k_values,
265 int* top_k_indices) {
266 // If k < num_shards, we can use a min-heap with k elements to get the top k
267 // of the sorted blocks.
268 // If k > num_shards, we can initialize a min-heap with the top element from
269 // each sorted block.
270 const int heap_size = k < num_shards ? k : num_shards;
271
272 // Min-heap part.
273 {
274 auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
275 IndirectLinearData, T>{
276 IndirectLinearData<T>{top_k_heap, entries}};
277 // Initialize the heap as a min-heap.
278 for (int slot = 0; slot < heap_size; slot++) {
279 min_heap.assign(slot, {slot, entries[slot].value});
280 }
281 min_heap.build(heap_size);
282
283 // Now perform top k with the remaining shards (if num_shards > heap_size).
284 for (int shard = heap_size; shard < num_shards; shard++) {
285 const auto entry = entries[shard];
286 const auto root = min_heap.root();
287 if (entry.value < root.value) {
288 continue;
289 }
290 if (entry.value == root.value &&
291 entry.index > entries[root.index].index) {
292 continue;
293 }
294 // This element should replace the min.
295 min_heap.replace_root({shard, entry.value}, heap_size);
296 }
297 }
298
299 // Max-part.
300 {
301 // Turn the min-heap into a max-heap in-place.
302 auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
303 IndirectLinearData, T>{
304 IndirectLinearData<T>{top_k_heap, entries}};
305 // Heapify into a max heap.
306 max_heap.build(heap_size);
307
308 // Now extract the minimum k-1 times.
309 // k is treated specially.
310 const int last_k = k - 1;
311 for (int rank = 0; rank < last_k; rank++) {
312 const Entry<T>& max_element = max_heap.root();
313 top_k_values[rank] = max_element.value;
314 int shard_index = max_element.index;
315 top_k_indices[rank] = entries[shard_index].index;
316 int next_shard_index = shard_index + num_shards;
317 // For rank < k-1, each top k heap still contains at least 1 element,
318 // so we can draw a replacement.
319 max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
320 heap_size);
321 }
322
323 // rank == last_k.
324 const Entry<T>& max_element = max_heap.root();
325 top_k_values[last_k] = max_element.value;
326 int shard_index = max_element.index;
327 top_k_indices[last_k] = entries[shard_index].index;
328 }
329 }
330
331 #if GOOGLE_CUDA
332 extern __shared__ char shared_memory[];
333 #endif // GOOGLE_CUDA
334
335 template <typename T>
336 #if TENSORFLOW_USE_ROCM
337 __attribute__((amdgpu_flat_work_group_size(1, 256)))
338 #endif // TENSORFLOW_USE_ROCM
339 __global__ void
TopKKernel(const T * __restrict__ input,int length,int k,bool sorted,T * __restrict__ output,int * __restrict__ indices)340 TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
341 T* __restrict__ output, int* __restrict__ indices) {
342 #if TENSORFLOW_USE_ROCM
343 HIP_DYNAMIC_SHARED(char, shared_memory);
344 #endif // TENSORFLOW_USE_ROCM
345
346 const int batch_index = blockIdx.x;
347 const T* batch_input = input + batch_index * length;
348
349 const int thread_index = threadIdx.x;
350 const int thread_count = blockDim.x;
351
352 Entry<T>* shared_entries = (Entry<T>*)shared_memory;
353
354 heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
355 thread_index, thread_count);
356
357 __syncthreads();
358 if (thread_index == 0) {
359 const int offset = batch_index * k;
360 auto batch_output = output + offset;
361 auto batch_indices = indices + offset;
362 Entry<T>* top_k_heap = shared_entries + thread_count * k;
363
364 // TODO(blackhc): Erich says: Performance can likely be improved
365 // significantly by having the merge be done by multiple threads rather than
366 // just one. ModernGPU has some nice primitives that could help with this.
367 mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
368 batch_indices);
369 }
370 }
371
372 template <typename T>
LaunchTopKKernel(const gpuStream_t & stream,int num_shards,const T * input,int batch_size,int length,int k,bool sorted,T * output,int * indices)373 cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
374 const T* input, int batch_size, int length, int k,
375 bool sorted, T* output, int* indices) {
376 // This code assumes that k is small enough that the computation
377 // fits inside shared memory (hard coded to 48KB). In practice this
378 // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
379 // The calculation is:
380 // shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
381
382 // Use as many shards as possible.
383 if (num_shards <= 0) {
384 constexpr auto shared_memory_size = 48 << 10; // 48 KB
385 const auto heap_size = k * sizeof(Entry<T>);
386 // shared_memory_size = (num_shards + 1) * heap_size <=>
387 num_shards = shared_memory_size / heap_size - 1;
388 if (num_shards <= 0) {
389 num_shards = 1;
390 }
391 auto shard_size = length / num_shards;
392 auto min_shard_size = 2 * k;
393 if (shard_size < min_shard_size) {
394 num_shards = length / min_shard_size;
395 }
396 if (num_shards <= 0) {
397 num_shards = 1;
398 #if GOOGLE_CUDA
399 } else if (num_shards > 1024) {
400 num_shards = 1024;
401 }
402 #elif TENSORFLOW_USE_ROCM
403 // ROCm can't execute with 1024 and requires an explicit
404 // amdgpu_flat_work_group_size attribute with >256
405 } else if (num_shards > 256) {
406 num_shards = 256;
407 }
408 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
409 }
410 // We are limited by the amount of shared memory we have per block.
411 auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
412
413 TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
414 shared_memory_size, stream, input, length, k,
415 sorted, output, indices));
416 return cudaGetLastError();
417 }
418
419 struct SegmentOffsetCreator {
420 EIGEN_DEVICE_FUNC
SegmentOffsetCreatorSegmentOffsetCreator421 SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
422
operatorSegmentOffsetCreator423 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
424 return idx * num_cols_;
425 }
426
427 int num_cols_;
428 };
429
430 struct ColumnIndexCreator {
ColumnIndexCreatorColumnIndexCreator431 ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
432
operatorColumnIndexCreator433 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
434 const Eigen::array<int, 1>& ix) const {
435 return ix[0] % num_cols_;
436 }
437
438 int num_cols_;
439 };
440
441 template <typename T>
LaunchSortKernel(OpKernelContext * ctx,const T * input,int num_rows,int num_cols,int k,typename TTypes<T,2>::Tensor values,TTypes<int,2>::Tensor indices)442 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
443 int num_cols, int k,
444 typename TTypes<T, 2>::Tensor values,
445 TTypes<int, 2>::Tensor indices) {
446 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
447 const auto& cu_stream = GetGpuStream(ctx);
448 size_t temp_storage_bytes = -1;
449
450 // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that
451 // tensor with an iterator that directly returns the correct value.
452 Tensor input_indices;
453 TF_RETURN_IF_ERROR(ctx->allocate_temp(
454 DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
455 auto input_indices_t = To32Bit(input_indices.flat<int32>());
456 input_indices_t.device(d) =
457 input_indices_t.generate(ColumnIndexCreator(num_cols));
458
459 gpuprim::CountingInputIterator<int> counting_iter(0);
460 gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
461 gpuprim::CountingInputIterator<int>>
462 segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
463
464 Tensor temp_values;
465 Tensor temp_indices;
466 T* sorted_values_ptr;
467 int* sorted_indices_ptr;
468 if (k == num_cols) {
469 // Doing a full sort, no intermediate values needed.
470 sorted_values_ptr = values.data();
471 sorted_indices_ptr = indices.data();
472 } else {
473 // Need to create intermediate values for sorting.
474 TF_RETURN_IF_ERROR(ctx->allocate_temp(
475 DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
476 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
477 TensorShape({num_rows, num_cols}),
478 &temp_values));
479 sorted_indices_ptr = temp_indices.flat<int32>().data();
480 sorted_values_ptr = temp_values.flat<T>().data();
481 }
482
483 auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
484 /* d_temp_storage */ nullptr,
485 /* temp_storage_bytes */ temp_storage_bytes,
486 /* d_keys_in */ input,
487 /* d_keys_out */ sorted_values_ptr,
488 /* d_values_in */ input_indices_t.data(),
489 /* d_values_out */ sorted_indices_ptr,
490 /* num_items */ num_cols * num_rows,
491 /* num_segments */ num_rows,
492 /* d_begin_offsets */ segment_offsets_t,
493 /* d_end_offsets */ segment_offsets_t + 1,
494 /* begin_bit */ 0,
495 /* end_bit */ sizeof(T) * 8,
496 /* stream */ cu_stream);
497 if (err != cudaSuccess) {
498 return errors::Internal(
499 "TopKOp: Could not launch "
500 "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
501 "temp_storage_bytes, status: ",
502 cudaGetErrorString(err));
503 }
504 Tensor temp_storage;
505 TF_RETURN_IF_ERROR(ctx->allocate_temp(
506 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
507 &temp_storage));
508 err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
509 /* d_temp_storage */ temp_storage.flat<int8>().data(),
510 /* temp_storage_bytes */ temp_storage_bytes,
511 /* d_keys_in */ input,
512 /* d_keys_out */ sorted_values_ptr,
513 /* d_values_in */ input_indices_t.data(),
514 /* d_values_out */ sorted_indices_ptr,
515 /* num_items */ num_cols * num_rows,
516 /* num_segments */ num_rows,
517 /* d_begin_offsets */ segment_offsets_t,
518 /* d_end_offsets */ segment_offsets_t + 1,
519 /* begin_bit */ 0,
520 /* end_bit */ sizeof(T) * 8,
521 /* stream */ cu_stream);
522 if (err != cudaSuccess) {
523 return errors::Internal(
524 "TopKOp: Could not launch "
525 "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
526 "temp_storage_bytes: ",
527 temp_storage_bytes, ", status: ", cudaGetErrorString(err));
528 }
529 if (k < num_cols) {
530 // Need to copy subsets of sorted_indices and sorted_outputs to
531 // indices and outputs.
532 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
533 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
534 To32Bit(indices).device(d) =
535 To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
536 To32Bit(values).device(d) =
537 To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
538 }
539 return Status::OK();
540 }
541
542 } // end namespace impl
543
544 namespace functor {
545
546 template <typename T>
547 struct TopKFunctor<GPUDevice, T> {
548 static EIGEN_ALWAYS_INLINE Status
549 Compute(OpKernelContext* context, bool sorted, int k,
550 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
551 const int64 num_cols, typename TTypes<T, 2>::Tensor values,
552 typename TTypes<int, 2>::Tensor indices) {
553 // For small k, use the heap implementation. For larger k, use
554 // the in-place gpuprim sort. For k == num_cols, always use the
555 // in-place gpuprim sort. The thresholds for n and k were determined
556 // empirically.
557 if (num_cols <= 1000 || k == num_cols || k >= 100) {
558 return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
559 k, values, indices);
560 } else {
561 const auto& cu_stream = GetGpuStream(context);
562 auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
563 input.data(), num_rows, num_cols, k,
564 sorted, values.data(), indices.data());
565 if (err != cudaSuccess) {
566 return errors::Internal(
567 "Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
568 } else {
569 return Status::OK();
570 }
571 }
572 }
573 };
574
575 } // end namespace functor
576 } // namespace tensorflow
577
578 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
579
580 #endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
581