• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/extension.h>
2 
3 // Declare the function from cuda_dlink_extension.cu.
4 void add_cuda(const float* a, const float* b, float* output, int size);
5 
add(at::Tensor a,at::Tensor b)6 at::Tensor add(at::Tensor a, at::Tensor b) {
7   TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor");
8   TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor");
9   TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor");
10   TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor");
11   TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size");
12 
13   at::Tensor output = at::empty_like(a);
14   add_cuda(a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), a.numel());
15 
16   return output;
17 }
18 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)19 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20   m.def("add", &add, "a + b");
21 }
22