• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
16 #define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #define EIGEN_USE_GPU
21 
22 #include <cmath>
23 #include <string>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/gpu_prim.h"
32 #include "tensorflow/core/kernels/topk_op.h"
33 #include "tensorflow/core/lib/gtl/top_n.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/gpu_kernel_helper.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::GpuDevice GPUDevice;
41 
42 namespace impl {
43 
44 enum class HeapType { kMinHeap, kMaxHeap };
45 enum class PreferIndices { kLower, kHigher };
46 
47 template <typename T>
48 struct Entry {
49   int index;
50   T value;
51 
52   // Test-only.
greaterEntry53   static bool greater(const Entry<T>& a, const Entry<T>& b) {
54     if (a.value == b.value) {
55       return a.index < b.index;
56     }
57     return a.value > b.value;
58   }
59 };
60 
61 template <typename T>
62 struct LinearData {
63   typedef impl::Entry<T> Entry;
64 
65   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
66 
get_indexLinearData67   __device__ int get_index(int i) const { return data[i].index; }
get_valueLinearData68   __device__ T get_value(int i) const { return data[i].value; }
69 
70   Entry* const data;
71 };
72 
73 template <typename T>
74 struct IndirectLinearData {
75   typedef impl::Entry<T> Entry;
76 
77   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
78 
get_indexIndirectLinearData79   __device__ int get_index(int i) const {
80     return backing_data[data[i].index].index;
81   }
get_valueIndirectLinearData82   __device__ T get_value(int i) const { return data[i].value; }
83 
84   Entry* const data;
85   Entry* const backing_data;
86 };
87 
88 template <typename T>
89 struct StridedData {
90   typedef impl::Entry<T> Entry;
91 
92   __device__ Entry& operator[](std::size_t index) const {
93     return data[index * blockDim.x + threadIdx.x];
94   }
95 
get_indexStridedData96   __device__ int get_index(int i) const { return (*this)[i].index; }
get_valueStridedData97   __device__ T get_value(int i) const { return (*this)[i].value; }
98 
99   Entry* const data;
100 };
101 
102 // A heap of Entry<T> that can either work as a min-heap or as a max-heap.
103 template <HeapType heapType, PreferIndices preferIndices,
104           template <typename> class Data, typename T>
105 struct IndexedHeap {
106   typedef typename Data<T>::Entry Entry;
107   const Data<T> data;
IndexedHeapIndexedHeap108   __device__ IndexedHeap(const Data<T>& d) : data(d) {}
109 
is_aboveIndexedHeap110   __device__ bool is_above(int left, int right) {
111     T left_value = data.get_value(left);
112     T right_value = data.get_value(right);
113     if (left_value == right_value) {
114       if (preferIndices == PreferIndices::kLower) {
115         return data.get_index(left) < data.get_index(right);
116       } else {
117         return data.get_index(left) > data.get_index(right);
118       }
119     }
120     if (heapType == HeapType::kMinHeap) {
121       return left_value < right_value;
122     } else {
123       return left_value > right_value;
124     }
125   }
126 
assignIndexedHeap127   __device__ void assign(int i, const Entry& entry) { data[i] = entry; }
128 
push_upIndexedHeap129   __device__ void push_up(int i) {
130     int child = i;
131     int parent;
132     for (; child > 0; child = parent) {
133       parent = (child - 1) / 2;
134       if (!is_above(child, parent)) {
135         // Heap property satisfied.
136         break;
137       }
138       swap(child, parent);
139     }
140   }
141 
swapIndexedHeap142   __device__ void swap(int a, int b) {
143     auto tmp = data[b];
144     data[b] = data[a];
145     data[a] = tmp;
146   }
147 
push_root_downIndexedHeap148   __device__ void push_root_down(int k) { push_down(0, k); }
149 
150   // MAX-HEAPIFY in Cormen
push_downIndexedHeap151   __device__ void push_down(int node, int k) {
152     while (true) {
153       const int left = 2 * node + 1;
154       const int right = left + 1;
155       int smallest = node;
156       if (left < k && is_above(left, smallest)) {
157         smallest = left;
158       }
159       if (right < k && is_above(right, smallest)) {
160         smallest = right;
161       }
162       if (smallest == node) {
163         break;
164       }
165       swap(smallest, node);
166       node = smallest;
167     }
168   }
169 
170   // BUILD-MAX-HEAPIFY in Cormen
buildIndexedHeap171   __device__ void build(int k) {
172     for (int node = (k - 1) / 2; node >= 0; node--) {
173       push_down(node, k);
174     }
175   }
176 
177   // HEAP-EXTRACT-MAX in Cormen
remove_rootIndexedHeap178   __device__ void remove_root(int k) {
179     data[0] = data[k - 1];
180     push_root_down(k - 1);
181   }
182 
183   // in-place HEAPSORT in Cormen
184   // This method destroys the heap property.
sortIndexedHeap185   __device__ void sort(int k) {
186     for (int slot = k - 1; slot > 0; slot--) {
187       // This is like remove_root but we insert the element at the end.
188       swap(slot, 0);
189       // Heap is now an element smaller.
190       push_root_down(/*k=*/slot);
191     }
192   }
193 
replace_rootIndexedHeap194   __device__ void replace_root(const Entry& entry, int k) {
195     data[0] = entry;
196     push_root_down(k);
197   }
198 
rootIndexedHeap199   __device__ const Entry& root() { return data[0]; }
200 };
201 
202 template <HeapType heapType, PreferIndices preferIndices,
203           template <typename> class Data, typename T>
make_indexed_heap(typename Data<T>::Entry * data)204 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
205     typename Data<T>::Entry* data) {
206   return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
207 }
208 
209 // heapTopK walks over [input, input+length) with `step_size` stride starting at
210 // `start_index`.
211 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
212 // access elements in `heap_entries`. If sorted=true, the elements will be
213 // sorted at the end.
214 template <typename T, template <typename> class Data = LinearData>
215 __device__ void heapTopK(const T* __restrict__ input, int length, int k,
216                          Entry<T>* __restrict__ heap_entries,
217                          bool sorted = false, int start_index = 0,
218                          int step_size = 1) {
219   assert(k <= length);
220 
221   auto heap =
222       make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
223           heap_entries);
224 
225   int heap_end_index = start_index + k * step_size;
226   if (heap_end_index > length) {
227     heap_end_index = length;
228   }
229   // Initialize the min-heap.
230   for (int index = start_index, slot = 0; index < heap_end_index;
231        index += step_size, slot++) {
232     heap.assign(slot, {index, input[index]});
233   }
234 
235   heap.build(k);
236 
237   // Now iterate over the remaining items.
238   // If an item is smaller than the min element, it is not amongst the top k.
239   // Otherwise, replace the min element with it and push upwards.
240   for (int index = heap_end_index; index < length; index += step_size) {
241     // We prefer elements with lower indices. This is given here.
242     // Later elements automatically have higher indices, so can be discarded.
243     if (input[index] > heap.root().value) {
244       // This element should replace the min.
245       heap.replace_root({index, input[index]}, k);
246     }
247   }
248 
249   // Sort if wanted.
250   if (sorted) {
251     heap.sort(k);
252   }
253 }
254 
255 // mergeShards performs a top-k merge on `num_shards` many sorted streams that
256 // are sorted and stored in `entries` in a strided way:
257 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
258 // The overall top k elements are written to `top_k_values` and their indices
259 // to top_k_indices.
260 // `top_k_heap` is used as temporary storage for the merge heap.
261 template <typename T>
mergeShards(int num_shards,int k,Entry<T> * __restrict__ entries,Entry<T> * __restrict__ top_k_heap,T * top_k_values,int * top_k_indices)262 __device__ void mergeShards(int num_shards, int k,
263                             Entry<T>* __restrict__ entries,
264                             Entry<T>* __restrict__ top_k_heap, T* top_k_values,
265                             int* top_k_indices) {
266   // If k < num_shards, we can use a min-heap with k elements to get the top k
267   // of the sorted blocks.
268   // If k > num_shards, we can initialize a min-heap with the top element from
269   // each sorted block.
270   const int heap_size = k < num_shards ? k : num_shards;
271 
272   // Min-heap part.
273   {
274     auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
275                                 IndirectLinearData, T>{
276         IndirectLinearData<T>{top_k_heap, entries}};
277     // Initialize the heap as a min-heap.
278     for (int slot = 0; slot < heap_size; slot++) {
279       min_heap.assign(slot, {slot, entries[slot].value});
280     }
281     min_heap.build(heap_size);
282 
283     // Now perform top k with the remaining shards (if num_shards > heap_size).
284     for (int shard = heap_size; shard < num_shards; shard++) {
285       const auto entry = entries[shard];
286       const auto root = min_heap.root();
287       if (entry.value < root.value) {
288         continue;
289       }
290       if (entry.value == root.value &&
291           entry.index > entries[root.index].index) {
292         continue;
293       }
294       // This element should replace the min.
295       min_heap.replace_root({shard, entry.value}, heap_size);
296     }
297   }
298 
299   // Max-part.
300   {
301     // Turn the min-heap into a max-heap in-place.
302     auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
303                                 IndirectLinearData, T>{
304         IndirectLinearData<T>{top_k_heap, entries}};
305     // Heapify into a max heap.
306     max_heap.build(heap_size);
307 
308     // Now extract the minimum k-1 times.
309     // k is treated specially.
310     const int last_k = k - 1;
311     for (int rank = 0; rank < last_k; rank++) {
312       const Entry<T>& max_element = max_heap.root();
313       top_k_values[rank] = max_element.value;
314       int shard_index = max_element.index;
315       top_k_indices[rank] = entries[shard_index].index;
316       int next_shard_index = shard_index + num_shards;
317       // For rank < k-1, each top k heap still contains at least 1 element,
318       // so we can draw a replacement.
319       max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
320                             heap_size);
321     }
322 
323     // rank == last_k.
324     const Entry<T>& max_element = max_heap.root();
325     top_k_values[last_k] = max_element.value;
326     int shard_index = max_element.index;
327     top_k_indices[last_k] = entries[shard_index].index;
328   }
329 }
330 
331 #if GOOGLE_CUDA
332 extern __shared__ char shared_memory[];
333 #endif  // GOOGLE_CUDA
334 
335 template <typename T>
336 #if TENSORFLOW_USE_ROCM
337 __attribute__((amdgpu_flat_work_group_size(1, 256)))
338 #endif  // TENSORFLOW_USE_ROCM
339 __global__ void
TopKKernel(const T * __restrict__ input,int length,int k,bool sorted,T * __restrict__ output,int * __restrict__ indices)340 TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
341            T* __restrict__ output, int* __restrict__ indices) {
342 #if TENSORFLOW_USE_ROCM
343   HIP_DYNAMIC_SHARED(char, shared_memory);
344 #endif  // TENSORFLOW_USE_ROCM
345 
346   const int batch_index = blockIdx.x;
347   const T* batch_input = input + batch_index * length;
348 
349   const int thread_index = threadIdx.x;
350   const int thread_count = blockDim.x;
351 
352   Entry<T>* shared_entries = (Entry<T>*)shared_memory;
353 
354   heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
355                            thread_index, thread_count);
356 
357   __syncthreads();
358   if (thread_index == 0) {
359     const int offset = batch_index * k;
360     auto batch_output = output + offset;
361     auto batch_indices = indices + offset;
362     Entry<T>* top_k_heap = shared_entries + thread_count * k;
363 
364     // TODO(blackhc): Erich says: Performance can likely be improved
365     // significantly by having the merge be done by multiple threads rather than
366     // just one.  ModernGPU has some nice primitives that could help with this.
367     mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
368                 batch_indices);
369   }
370 }
371 
372 template <typename T>
LaunchTopKKernel(const gpuStream_t & stream,int num_shards,const T * input,int batch_size,int length,int k,bool sorted,T * output,int * indices)373 cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
374                            const T* input, int batch_size, int length, int k,
375                            bool sorted, T* output, int* indices) {
376   // This code assumes that k is small enough that the computation
377   // fits inside shared memory (hard coded to 48KB).  In practice this
378   // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
379   // The calculation is:
380   //   shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
381 
382   // Use as many shards as possible.
383   if (num_shards <= 0) {
384     constexpr auto shared_memory_size = 48 << 10;  // 48 KB
385     const auto heap_size = k * sizeof(Entry<T>);
386     // shared_memory_size = (num_shards + 1) * heap_size <=>
387     num_shards = shared_memory_size / heap_size - 1;
388     if (num_shards <= 0) {
389       num_shards = 1;
390     }
391     auto shard_size = length / num_shards;
392     auto min_shard_size = 2 * k;
393     if (shard_size < min_shard_size) {
394       num_shards = length / min_shard_size;
395     }
396     if (num_shards <= 0) {
397       num_shards = 1;
398 #if GOOGLE_CUDA
399     } else if (num_shards > 1024) {
400       num_shards = 1024;
401     }
402 #elif TENSORFLOW_USE_ROCM
403       // ROCm can't execute with 1024 and requires an explicit
404       // amdgpu_flat_work_group_size attribute with >256
405     } else if (num_shards > 256) {
406       num_shards = 256;
407     }
408 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
409   }
410   // We are limited by the amount of shared memory we have per block.
411   auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
412 
413   TF_CHECK_OK(GpuLaunchKernel(TopKKernel<T>, batch_size, num_shards,
414                               shared_memory_size, stream, input, length, k,
415                               sorted, output, indices));
416   return cudaGetLastError();
417 }
418 
419 struct SegmentOffsetCreator {
420   EIGEN_DEVICE_FUNC
SegmentOffsetCreatorSegmentOffsetCreator421   SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
422 
operatorSegmentOffsetCreator423   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
424     return idx * num_cols_;
425   }
426 
427   int num_cols_;
428 };
429 
430 struct ColumnIndexCreator {
ColumnIndexCreatorColumnIndexCreator431   ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
432 
operatorColumnIndexCreator433   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
434       const Eigen::array<int, 1>& ix) const {
435     return ix[0] % num_cols_;
436   }
437 
438   int num_cols_;
439 };
440 
441 template <typename T>
LaunchSortKernel(OpKernelContext * ctx,const T * input,int num_rows,int num_cols,int k,typename TTypes<T,2>::Tensor values,TTypes<int,2>::Tensor indices)442 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
443                         int num_cols, int k,
444                         typename TTypes<T, 2>::Tensor values,
445                         TTypes<int, 2>::Tensor indices) {
446   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
447   const auto& cu_stream = GetGpuStream(ctx);
448   size_t temp_storage_bytes = -1;
449 
450   // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that
451   // tensor with an iterator that directly returns the correct value.
452   Tensor input_indices;
453   TF_RETURN_IF_ERROR(ctx->allocate_temp(
454       DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
455   auto input_indices_t = To32Bit(input_indices.flat<int32>());
456   input_indices_t.device(d) =
457       input_indices_t.generate(ColumnIndexCreator(num_cols));
458 
459   gpuprim::CountingInputIterator<int> counting_iter(0);
460   gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
461                                   gpuprim::CountingInputIterator<int>>
462       segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
463 
464   Tensor temp_values;
465   Tensor temp_indices;
466   T* sorted_values_ptr;
467   int* sorted_indices_ptr;
468   if (k == num_cols) {
469     // Doing a full sort, no intermediate values needed.
470     sorted_values_ptr = values.data();
471     sorted_indices_ptr = indices.data();
472   } else {
473     // Need to create intermediate values for sorting.
474     TF_RETURN_IF_ERROR(ctx->allocate_temp(
475         DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
476     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
477                                           TensorShape({num_rows, num_cols}),
478                                           &temp_values));
479     sorted_indices_ptr = temp_indices.flat<int32>().data();
480     sorted_values_ptr = temp_values.flat<T>().data();
481   }
482 
483   auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
484       /* d_temp_storage */ nullptr,
485       /* temp_storage_bytes */ temp_storage_bytes,
486       /* d_keys_in */ input,
487       /* d_keys_out */ sorted_values_ptr,
488       /* d_values_in */ input_indices_t.data(),
489       /* d_values_out */ sorted_indices_ptr,
490       /* num_items */ num_cols * num_rows,
491       /* num_segments */ num_rows,
492       /* d_begin_offsets */ segment_offsets_t,
493       /* d_end_offsets */ segment_offsets_t + 1,
494       /* begin_bit */ 0,
495       /* end_bit */ sizeof(T) * 8,
496       /* stream */ cu_stream);
497   if (err != cudaSuccess) {
498     return errors::Internal(
499         "TopKOp: Could not launch "
500         "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
501         "temp_storage_bytes, status: ",
502         cudaGetErrorString(err));
503   }
504   Tensor temp_storage;
505   TF_RETURN_IF_ERROR(ctx->allocate_temp(
506       DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
507       &temp_storage));
508   err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
509       /* d_temp_storage */ temp_storage.flat<int8>().data(),
510       /* temp_storage_bytes */ temp_storage_bytes,
511       /* d_keys_in */ input,
512       /* d_keys_out */ sorted_values_ptr,
513       /* d_values_in */ input_indices_t.data(),
514       /* d_values_out */ sorted_indices_ptr,
515       /* num_items */ num_cols * num_rows,
516       /* num_segments */ num_rows,
517       /* d_begin_offsets */ segment_offsets_t,
518       /* d_end_offsets */ segment_offsets_t + 1,
519       /* begin_bit */ 0,
520       /* end_bit */ sizeof(T) * 8,
521       /* stream */ cu_stream);
522   if (err != cudaSuccess) {
523     return errors::Internal(
524         "TopKOp: Could not launch "
525         "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
526         "temp_storage_bytes: ",
527         temp_storage_bytes, ", status: ", cudaGetErrorString(err));
528   }
529   if (k < num_cols) {
530     // Need to copy subsets of sorted_indices and sorted_outputs to
531     // indices and outputs.
532     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
533     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
534     To32Bit(indices).device(d) =
535         To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
536     To32Bit(values).device(d) =
537         To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
538   }
539   return Status::OK();
540 }
541 
542 }  // end namespace impl
543 
544 namespace functor {
545 
546 template <typename T>
547 struct TopKFunctor<GPUDevice, T> {
548   static EIGEN_ALWAYS_INLINE Status
549   Compute(OpKernelContext* context, bool sorted, int k,
550           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
551           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
552           typename TTypes<int, 2>::Tensor indices) {
553     // For small k, use the heap implementation.  For larger k, use
554     // the in-place gpuprim sort.  For k == num_cols, always use the
555     // in-place gpuprim sort.  The thresholds for n and k were determined
556     // empirically.
557     if (num_cols <= 1000 || k == num_cols || k >= 100) {
558       return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
559                                     k, values, indices);
560     } else {
561       const auto& cu_stream = GetGpuStream(context);
562       auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
563                                         input.data(), num_rows, num_cols, k,
564                                         sorted, values.data(), indices.data());
565       if (err != cudaSuccess) {
566         return errors::Internal(
567             "Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
568       } else {
569         return Status::OK();
570       }
571     }
572   }
573 };
574 
575 }  // end namespace functor
576 }  // namespace tensorflow
577 
578 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
579 
580 #endif  // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
581