1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <ATen/cuda/NumericLimits.cuh>
7 #include <cuda.h>
8 #include <cuda_fp16.h>
9 #include <cuda_runtime.h>
10
11 #include <assert.h>
12
13 using namespace at;
14
test()15 __device__ void test(){
16 // test half construction and implicit conversions in device
17 assert(Half(3) == Half(3.0f));
18 assert(static_cast<Half>(3.0f) == Half(3.0f));
19 // there is no float <=> __half implicit conversion
20 assert(static_cast<Half>(3.0f) == 3.0f);
21
22 __half a = __float2half(3.0f);
23 __half b = __float2half(2.0f);
24 __half c = Half(a) - Half(b);
25 assert(static_cast<Half>(c) == Half(1.0));
26
27 // asserting if the functions used on
28 // half types give almost equivalent results when using
29 // functions on double.
30 // The purpose of these asserts are to test the device side
31 // half API for the common mathematical functions.
32 // Note: When calling std math functions from device, don't
33 // use the std namespace, but just "::" so that the function
34 // gets resolved from nvcc math_functions.hpp
35
36 float threshold = 0.00001;
37 assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold);
38 assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold);
39 assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold);
40 assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold);
41 assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold);
42 assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold);
43 assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold);
44 assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold);
45 assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold);
46 assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold);
47 assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold);
48 assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold);
49 assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold);
50 assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold);
51 assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold);
52 assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
53 assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
54 assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
55 // See note below about VC++ and isinf
56 #ifndef _MSC_VER
57 assert(::isinf(::atanh(Half(1.0))));
58 #endif
59 assert(::abs(::atanh(Half(.5)) - ::atanh(.5f)) <= threshold);
60 assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold);
61 assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold);
62 assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
63 assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold);
64 assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold);
65 assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold);
66 assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold);
67 assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold);
68 assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold);
69 assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold);
70 assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold);
71 assert(
72 ::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold);
73 // note: can't use namespace on isnan and isinf in device code
74
75 // Windows requires this explicit conversion. The reason is unclear
76 // related issue with clang: https://reviews.llvm.org/D37906
77 #ifndef _MSC_VER
78 assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold);
79 assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold);
80 #endif
81
82 // test complex<32>
83 Half real = 3.0f;
84 Half imag = -10.0f;
85 auto complex = c10::complex<Half>(real, imag);
86 assert(complex.real() == real);
87 assert(complex.imag() == imag);
88 }
89
kernel()90 __global__ void kernel(){
91 test();
92 }
93
launch_function()94 void launch_function(){
95 kernel<<<1, 1>>>();
96 C10_CUDA_KERNEL_LAUNCH_CHECK();
97 }
98
99 // half common math functions tests in device
TEST(HalfCuda,HalfCuda)100 TEST(HalfCuda, HalfCuda) {
101 if (!at::cuda::is_available()) return;
102 launch_function();
103 cudaError_t err = cudaDeviceSynchronize();
104 bool isEQ = err == cudaSuccess;
105 ASSERT_TRUE(isEQ);
106 }
107