#include #include #include #include #include #include #include #include #include #include #include #include #define BLOCK_DIM 256 namespace at { namespace native { // only for nested [B, *, D], dense [B, 1, D] template __global__ void op_dense_esuhm( const T* input, const T* dense, T* output, int64_t embedding_dim, const int64_t* offsets, const func_t& f) { // each batch is handled by a block const int64_t batch_idx = blockIdx.x; const int64_t grain_size = blockDim.x; const int64_t tid = threadIdx.x; const int64_t range = offsets[batch_idx + 1] - offsets[batch_idx]; // each thread handles (embedding_dim // grain_size + (embedding_dim % grain_size <= tid)) elems // of the dense embedding for (int64_t idx = tid; idx < embedding_dim; idx += grain_size) { const T dense_elem = dense[batch_idx * embedding_dim + idx]; for (int64_t nested_idx = idx; nested_idx < range; nested_idx += embedding_dim) { output[offsets[batch_idx] + nested_idx] = f(input[offsets[batch_idx] + nested_idx], dense_elem); } } } template void nested_op_dense_kernelLauncher( const T* input, // [sum(*) x embedding_dim] const T* dense, // [batch_size x embedding_dim] T* output, // [sum(*) x embedding_dim] int64_t batch_size, int64_t embedding_dim, const int64_t* input_offsets, // [batch_size] func_t f) { dim3 grid; grid.x = batch_size; const auto stream = at::cuda::getCurrentCUDAStream(); op_dense_esuhm<<>>( input, dense, output, embedding_dim, input_offsets, f); } template void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Tensor& other, func_t f) { auto self_ptr = get_nested_tensor_impl(self); auto result_ptr = get_nested_tensor_impl(result); const auto self_buffer = self_ptr->get_buffer(); const auto offsets = self_ptr->get_storage_offsets(); const auto batch_size = other.size(0); const auto embedding_size = other.size(2); auto result_buffer = result_ptr->get_buffer(); auto result_offsets = at::cat({offsets, at::tensor(self_ptr->numel())}); result_offsets = result_offsets.to(kCUDA); const scalar_t* self_data_ptr = self_buffer.const_data_ptr(); const scalar_t* other_data_ptr = other.const_data_ptr(); scalar_t* result_data_ptr = result_buffer.data_ptr(); int64_t* result_offsets_ptr = result_offsets.data_ptr(); nested_op_dense_kernelLauncher( self_data_ptr, other_data_ptr, result_data_ptr, batch_size, embedding_size, result_offsets_ptr, f); } void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tensor& other, const NESTED_DENSE_OP& op) { AT_DISPATCH_ALL_TYPES_AND2( ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "_nested_op_dense_esuhm", [&]() { switch (op) { case NESTED_DENSE_OP::ADD : _nested_op_dense_esuhm_kernel(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a + b; }); break; case NESTED_DENSE_OP::MUL : _nested_op_dense_esuhm_kernel(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a * b; }); break; } }); } REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda); } // namespace native } // namespace at