• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <ATen/cuda/NumericLimits.cuh>
7 #include <cuda.h>
8 #include <cuda_fp16.h>
9 #include <cuda_runtime.h>
10 
11 #include <assert.h>
12 
13 using namespace at;
14 
test()15 __device__ void test(){
16   // test half construction and implicit conversions in device
17   assert(Half(3) == Half(3.0f));
18   assert(static_cast<Half>(3.0f) == Half(3.0f));
19   // there is no float <=> __half implicit conversion
20   assert(static_cast<Half>(3.0f) == 3.0f);
21 
22   __half a = __float2half(3.0f);
23   __half b = __float2half(2.0f);
24   __half c = Half(a) - Half(b);
25   assert(static_cast<Half>(c) == Half(1.0));
26 
27   // asserting if the  functions used on
28   // half types give almost equivalent results when using
29   //  functions on double.
30   // The purpose of these asserts are to test the device side
31   // half API for the common mathematical functions.
32   // Note: When calling std math functions from device, don't
33   // use the std namespace, but just "::" so that the function
34   // gets resolved from nvcc math_functions.hpp
35 
36   float threshold = 0.00001;
37   assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold);
38   assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold);
39   assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold);
40   assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold);
41   assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold);
42   assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold);
43   assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold);
44   assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold);
45   assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold);
46   assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold);
47   assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold);
48   assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold);
49   assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold);
50   assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold);
51   assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold);
52   assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
53   assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
54   assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
55   // See note below about VC++ and isinf
56 #ifndef  _MSC_VER
57   assert(::isinf(::atanh(Half(1.0))));
58 #endif
59   assert(::abs(::atanh(Half(.5)) - ::atanh(.5f)) <= threshold);
60   assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold);
61   assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold);
62   assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
63   assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold);
64   assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold);
65   assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold);
66   assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold);
67   assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold);
68   assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold);
69   assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold);
70   assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold);
71   assert(
72       ::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold);
73   // note: can't use  namespace on isnan and isinf in device code
74 
75   // Windows requires this explicit conversion. The reason is unclear
76   // related issue with clang: https://reviews.llvm.org/D37906
77 #ifndef _MSC_VER
78   assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold);
79   assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold);
80 #endif
81 
82   // test complex<32>
83   Half real = 3.0f;
84   Half imag = -10.0f;
85   auto complex = c10::complex<Half>(real, imag);
86   assert(complex.real() == real);
87   assert(complex.imag() == imag);
88 }
89 
kernel()90 __global__ void kernel(){
91   test();
92 }
93 
launch_function()94 void launch_function(){
95   kernel<<<1, 1>>>();
96   C10_CUDA_KERNEL_LAUNCH_CHECK();
97 }
98 
99 // half common math functions tests in device
TEST(HalfCuda,HalfCuda)100 TEST(HalfCuda, HalfCuda) {
101   if (!at::cuda::is_available()) return;
102   launch_function();
103   cudaError_t err = cudaDeviceSynchronize();
104   bool isEQ = err == cudaSuccess;
105   ASSERT_TRUE(isEQ);
106 }
107