• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #include "tensorflow/core/kernels/image/non_max_suppression_op.h"
19 
20 #include <limits>
21 
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_types.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/kernels/gpu_prim.h"
28 #include "tensorflow/core/util/gpu_kernel_helper.h"
29 #include "tensorflow/core/util/gpu_launch_config.h"
30 #include "tensorflow/stream_executor/stream_executor.h"
31 
32 struct
33 #if GOOGLE_CUDA
34     __align__(16)
35 #endif
36         Box {
37   float x1, y1, x2, y2;
38 };
39 
40 namespace tensorflow {
41 typedef Eigen::GpuDevice GPUDevice;
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 
44 // This is the width of the bitmask for masking boxes for each thread.
45 // This needs to be a multiple of 2(a POD width usually) so that division and
46 // modulo can be implemented as bit operations during host selection.
47 constexpr int kNmsBoxesPerThread = 8 * sizeof(int);
48 // Helper to calculate modulo mask and shift bits.
49 // For kNmsBoxesPerThread=32 ModuloMask will be 31, i.e 0x1F thus
50 // i % 32 == i & 31. Similarly ShiftBits will be 5 so that
51 // i / 32 == i >> 5. Using these bit operations should reduce the stall on host
52 // thread.
NumBits(int n)53 constexpr int NumBits(int n) { return (n == 0) ? 0 : NumBits(n >> 1) + 1; }
54 constexpr int kNmsBoxesPerThreadModuloMask = kNmsBoxesPerThread - 1;
55 constexpr int kNmsBoxesPerThreadShiftBits =
56     NumBits(kNmsBoxesPerThreadModuloMask);
57 
58 constexpr int kNmsBlockDim = 16;
59 constexpr int kNmsBlockDimMax = 128;
60 constexpr int kNmsChunkSize = 2000;
61 
62 template <typename T>
Swap(T & a,T & b)63 __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) {
64   T c(a);
65   a = b;
66   b = c;
67 }
68 
69 // Check whether two boxes have an IoU greater than threshold.
70 template <typename T>
OverThreshold(const Box * a,const Box * b,const float a_area,const T iou_threshold)71 __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
72                                                   const float a_area,
73                                                   const T iou_threshold) {
74   const float b_area = (b->x2 - b->x1) * (b->y2 - b->y1);
75   if (a_area == 0.0f || b_area == 0.0f) return false;
76   const float xx1 = fmaxf(a->x1, b->x1);
77   const float yy1 = fmaxf(a->y1, b->y1);
78   const float xx2 = fminf(a->x2, b->x2);
79   const float yy2 = fminf(a->y2, b->y2);
80 
81   // fdimf computes the positive difference between xx2+1 and xx1.
82   const float w = fdimf(xx2, xx1);
83   const float h = fdimf(yy2, yy1);
84   const float intersection = w * h;
85 
86   // Testing for aa/bb > t
87   // eq with aa > bb*t (b is !=0)
88   // avoiding divisions.
89   const float aa = intersection;
90   const float bb = a_area + b_area - intersection;
91   const float bt = bb * iou_threshold;
92   return aa > bt;
93 }
94 
95 template <bool flip_box>
96 __device__ EIGEN_STRONG_INLINE void Flipped(Box& box);
97 
98 template <>
Flipped(Box & box)99 __device__ EIGEN_STRONG_INLINE void Flipped<false>(Box& box) {}
100 
101 template <>
Flipped(Box & box)102 __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
103   if (box.x1 > box.x2) Swap(box.x1, box.x2);
104   if (box.y1 > box.y2) Swap(box.y1, box.y2);
105 }
106 template <typename T>
CheckBit(T * bit_mask,int bit)107 __device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
108   constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1;
109   constexpr int kRemainderMask = 8 * sizeof(T) - 1;
110   int bin = bit >> kShiftLen;
111   return (bit_mask[bin] >> (bit & kRemainderMask)) & 1;
112 }
113 
114 // Produce a global bitmask (result_mask) of selected boxes from bitmask
115 // generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask
116 // is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
NMSReduce(const int * bitmask,const int bit_mask_len,const int num_boxes,const int max_boxes,char * result_mask)117 __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
118                           const int num_boxes, const int max_boxes,
119                           char* result_mask) {
120   extern __shared__ int local[];
121   // set global mask to accept all boxes
122   for (int box : GpuGridRangeX(bit_mask_len)) {
123     local[box] = 0xFFFFFFFF;
124   }
125   __syncthreads();
126   int accepted_boxes = 0;
127   for (int box = 0; box < num_boxes - 1; ++box) {
128     // if current box is masked by an earlier box, skip it.
129     if (!CheckBit(local, box)) {
130       continue;
131     }
132     accepted_boxes += 1;
133     int offset = box * bit_mask_len;
134     // update global mask with current box's mask
135     for (int b : GpuGridRangeX(bit_mask_len)) {
136       local[b] &= ~bitmask[offset + b];
137     }
138     __syncthreads();
139     if (accepted_boxes > max_boxes) break;
140   }
141   // copy global mask to result_max char array. char array is needed for
142   // cub::DeviceSelect later.
143   for (int box : GpuGridRangeX(num_boxes)) {
144     result_mask[box] = CheckBit(local, box);
145   }
146 }
147 
148 // For each box, compute a bitmask of boxes which has an overlap with given box
149 // above threshold.
150 //
151 // Starting from highes scoring box, mark any box which has IoU>threshold with
152 // given box. Each thread processes a kNmsBoxesPerThread boxes per stride, and
153 // each box has bitmask of overlaps of length bit_mask_len.
154 //
155 // If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the
156 // coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have
157 // x1<x2 and y1<y2.
158 template <bool flip_box>
159 __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
NMSKernel(const Box * d_desc_sorted_boxes,const int num_boxes,const float iou_threshold,const int bit_mask_len,int * d_delete_mask)160     void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
161                    const float iou_threshold, const int bit_mask_len,
162                    int* d_delete_mask) {
163   // Storing boxes used by this CUDA block in the shared memory.
164   __shared__ Box shared_i_boxes[kNmsBlockDim];
165   // Same thing with areas
166   __shared__ float shared_i_areas[kNmsBlockDim];
167   // The condition of the for loop is common to all threads in the block.
168   // This is necessary to be able to call __syncthreads() inside of the loop.
169   for (int i_block_offset = blockIdx.x * blockDim.x; i_block_offset < num_boxes;
170        i_block_offset += blockDim.x * gridDim.x) {
171     const int i = i_block_offset + threadIdx.x;
172     if (i < num_boxes) {
173       // One 1D line load the boxes for x-dimension.
174       if (threadIdx.y == 0) {
175         Box box = d_desc_sorted_boxes[i];
176         Flipped<flip_box>(box);
177         shared_i_boxes[threadIdx.x] = box;
178         shared_i_areas[threadIdx.x] = (box.x2 - box.x1) * (box.y2 - box.y1);
179       }
180     }
181     __syncthreads();
182     for (int j_thread_offset =
183              kNmsBoxesPerThread * (blockIdx.y * blockDim.y + threadIdx.y);
184          j_thread_offset < num_boxes;
185          j_thread_offset += kNmsBoxesPerThread * blockDim.y * gridDim.y) {
186       // Note : We can do everything using multiplication,
187       // and use fp16 - we are comparing against a low precision
188       // threshold.
189       int above_threshold = 0;
190       // Make sure that threads are within valid domain.
191       bool valid = false;
192       // Loop over the next kNmsBoxesPerThread boxes and set corresponding bit
193       // if it is overlapping with current box
194       for (int ib = 0; ib < kNmsBoxesPerThread; ++ib) {
195         // This thread will compare Box i and Box j.
196         const int j = j_thread_offset + ib;
197         if (i >= j || i >= num_boxes || j >= num_boxes) continue;
198         valid = true;
199         Box j_box = d_desc_sorted_boxes[j];
200         const Box i_box = shared_i_boxes[threadIdx.x];
201         Flipped<flip_box>(j_box);
202         if (OverThreshold<float>(&i_box, &j_box, shared_i_areas[threadIdx.x],
203                                  iou_threshold)) {
204           // we have score[j] <= score[i].
205           above_threshold |= (1U << ib);
206         }
207       }
208       if (valid) {
209         d_delete_mask[i * bit_mask_len + j_thread_offset / kNmsBoxesPerThread] =
210             above_threshold;
211       }
212     }
213     __syncthreads();  // making sure everyone is done reading shared memory.
214   }
215 }
216 // Variadic template helpers for Index selecting multiple arrays at the same
217 // time
218 template <typename Index>
SelectHelper(const Index i_selected,const Index i_original)219 __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
220                                                  const Index i_original) {}
221 
222 template <typename Index, typename T, typename... Args>
SelectHelper(const Index i_selected,const Index i_original,const T * original,T * selected,Args...args)223 __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
224                                                  const Index i_original,
225                                                  const T* original, T* selected,
226                                                  Args... args) {
227   selected[i_selected] = original[i_original];
228   SelectHelper(i_selected, i_original, args...);
229 }
230 
231 // Helper template to select elements from original arrays using the index
232 // mapping and store into selected array. Each array sharing same mapping need
233 // to be passed as pairs of pointers to original and selected arrays. For
234 // selecting 2 arrays call would be
235 // IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
236 // selected2).
237 template <typename Index, typename T, typename... Args>
IndexMultiSelect(const int num_elements,const Index * indices,const T * original,T * selected,Args...args)238 __global__ void IndexMultiSelect(const int num_elements, const Index* indices,
239                                  const T* original, T* selected, Args... args) {
240   for (const int idx : GpuGridRangeX(num_elements)) {
241     SelectHelper(idx, indices[idx], original, selected, args...);
242   }
243 }
244 
245 template <typename T>
Iota(const int num_elements,const T offset,T * to_fill)246 __global__ void Iota(const int num_elements, const T offset, T* to_fill) {
247   for (int idx : GpuGridRangeX(num_elements)) {
248     to_fill[idx] = static_cast<T>(idx) + offset;
249   }
250 }
251 
NmsGpu(const float * d_sorted_boxes_float_ptr,const int num_boxes,const float iou_threshold,int * d_selected_indices,int * h_nkeep,OpKernelContext * context,const int max_boxes,bool flip_boxes)252 Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
253               const float iou_threshold, int* d_selected_indices, int* h_nkeep,
254               OpKernelContext* context, const int max_boxes, bool flip_boxes) {
255   // Making sure we respect the __align(16)__
256   // we promised to the compiler.
257   auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
258   if ((iptr & 15) != 0) {
259     return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
260   }
261   // allocate bitmask arrays on host and on device
262   Tensor h_num_selected, d_nms_mask;
263   const int bit_mask_len =
264       (num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
265 
266   int64 max_nms_mask_size = num_boxes * bit_mask_len;
267   TF_RETURN_IF_ERROR(context->allocate_temp(
268       DataType::DT_INT32, TensorShape({max_nms_mask_size}), &d_nms_mask));
269   // reset data sensitive tensors
270   auto device = context->eigen_gpu_device();
271   auto config = GetGpuLaunchConfig(d_nms_mask.NumElements(), device);
272   TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
273                               config.thread_per_block, 0, device.stream(),
274                               config.virtual_thread_count,
275                               d_nms_mask.flat<int32>().data()));
276 
277   AllocatorAttributes alloc_attr;
278   alloc_attr.set_on_host(true);
279   alloc_attr.set_gpu_compatible(true);
280   // Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
281   // using it as a ring buffer. However savings should be a few MB .
282   TF_RETURN_IF_ERROR(context->allocate_temp(
283       DataType::DT_INT32, TensorShape({1}), &h_num_selected, alloc_attr));
284 
285   int* d_delete_mask = d_nms_mask.flat<int>().data();
286   int* h_selected_count = h_num_selected.flat<int>().data();
287   const Box* d_sorted_boxes =
288       reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
289   dim3 block_dim, thread_block;
290   int num_blocks = (num_boxes + kNmsBlockDim - 1) / kNmsBlockDim;
291   num_blocks = std::max(std::min(num_blocks, kNmsBlockDimMax), 1);
292   block_dim.x = num_blocks;
293   block_dim.y = num_blocks;
294   block_dim.z = 1;
295   thread_block.x = kNmsBlockDim;
296   thread_block.y = kNmsBlockDim;
297   thread_block.z = 1;
298   if (flip_boxes) {
299     TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true>, block_dim, thread_block, 0,
300                                 device.stream(), d_sorted_boxes, num_boxes,
301                                 iou_threshold, bit_mask_len, d_delete_mask));
302   } else {
303     TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false>, block_dim, thread_block, 0,
304                                 device.stream(), d_sorted_boxes, num_boxes,
305                                 iou_threshold, bit_mask_len, d_delete_mask));
306   }
307   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
308   // Overlapping CPU computes and D2H memcpy
309   // both take about the same time
310 
311   config = GetGpuLaunchConfig(num_boxes, device);
312   Tensor selected_boxes;
313   TF_RETURN_IF_ERROR(context->allocate_temp(
314       DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes));
315   Tensor d_indices;
316   TF_RETURN_IF_ERROR(context->allocate_temp(
317       DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
318   TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
319                               config.thread_per_block, 0, device.stream(),
320                               config.virtual_thread_count, 0,
321                               d_indices.flat<int>().data()));
322 
323   char* selected = (char*)(selected_boxes.flat<int8>().data());
324   TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int),
325                               device.stream(), d_delete_mask, bit_mask_len,
326                               num_boxes, max_boxes, selected));
327   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
328   // do Cub::deviceSelect::flagged
329   size_t flagged_buffer_size = 0;
330   gpuprim::DeviceSelect::Flagged(static_cast<void*>(nullptr),  // temp_storage
331                                  flagged_buffer_size,
332                                  static_cast<int*>(nullptr),   // input
333                                  static_cast<char*>(nullptr),  // selection flag
334                                  static_cast<int*>(nullptr),   // selected items
335                                  static_cast<int*>(nullptr),   // num_selected
336                                  num_boxes, device.stream());
337   Tensor cub_scratch;
338   TF_RETURN_IF_ERROR(context->allocate_temp(
339       DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
340       &cub_scratch));
341   Tensor d_num_selected;
342   TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
343                                             TensorShape({1}), &d_num_selected));
344 
345   gpuprim::DeviceSelect::Flagged(
346       (void*)cub_scratch.flat<int8>().data(),  // temp_storage
347       flagged_buffer_size,
348       d_indices.flat<int>().data(),  // input
349       selected,                      // selection flag
350       d_selected_indices,            // selected items
351       d_num_selected.flat<int>().data(), num_boxes, device.stream());
352   gpuEvent_t copy_done;
353   TF_RETURN_IF_CUDA_ERROR(
354       gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
355   device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
356                             sizeof(int));
357   TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
358   TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
359   *h_nkeep = *h_selected_count;
360   gpuEventDestroy(copy_done);
361   return Status::OK();
362 }
363 
364 struct GreaterThanCubOp {
365   float threshold_;
GreaterThanCubOptensorflow::GreaterThanCubOp366   __host__ __device__ __forceinline__ GreaterThanCubOp(float threshold)
367       : threshold_(threshold) {}
operator ()tensorflow::GreaterThanCubOp368   __host__ __device__ __forceinline__ bool operator()(const float& val) const {
369     return (val > threshold_);
370   }
371 };
372 // Use DeviceSelect::If to count number of elements.
373 // TODO(sami) Not really a good way. Perhaps consider using thrust?
374 template <typename Op>
CountIf(OpKernelContext * context,const float * dev_array,const Op & op,int num_elements,int * result)375 Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
376                int num_elements, int* result) {
377   Tensor scratch_output;
378   Tensor workspace;
379   Tensor element_count;
380   size_t workspace_size = 0;
381   auto cuda_stream = tensorflow::GetGpuStream(context);
382   auto device = context->eigen_gpu_device();
383   gpuprim::DeviceSelect::If(nullptr, workspace_size,
384                             static_cast<float*>(nullptr),
385                             static_cast<float*>(nullptr),
386                             static_cast<int*>(nullptr), num_elements, op);
387 
388   TF_RETURN_IF_ERROR(context->allocate_temp(
389       DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
390   TF_RETURN_IF_ERROR(context->allocate_temp(
391       DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
392   TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
393                                             TensorShape({1}), &element_count));
394   gpuEvent_t copy_done;
395   TF_RETURN_IF_CUDA_ERROR(
396       gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
397   TF_RETURN_IF_CUDA_ERROR(gpuprim::DeviceSelect::If(
398       workspace.flat<int8>().data(), workspace_size, dev_array,
399       scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
400       num_elements, op, cuda_stream));
401   device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
402                             sizeof(int));
403   TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
404   TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
405   return Status::OK();
406 }
407 
DoNMS(OpKernelContext * context,const Tensor & boxes,const Tensor & scores,const int64_t max_output_size,const float iou_threshold_val,const float score_threshold,bool pad_to_max_output,int * num_saved_outputs)408 Status DoNMS(OpKernelContext* context, const Tensor& boxes,
409              const Tensor& scores, const int64_t max_output_size,
410              const float iou_threshold_val, const float score_threshold,
411              bool pad_to_max_output, int* num_saved_outputs) {
412   int num_boxes = boxes.dim_size(0);
413   size_t cub_sort_temp_storage_bytes = 0;
414   auto cuda_stream = GetGpuStream(context);
415   auto device = context->eigen_gpu_device();
416   // Calling cub with nullptrs as inputs will make it return
417   // workspace size needed for the operation instead of doing the operation.
418   // In this specific instance, cub_sort_temp_storage_bytes will contain the
419   // necessary workspace size for sorting after the call.
420   if (num_boxes == 0) {
421     Tensor* output_indices = nullptr;
422     TF_RETURN_IF_ERROR(
423         context->allocate_output(0, TensorShape({0}), &output_indices));
424     return Status::OK();
425   }
426 
427   cudaError_t cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
428       nullptr, cub_sort_temp_storage_bytes,
429       static_cast<float*>(nullptr),  // scores
430       static_cast<float*>(nullptr),  // sorted scores
431       static_cast<int*>(nullptr),    // input indices
432       static_cast<int*>(nullptr),    // sorted indices
433       num_boxes,                     // num items
434       0, 8 * sizeof(float),          // sort all bits
435       cuda_stream);
436   TF_RETURN_IF_CUDA_ERROR(cuda_ret);
437   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
438 
439   Tensor d_cub_sort_buffer;
440   TF_RETURN_IF_ERROR(context->allocate_temp(
441       DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}),
442       &d_cub_sort_buffer));
443   Tensor d_indices;
444   TF_RETURN_IF_ERROR(context->allocate_temp(
445       DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
446   Tensor d_sorted_indices;
447   TF_RETURN_IF_ERROR(context->allocate_temp(
448       DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices));
449   Tensor d_selected_indices;
450   TF_RETURN_IF_ERROR(context->allocate_temp(
451       DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices));
452   Tensor d_sorted_scores;
453   TF_RETURN_IF_ERROR(context->allocate_temp(
454       DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores));
455   Tensor d_sorted_boxes;
456   TF_RETURN_IF_ERROR(context->allocate_temp(
457       DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes));
458 
459   // this will return sorted scores and their indices
460   auto config = GetGpuLaunchConfig(num_boxes, device);
461   // initialize box and score indices
462   TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
463                               config.thread_per_block, 0, device.stream(),
464                               config.virtual_thread_count, 0,
465                               d_indices.flat<int>().data()));
466   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
467   cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
468       d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
469       scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
470       d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
471       num_boxes, 0,
472       8 * sizeof(float),  // sort all bits
473       cuda_stream);
474   TF_RETURN_IF_CUDA_ERROR(cuda_ret);
475 
476   // get pointers for easy access
477   const float4* original_boxes =
478       reinterpret_cast<const float4*>(boxes.flat<float>().data());
479   float4* sorted_boxes =
480       reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
481   const int* sorted_indices = d_sorted_indices.flat<int>().data();
482   // sort boxes using indices
483   TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, config.block_count,
484                               config.thread_per_block, 0, device.stream(),
485                               config.virtual_thread_count, sorted_indices,
486                               original_boxes, sorted_boxes));
487   int limited_num_boxes = num_boxes;
488   // filter boxes by scores if nms v3
489   if (score_threshold > std::numeric_limits<float>::lowest()) {
490     GreaterThanCubOp score_limit(score_threshold);
491     TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat<float>().data(),
492                                score_limit, num_boxes, &limited_num_boxes));
493     if (limited_num_boxes == 0) {
494       Tensor* output_indices = nullptr;
495       VLOG(1) << "Number of boxes above score threshold " << score_threshold
496               << " is 0";
497       int len_output = pad_to_max_output ? max_output_size : 0;
498       *num_saved_outputs = 0;
499       TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({len_output}),
500                                                   &output_indices));
501       return Status::OK();
502     } else {
503       VLOG(2) << "Number of boxes above threshold=" << score_threshold << " is "
504               << limited_num_boxes;
505     }
506   }
507   int num_to_keep = 0;
508   // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
509   // flip boxes if necessary!
510   const bool flip_boxes = true;
511   auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
512                        iou_threshold_val, d_selected_indices.flat<int>().data(),
513                        &num_to_keep, context, max_output_size, flip_boxes);
514   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
515   if (!status.ok()) {
516     context->SetStatus(status);
517     return status;
518   }
519   Tensor* output_indices = nullptr;
520   int num_outputs = std::min(num_to_keep, (int)max_output_size);  // no padding!
521   if (pad_to_max_output && num_outputs != max_output_size) {
522     TF_RETURN_IF_ERROR(context->allocate_output(
523         0, TensorShape({max_output_size}), &output_indices));
524     config = GetGpuLaunchConfig(max_output_size, device);
525     TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
526                                 config.thread_per_block, 0, device.stream(),
527                                 config.virtual_thread_count,
528                                 output_indices->flat<int>().data()));
529 
530   } else {
531     TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({num_outputs}),
532                                                 &output_indices));
533   }
534   if (num_outputs == 0) {
535     *num_saved_outputs = num_outputs;
536     return Status::OK();
537   }
538   config = GetGpuLaunchConfig(num_outputs, device);
539   TF_CHECK_OK(GpuLaunchKernel(
540       IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
541       0, device.stream(), config.virtual_thread_count,
542       d_selected_indices.flat<int>().data(), sorted_indices,
543       (*output_indices).flat<int>().data()));
544   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
545   *num_saved_outputs = num_outputs;
546   return Status::OK();
547 }
548 
CheckValidInputs(const Tensor & boxes,const Tensor & scores,const Tensor & max_output_size,const Tensor & iou_threshold)549 Status CheckValidInputs(const Tensor& boxes, const Tensor& scores,
550                         const Tensor& max_output_size,
551                         const Tensor& iou_threshold) {
552   if (!TensorShapeUtils::IsScalar(max_output_size.shape())) {
553     return errors::InvalidArgument("max_output_size must be 0-D, got shape ",
554                                    max_output_size.shape().DebugString(),
555                                    " (Shape must be rank 0 but is ", "rank ",
556                                    max_output_size.dims(), ")");
557   }
558   if (!TensorShapeUtils::IsScalar(iou_threshold.shape())) {
559     return errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
560                                    iou_threshold.shape().DebugString(),
561                                    " (Shape must be rank 0 but is rank ",
562                                    iou_threshold.dims(), ")");
563   }
564   const float iou_threshold_val = iou_threshold.scalar<float>()();
565   if (iou_threshold_val < 0 || iou_threshold_val > 1) {
566     return errors::InvalidArgument("iou_threshold must be in [0, 1]");
567   }
568   if (boxes.dims() != 2) {
569     return errors::InvalidArgument(
570         "boxes must be a rank 2 tensor! (Shape must "
571         "be rank 2 but is rank ",
572         boxes.dims(), ")");
573   }
574   int num_boxes = boxes.dim_size(0);
575   if (boxes.dim_size(1) != 4) {
576     return errors::InvalidArgument(
577         "boxes must be Nx4 (Dimension must be 4 but"
578         " is ",
579         boxes.dim_size(1), ")");
580   }
581   if (scores.dims() != 1) {
582     return errors::InvalidArgument(
583         "scores must be a vector! (Shape must be "
584         "rank 1 but is rank ",
585         scores.dims(), ")");
586   }
587   if (scores.dim_size(0) != num_boxes) {
588     return errors::InvalidArgument(
589         "scores has incompatible shape "        // message must be exactly this
590         "(Dimensions must be equal, but are ",  // otherwise tests fail!
591         num_boxes, " and ", scores.dim_size(0), ")");
592   }
593   return Status::OK();
594 }
595 class NonMaxSuppressionV2GPUOp : public OpKernel {
596  public:
NonMaxSuppressionV2GPUOp(OpKernelConstruction * context)597   explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context)
598       : OpKernel(context) {}
599 
Compute(OpKernelContext * context)600   void Compute(OpKernelContext* context) override {
601     // boxes: [num_boxes, 4]
602     const Tensor& boxes = context->input(0);
603     // scores: [num_boxes]
604     const Tensor& scores = context->input(1);
605     // max_output_size: scalar
606     const Tensor& max_output_size = context->input(2);
607     // iou_threshold: scalar
608     const Tensor& iou_threshold = context->input(3);
609     auto valid =
610         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
611     if (!valid.ok()) {
612       context->SetStatus(valid);
613       return;
614     }
615     int num_boxes = boxes.dim_size(0);
616     if (num_boxes == 0) {
617       Tensor* output_indices = nullptr;
618       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
619                                                        &output_indices));
620       return;
621     }
622     const float iou_threshold_val = iou_threshold.scalar<float>()();
623     const int64_t output_size = max_output_size.scalar<int>()();
624 
625     OP_REQUIRES_OK(
626         context,
627         DoNMS(context, boxes, scores, output_size, iou_threshold_val,
628               /*score_threshold is float lowest if score threshold is disabled*/
629               std::numeric_limits<float>::lowest(),
630               /*pad_to_max_output*/ false, &num_boxes));
631   }
632 };
633 
634 class NonMaxSuppressionV3GPUOp : public OpKernel {
635  public:
NonMaxSuppressionV3GPUOp(OpKernelConstruction * context)636   explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context)
637       : OpKernel(context) {}
638 
Compute(OpKernelContext * context)639   void Compute(OpKernelContext* context) override {
640     // boxes: [num_boxes, 4]
641     const Tensor& boxes = context->input(0);
642     // scores: [num_boxes]
643     const Tensor& scores = context->input(1);
644     // max_output_size: scalar
645     const Tensor& max_output_size = context->input(2);
646     // iou_threshold: scalar
647     const Tensor& iou_threshold = context->input(3);
648     auto valid =
649         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
650     if (!valid.ok()) {
651       context->SetStatus(valid);
652       return;
653     }
654 
655     const Tensor& score_threshold = context->input(4);
656     OP_REQUIRES(
657         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
658         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
659                                 score_threshold.shape().DebugString()));
660     const float score_threshold_val = score_threshold.scalar<float>()();
661     int num_boxes = boxes.dim_size(0);
662     if (num_boxes == 0) {
663       Tensor* output_indices = nullptr;
664       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
665                                                        &output_indices));
666       return;
667     }
668     const float iou_threshold_val = iou_threshold.scalar<float>()();
669     const int64_t output_size = max_output_size.scalar<int>()();
670     OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
671                                   iou_threshold_val, score_threshold_val,
672                                   /*pad_to_max_output*/ false, &num_boxes));
673   }
674 };
675 
676 class NonMaxSuppressionV4GPUOp : public OpKernel {
677  public:
NonMaxSuppressionV4GPUOp(OpKernelConstruction * context)678   explicit NonMaxSuppressionV4GPUOp(OpKernelConstruction* context)
679       : OpKernel(context) {
680     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
681                                              &pad_to_max_output_size_));
682   }
683 
Compute(OpKernelContext * context)684   void Compute(OpKernelContext* context) override {
685     // boxes: [num_boxes, 4]
686     const Tensor& boxes = context->input(0);
687     // scores: [num_boxes]
688     const Tensor& scores = context->input(1);
689     // max_output_size: scalar
690     const Tensor& max_output_size = context->input(2);
691     // iou_threshold: scalar
692     const Tensor& iou_threshold = context->input(3);
693     auto valid =
694         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
695     if (!valid.ok()) {
696       context->SetStatus(valid);
697       return;
698     }
699 
700     const Tensor& score_threshold = context->input(4);
701     OP_REQUIRES(
702         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
703         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
704                                 score_threshold.shape().DebugString()));
705     const float score_threshold_val = score_threshold.scalar<float>()();
706 
707     Tensor* num_outputs_t = nullptr;
708     OP_REQUIRES_OK(context,
709                    context->allocate_output(1, tensorflow::TensorShape({}),
710                                             &num_outputs_t));
711     auto device = context->eigen_gpu_device();
712     int num_boxes = boxes.dim_size(0);
713     if (num_boxes == 0) {
714       Tensor* output_indices = nullptr;
715       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
716                                                        &output_indices));
717       device.memcpy((num_outputs_t->flat<int>().data()), &num_boxes,
718                     sizeof(int));
719       return;
720     }
721 
722     const float iou_threshold_val = iou_threshold.scalar<float>()();
723     const int64_t output_size = max_output_size.scalar<int>()();
724     int num_outputs = 0;
725     OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
726                                   iou_threshold_val, score_threshold_val,
727                                   pad_to_max_output_size_, &num_outputs));
728     device.memcpyHostToDevice((num_outputs_t->flat<int>().data()), &num_outputs,
729                               sizeof(int));
730     return;
731   }
732 
733  private:
734   bool pad_to_max_output_size_;
735 };
736 
737 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
738                             .TypeConstraint<float>("T")
739                             .Device(DEVICE_GPU)
740                             .HostMemory("iou_threshold")
741                             .HostMemory("max_output_size"),
742                         NonMaxSuppressionV2GPUOp);
743 
744 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
745                             .TypeConstraint<float>("T")
746                             .Device(DEVICE_GPU)
747                             .HostMemory("iou_threshold")
748                             .HostMemory("max_output_size")
749                             .HostMemory("score_threshold"),
750                         NonMaxSuppressionV3GPUOp);
751 
752 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
753                             .TypeConstraint<float>("T")
754                             .Device(DEVICE_GPU)
755                             .HostMemory("iou_threshold")
756                             .HostMemory("max_output_size")
757                             .HostMemory("score_threshold"),
758                         NonMaxSuppressionV4GPUOp);
759 
760 }  // namespace tensorflow
761 #endif
762