• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/nested/NestedTensorBinaryOps.h>
2 
3 #include <type_traits>
4 
5 #include <ATen/ATen.h>
6 #include <ATen/Dispatch.h>
7 
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/cuda/detail/IndexUtils.cuh>
11 #include <ATen/native/cuda/Loops.cuh>
12 #include <ATen/native/cuda/MemoryAccess.cuh>
13 
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <c10/cuda/CUDAStream.h>
16 
17 
18 #include <ATen/native/nested/NestedTensorUtils.h>
19 
20 #define BLOCK_DIM 256
21 
22 namespace at {
23 namespace native {
24 
25 
26 // only for nested [B, *, D], dense [B, 1, D]
27 template <typename T, typename func_t>
op_dense_esuhm(const T * input,const T * dense,T * output,int64_t embedding_dim,const int64_t * offsets,const func_t & f)28 __global__ void op_dense_esuhm(
29     const T* input,
30     const T* dense,
31     T* output,
32     int64_t embedding_dim,
33     const int64_t* offsets,
34     const func_t& f)
35 {
36   // each batch is handled by a block
37   const int64_t batch_idx  = blockIdx.x;
38   const int64_t grain_size = blockDim.x;
39   const int64_t tid = threadIdx.x;
40   const int64_t range = offsets[batch_idx + 1] - offsets[batch_idx];
41   // each thread handles (embedding_dim // grain_size + (embedding_dim % grain_size <= tid)) elems
42   // of the dense embedding
43   for (int64_t idx = tid; idx < embedding_dim; idx += grain_size) {
44     const T dense_elem = dense[batch_idx * embedding_dim + idx];
45     for (int64_t nested_idx = idx; nested_idx < range; nested_idx += embedding_dim) {
46       output[offsets[batch_idx] + nested_idx] = f(input[offsets[batch_idx] + nested_idx], dense_elem);
47     }
48   }
49 }
50 
51 template <typename T, typename func_t>
nested_op_dense_kernelLauncher(const T * input,const T * dense,T * output,int64_t batch_size,int64_t embedding_dim,const int64_t * input_offsets,func_t f)52 void nested_op_dense_kernelLauncher(
53     const T* input, // [sum(*) x embedding_dim]
54     const T* dense, // [batch_size x embedding_dim]
55     T* output, // [sum(*) x embedding_dim]
56     int64_t batch_size,
57     int64_t embedding_dim,
58     const int64_t* input_offsets,  // [batch_size]
59     func_t f)
60 {
61   dim3 grid;
62   grid.x = batch_size;
63   const auto stream = at::cuda::getCurrentCUDAStream();
64 
65   op_dense_esuhm<<<grid, BLOCK_DIM, 0, stream>>>(
66       input,
67       dense,
68       output,
69       embedding_dim,
70       input_offsets,
71       f);
72 }
73 
74 template <typename scalar_t, typename func_t>
_nested_op_dense_esuhm_kernel(Tensor & result,const Tensor & self,const Tensor & other,func_t f)75 void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Tensor& other, func_t f) {
76   auto self_ptr = get_nested_tensor_impl(self);
77   auto result_ptr = get_nested_tensor_impl(result);
78 
79   const auto self_buffer = self_ptr->get_buffer();
80   const auto offsets = self_ptr->get_storage_offsets();
81   const auto batch_size = other.size(0);
82   const auto embedding_size = other.size(2);
83 
84   auto result_buffer = result_ptr->get_buffer();
85   auto result_offsets = at::cat({offsets, at::tensor(self_ptr->numel())});
86   result_offsets = result_offsets.to(kCUDA);
87 
88   const scalar_t* self_data_ptr = self_buffer.const_data_ptr<scalar_t>();
89   const scalar_t* other_data_ptr = other.const_data_ptr<scalar_t>();
90   scalar_t* result_data_ptr = result_buffer.data_ptr<scalar_t>();
91   int64_t* result_offsets_ptr = result_offsets.data_ptr<int64_t>();
92 
93   nested_op_dense_kernelLauncher(
94     self_data_ptr,
95     other_data_ptr,
96     result_data_ptr,
97     batch_size,
98     embedding_size,
99     result_offsets_ptr,
100     f);
101 }
102 
_nested_op_dense_esuhm_cuda(Tensor & result,const Tensor & self,const Tensor & other,const NESTED_DENSE_OP & op)103 void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tensor& other, const NESTED_DENSE_OP& op) {
104   AT_DISPATCH_ALL_TYPES_AND2(
105     ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "_nested_op_dense_esuhm", [&]() {
106     switch (op) {
107       case NESTED_DENSE_OP::ADD :
108         _nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a + b; });
109         break;
110       case NESTED_DENSE_OP::MUL :
111         _nested_op_dense_esuhm_kernel<scalar_t>(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a * b; });
112         break;
113     }
114   });
115 }
116 
117 REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda);
118 
119 } // namespace native
120 } // namespace at
121