• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 #include <ATen/ATen.h>
3 #include <ATen/native/cuda/Loops.cuh>
4 #include <ATen/native/cuda/MemoryAccess.cuh>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/core/Array.h>
7 
8 using namespace at::native;
9 using namespace at::native::memory;
10 
11 constexpr int buffer_size = 1024;
12 
13 __managed__ double4 buffer1[buffer_size];
14 __managed__ double4 buffer2[buffer_size];
15 
reset_buffers()16 void reset_buffers() {
17   for (int i = 0; i < buffer_size; i++) {
18     buffer1[i].x = i;
19     buffer1[i].y = i + 0.1;
20     buffer1[i].z = i + 0.2;
21     buffer1[i].w = i + 0.3;
22 
23     buffer2[2].x = -i;
24     buffer2[2].y = -(i + 0.1);
25     buffer2[2].z = -(i + 0.2);
26     buffer2[2].w = -(i + 0.3);
27   }
28 }
29 
30 #if defined(USE_ROCM)
TEST(TestLoops,HasSameArgTypes)31 TEST(TestLoops, HasSameArgTypes) {
32   // This is a compile-time unit test. If this file compiles without error,
33   // then the test passes and during runtime, we just need to return.
34   using namespace at::native::modern::detail;
35   using func1_t = int (*)(float, float);
36   using func2_t = int (*)(bool, float, float);
37   using func3_t = int (*)(float);
38   using func4_t = int (*)();
39   static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
40   static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
41   static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
42   static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
43   return;
44 }
45 #endif
46 
TEST(TestVectorizedMemoryAccess,CanVectorizeUpTo)47 TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
48   char *ptr = reinterpret_cast<char *>(buffer1);
49 
50   ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
51   ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
52   ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
53   ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
54   ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
55 
56   ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
57   ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
58 
59   ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 2), 2);
60   ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 2), 2);
61   ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 2), 1);
62 
63   ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 4), 4);
64   ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 4), 4);
65   ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
66   ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
67 
68   ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
69   ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
70   ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
71   ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
72   ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);
73 }
74 
75 // The following kernel copy values by using vectorized policies
76 // defined in `ATen/native/cuda/MemoryAccess.cuh`
77 template <typename scalar_t, int vec_size>
vectorized_copy(scalar_t * dst,scalar_t * src)78 __global__ void vectorized_copy(scalar_t *dst, scalar_t *src) {
79   static_assert(vec_size <= thread_work_size() && thread_work_size() % vec_size == 0, "Invalid vec_size");
80   using array_t = at::detail::Array<char*, 2>;
81   array_t data;
82   data[0] = reinterpret_cast<char *>(dst);
83   data[1] = reinterpret_cast<char *>(src);
84   int idx = blockIdx.x;
85   using vectorized = policies::vectorized<vec_size, array_t>;
86   auto policy = vectorized(data);
87   scalar_t buf[thread_work_size()];
88 #if !defined(USE_ROCM)
89   // This fails only on CUDA 10.x, remove this after CUDA 10.x support is dropped
90   scalar_t *buf_ = &buf[0];
91   auto accessor = [&](int index) -> scalar_t & { return buf_[index]; };
92 #else
93   auto accessor = [&](int index) -> scalar_t & { return buf[index]; };
94 #endif
95   policy.load_single_arg(accessor, src + block_work_size() * blockIdx.x);
96   policy.store(buf, idx);
97 }
98 
TEST(TestVectorizedMemoryAccess,CopyKernel)99 TEST(TestVectorizedMemoryAccess, CopyKernel) {
100   if (!at::cuda::is_available()) {
101     return;
102   }
103 
104   double *b1 = reinterpret_cast<double *>(buffer1);
105   double *b2 = reinterpret_cast<double *>(buffer2);
106 
107   // vec4 copy
108   reset_buffers();
109   cudaDeviceSynchronize();
110   constexpr int total_work_size = buffer_size * 4;
111   vectorized_copy<double, 4><<<total_work_size / block_work_size() , num_threads()>>>(b2, b1);
112   C10_CUDA_KERNEL_LAUNCH_CHECK();
113 
114   ASSERT_EQ(cudaSuccess, cudaDeviceSynchronize());
115   for (int i = 0; i < 1024; i++) {
116     ASSERT_EQ(buffer1[i].x, buffer2[i].x);
117     ASSERT_EQ(buffer1[i].y, buffer2[i].y);
118     ASSERT_EQ(buffer1[i].z, buffer2[i].z);
119     ASSERT_EQ(buffer1[i].w, buffer2[i].w);
120   }
121 
122   // vec2 copy
123   reset_buffers();
124   cudaDeviceSynchronize();
125   vectorized_copy<double, 2><<<total_work_size / block_work_size() , num_threads()>>>(b2, b1);
126   C10_CUDA_KERNEL_LAUNCH_CHECK();
127 
128   ASSERT_EQ(cudaSuccess, cudaDeviceSynchronize());
129   for (int i = 0; i < 1024; i++) {
130     ASSERT_EQ(buffer1[i].x, buffer2[i].x);
131     ASSERT_EQ(buffer1[i].y, buffer2[i].y);
132     ASSERT_EQ(buffer1[i].z, buffer2[i].z);
133     ASSERT_EQ(buffer1[i].w, buffer2[i].w);
134   }
135 
136   // vec1 copy
137   reset_buffers();
138   cudaDeviceSynchronize();
139   vectorized_copy<double, 1><<<total_work_size / block_work_size() , num_threads()>>>(b2, b1);
140   C10_CUDA_KERNEL_LAUNCH_CHECK();
141 
142   ASSERT_EQ(cudaSuccess, cudaDeviceSynchronize());
143   for (int i = 0; i < 1024; i++) {
144     ASSERT_EQ(buffer1[i].x, buffer2[i].x);
145     ASSERT_EQ(buffer1[i].y, buffer2[i].y);
146     ASSERT_EQ(buffer1[i].z, buffer2[i].z);
147     ASSERT_EQ(buffer1[i].w, buffer2[i].w);
148   }
149 
150 // Skipping this part until https://github.com/pytorch/pytorch/issues/51863 is resolved
151 #if 0
152   // unaligned
153   for (int i = 0; i < 16; i++) {
154     for (int j = 0; j < 16; j++) {
155       b1 = reinterpret_cast<double *>(reinterpret_cast<char *>(buffer1) + i);
156       b2 = reinterpret_cast<double *>(reinterpret_cast<char *>(buffer2) + j);
157       (void)cudaGetLastError();
158       cudaDeviceSynchronize();
159       vectorized_copy<double, 4><<<1, num_threads()>>>(b2, b1);
160       cudaDeviceSynchronize();
161       auto err = cudaGetLastError();
162       if (i % 16 == 0 && j % 16 == 0) {
163         ASSERT_EQ(err, cudaSuccess);
164       } else {
165         ASSERT_EQ(err, cudaErrorMisalignedAddress);
166       }
167     }
168   }
169 #endif
170 }
171