#define TORCH_ASSERT_NO_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace at::native { template void argmax_kernel_cuda_impl(TensorIterator& iter) { gpu_reduce_kernel( iter, ArgMaxOps{}, thrust::pair( at::numeric_limits::lower_bound(), 0)); }; void argmax_kernel_cuda(TensorIterator& iter) { // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, // we can convert float16 & bfloat16 to float and do all the operations in // float. if (iter.dtype(1) == kHalf) { argmax_kernel_cuda_impl(iter); } else if (iter.dtype(1) == kBFloat16) { argmax_kernel_cuda_impl(iter); } else { AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { argmax_kernel_cuda_impl(iter); }); } } REGISTER_DISPATCH(argmax_stub, &argmax_kernel_cuda); } // namespace at::native