1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Utils.h>
7 #include <ATen/cuda/detail/IndexUtils.cuh>
8 #include <ATen/cuda/detail/TensorInfo.cuh>
9 #include <ATen/cuda/CUDAGraphsUtils.cuh>
10 #include <c10/macros/Macros.h>
11 #include <curand_kernel.h>
12 
13 #include <ATen/native/TensorIterator.h>
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/MemoryAccess.cuh>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/_masked_scale_native.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/native_dropout_backward_native.h>
24 #include <ATen/ops/ones_like.h>
25 #include <ATen/ops/zeros_like.h>
26 #endif
27 
28 namespace at::native {
29 
30 namespace {
31 
32 // philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4
33 // for all members of float4 to be consumed UNROLL has to be 4. Don't change!
34 // Note: VEC <= 4 (and in most real-world cases will be 4), so same logic applies.
35 const int UNROLL = 4;
36 
37 template <
38     typename scalar_t,
39     typename accscalar_t,
40     typename IndexType,
41     int ADims,
42     int VEC,
43     typename mask_t>
44 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
45 C10_LAUNCH_BOUNDS_2(256, 4)
46 #endif
47 __global__ void
fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t,IndexType> a,at::cuda::detail::TensorInfo<scalar_t,IndexType> b,at::cuda::detail::TensorInfo<mask_t,IndexType> c,IndexType totalElements,accscalar_t p,PhiloxCudaState philox_args)48 fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType> a,
49                          at::cuda::detail::TensorInfo<scalar_t, IndexType> b,
50                          at::cuda::detail::TensorInfo<mask_t, IndexType> c,
51                          IndexType totalElements, accscalar_t p,
52                          PhiloxCudaState philox_args) {
53   // make sure we don't break assumption that we can't have > 4 elements / thread
54   static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");
55 
56   using LoadT = memory::aligned_vector<scalar_t, VEC>;
57   using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
58 
59   auto seeds = at::cuda::philox::unpack(philox_args);
60   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
61   curandStatePhilox4_32_10_t state;
62   curand_init(std::get<0>(seeds),
63               idx,
64               std::get<1>(seeds),
65               &state);
66 
67   // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
68   // in the vec=2 and vec=4 cases.
69   bool gridxvec_loop_state = 0;
70   accscalar_t scale = 1.0 / p;
71 
72   float4 rand;
73 
74   // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
75   for (IndexType linearIndex = idx * VEC;
76       linearIndex < totalElements;
77       linearIndex += gridDim.x * blockDim.x * VEC) {
78     // local storage
79     scalar_t src[VEC];
80     // We'll use this to actually cause vectorized loads later
81     LoadT *value = reinterpret_cast<LoadT*>(&src);
82 
83     //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
84     // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
85     // sets of rand.
86     if ((VEC == 4) || (gridxvec_loop_state == 0)) {
87       rand = curand_uniform4(&state);
88     } else {
89       // sets up the last two values we generated last iteration to be used this iteration.
90       rand.x = rand.z;
91       rand.y = rand.w;
92       gridxvec_loop_state ^= 1;
93     }
94 
95     rand.x = rand.x < p;
96     rand.y = rand.y < p;
97     if (VEC == 4) {
98       rand.z = rand.z < p;
99       rand.w = rand.w < p;
100     }
101 
102     // Note: We explicitly check for is_contiguous() before launching the vectorized kernel
103     // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)
104     // ordering.
105     // Single vectorized load
106     *value = *reinterpret_cast<const LoadT*>(&a.data[linearIndex]);
107 
108     scalar_t r[VEC];
109     mask_t mask[VEC];
110 
111     // Perform the actual computation
112     #pragma unroll
113     for (int ii = 0; ii < VEC; ii++) {
114       r[ii] = src[ii]*(&rand.x)[ii]*scale;
115       mask[ii] = (mask_t)(&rand.x)[ii];
116     }
117     // Vectorized writes for both mask & result
118     *(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
119     *(reinterpret_cast<MaskLoadT*>(&c.data[linearIndex])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
120 
121     __syncthreads();
122   }
123 }
124 
125 template <
126     typename scalar_t,
127     typename accscalar_t,
128     typename IndexType,
129     int ADims,
130     int BDims = ADims,
131     typename mask_t>
132 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
133 C10_LAUNCH_BOUNDS_2(256, 4)
134 #endif
135 __global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t,IndexType> a,cuda::detail::TensorInfo<scalar_t,IndexType> b,cuda::detail::TensorInfo<mask_t,IndexType> c,IndexType totalElements,accscalar_t p,PhiloxCudaState philox_args)136 fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
137                      cuda::detail::TensorInfo<scalar_t, IndexType> b,
138                      cuda::detail::TensorInfo<mask_t, IndexType> c,
139                      IndexType totalElements, accscalar_t p,
140                      PhiloxCudaState philox_args) {
141   auto seeds = at::cuda::philox::unpack(philox_args);
142   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
143   curandStatePhilox4_32_10_t state;
144   curand_init(std::get<0>(seeds),
145               idx,
146               std::get<1>(seeds),
147               &state);
148   accscalar_t scale = 1.0 / p;
149 
150   IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
151         blockDim.x * gridDim.x * UNROLL;
152   for (IndexType linearIndex = idx;
153        linearIndex < rounded_size;
154        linearIndex += gridDim.x * blockDim.x*UNROLL) {
155 //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
156        float4 rand = curand_uniform4(&state);
157        scalar_t src[UNROLL];
158        rand.x = rand.x < p;
159        rand.y = rand.y < p;
160        rand.z = rand.z < p;
161        rand.w = rand.w < p;
162        for (int ii = 0; ii < UNROLL; ii++) {
163            IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
164            if (li < totalElements) {
165     // Convert `linearIndex` into an offset of `a`
166                const IndexType aOffset =
167                    cuda::detail::IndexToOffset<const scalar_t, IndexType, ADims>::get(li, a);
168                src[ii] = a.data[aOffset];
169            }
170        }
171        for (int ii = 0; ii < UNROLL; ii++) {
172            IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
173            if (li < totalElements) {
174     // Convert `linearIndex` into an offset of `b`
175                const IndexType bOffset =
176                    cuda::detail::IndexToOffset<scalar_t, IndexType, BDims>::get(li, b);
177                b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale;
178                c.data[bOffset] = (mask_t)(&rand.x)[ii];
179            }
180        }
181        __syncthreads();
182   }
183 }
184 
185 template<typename mask_t, typename scalar_t, typename accscalar_t>
masked_scale_kernel(at::Tensor & ret,const at::Tensor & src,const at::Tensor & mask,accscalar_t scale)186 void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tensor& mask, accscalar_t scale){
187    auto iter = at::TensorIteratorConfig()
188      .check_all_same_dtype(false)
189      .add_output(ret)
190      .add_const_input(src)
191      .add_const_input(mask)
192      .build();
193 
194    at::native::gpu_kernel(
195        iter,
196        [=]GPU_LAMBDA(const scalar_t src_val, const mask_t mask_val) -> scalar_t {
197           return (float)mask_val * src_val * scale;
198        });
199 }
200 
201 template <typename scalar_t>
get_vector_size(at::Tensor self,at::Tensor ret,at::Tensor mask)202 int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
203   int vec_size = 4;
204   // get the vector size
205   if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) {
206     vec_size = 1;
207   } else {
208     vec_size = memory::can_vectorize_up_to<scalar_t>((const char*)self.const_data_ptr());
209   }
210 
211   // check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder.
212   bool can_vectorize = true;
213   do {
214     can_vectorize = self.numel() % vec_size == 0 && ret.numel() % vec_size == 0 && mask.numel() % vec_size == 0;
215     if (!can_vectorize) vec_size /= 2;
216   } while (vec_size > 1 && !can_vectorize);
217   return can_vectorize ? vec_size : 1;
218 }
219 
220 template <typename index_type, typename mask_t>
launcher(const Tensor & self,Tensor & ret,Tensor & mask,double p,const int64_t nelem,const PhiloxCudaState rng_engine_inputs,dim3 grid,dim3 dim_block)221 inline void launcher(
222     const Tensor& self,
223     Tensor& ret,
224     Tensor& mask,
225     double p,
226     const int64_t nelem,
227     const PhiloxCudaState rng_engine_inputs,
228     dim3 grid,
229     dim3 dim_block) {
230   AT_DISPATCH_FLOATING_TYPES_AND2(
231       at::ScalarType::Half,
232       at::ScalarType::BFloat16,
233       self.scalar_type(),
234       "fused_dropout",
235       [&] {
236         using accscalar_t = acc_type<scalar_t, true>;
237         accscalar_t pa = (accscalar_t)(p);
238         auto self_info =
239             cuda::detail::getTensorInfo<const scalar_t, index_type>(self);
240         auto ret_info =
241             cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
242         auto mask_info =
243             cuda::detail::getTensorInfo<mask_t, index_type>(mask);
244         self_info.collapseDims();
245         ret_info.collapseDims();
246         mask_info.collapseDims(); // ret and mask are collapsed to 1d
247                                   // contiguous tensor
248 
249         int vec_size = get_vector_size<scalar_t>(self, ret, mask);
250 
251         if (vec_size > 1) {
252           switch (vec_size) {
253             case 4:
254               fused_dropout_kernel_vec<
255                   scalar_t,
256                   accscalar_t,
257                   index_type,
258                   1,
259                   4>
260                   <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
261                       self_info,
262                       ret_info,
263                       mask_info,
264                       nelem,
265                       pa,
266                       rng_engine_inputs);
267               C10_CUDA_KERNEL_LAUNCH_CHECK();
268               break;
269             case 2:
270               fused_dropout_kernel_vec<
271                   scalar_t,
272                   accscalar_t,
273                   index_type,
274                   1,
275                   2>
276                   <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
277                       self_info,
278                       ret_info,
279                       mask_info,
280                       nelem,
281                       pa,
282                       rng_engine_inputs);
283               C10_CUDA_KERNEL_LAUNCH_CHECK();
284               break;
285           }
286         } else {
287           switch (self_info.dims) {
288             case 1:
289               fused_dropout_kernel<scalar_t, accscalar_t, index_type, 1>
290                   <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
291                       self_info,
292                       ret_info,
293                       mask_info,
294                       nelem,
295                       pa,
296                       rng_engine_inputs);
297               C10_CUDA_KERNEL_LAUNCH_CHECK();
298               break;
299             default:
300               if (!self.is_contiguous() && ret.is_contiguous() &&
301                   mask.is_contiguous()) {
302                 fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1, 1>
303                     <<<grid,
304                         dim_block,
305                         0,
306                         at::cuda::getCurrentCUDAStream()>>>(
307                         self_info,
308                         ret_info,
309                         mask_info,
310                         nelem,
311                         pa,
312                         rng_engine_inputs);
313                 C10_CUDA_KERNEL_LAUNCH_CHECK();
314               } else {
315                 fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1>
316                     <<<grid,
317                         dim_block,
318                         0,
319                         at::cuda::getCurrentCUDAStream()>>>(
320                         self_info,
321                         ret_info,
322                         mask_info,
323                         nelem,
324                         pa,
325                         rng_engine_inputs);
326                 C10_CUDA_KERNEL_LAUNCH_CHECK();
327               }
328           }
329         }
330       });
331 }
332 
333 } //anonymous namespace
334 
335 template <typename mask_t>
336 std::tuple<Tensor,Tensor>
dropout_cuda(CUDAGeneratorImpl * gen,const Tensor & self,double p)337 dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){
338   Tensor mask = at::empty_like(self, self.options().dtype(c10::CppTypeToScalarType<mask_t>::value));
339   const int64_t nelem = self.numel();
340   // empty tensors should not get here, but just in case, avoid FPE
341   // non-training shot-cut
342   if (nelem==0) return std::tuple<Tensor,Tensor>(self.clone(), mask);
343 
344   Tensor ret = at::empty_like(self);
345   const int64_t block_size = 256;
346   unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
347   dim3 dim_block(block_size);
348   dim3 grid((nelem + block_size -1)/block_size);
349   grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
350 //number of times random will be generated per thread, to offset philox counter in thc random state
351   int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
352   PhiloxCudaState rng_engine_inputs;
353   {
354     // See Note [Acquire lock when using random generators]
355     std::lock_guard<std::mutex> lock(gen->mutex_);
356     rng_engine_inputs = gen->philox_cuda_state(counter_offset);
357   }
358   if (cuda::detail::canUse32BitIndexMath(self)){
359     launcher<unsigned int, mask_t>(
360         self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
361   } else {
362     launcher<uint64_t, mask_t>(
363         self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
364   }
365   return std::tuple<Tensor,Tensor>(ret, mask);
366 }
367 
368 std::tuple<Tensor,Tensor>
native_dropout_cuda(const Tensor & self,double p,std::optional<bool> train)369 native_dropout_cuda(const Tensor& self, double p, std::optional<bool> train){
370   // short-cut for train == false
371   if (train.has_value() && !train.value()) {
372     return std::make_tuple(self.clone(), at::ones_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value)));
373   }
374   // short-cut
375   if (p == 1) {
376     // native_dropout_cuda is in derivatives.yaml, so we don't need to add data
377     // dependency from output to input for autograd
378     auto ret = at::zeros_like(self);
379     auto mask = at::zeros_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value));
380     return std::tuple<Tensor,Tensor>(ret, mask);
381   }
382 
383   auto gen = get_generator_or_default<CUDAGeneratorImpl>(std::nullopt, cuda::detail::getDefaultCUDAGenerator());
384   double p1m = 1. - p;
385   return dropout_cuda<bool>(gen, self, p1m);
386 }
387 
388 // TODO: _fused_dropout_cuda is to be removed, see PR #63937
389 std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor & self,double p,std::optional<Generator> gen_)390 fused_dropout_cuda(const Tensor& self, double p, std::optional<Generator> gen_){
391   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
392   return dropout_cuda<uint8_t>(gen, self, p);
393 }
394 
395 template <typename mask_t>
dropout_backward_cuda(const Tensor & grad,const Tensor & mask,double scale)396 Tensor dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
397    Tensor ret = at::empty_like(grad, grad.suggest_memory_format());
398    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] {
399       using accscalar_t = acc_type<scalar_t, true>;
400       masked_scale_kernel<mask_t, scalar_t>(ret, grad, mask, (accscalar_t)scale);
401   });
402   return ret;
403 }
404 
native_dropout_backward_cuda(const Tensor & grad,const Tensor & mask,double scale)405 Tensor native_dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
406    TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Mask should be Bool Scalar Type", mask.scalar_type());
407   return dropout_backward_cuda<bool>(grad, mask, scale);
408 }
409 
410 // TODO: masked_scale_cuda is to be removed, see PR #63937
masked_scale_cuda(const Tensor & self,const Tensor & mask,double scale)411 Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
412   TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
413   return dropout_backward_cuda<uint8_t>(self, mask, scale);
414 }
415 
416 } // namespace at::native
417