• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/TensorCompare.h>
4 
5 namespace at::native {
6 
7 namespace {
8 
9 // Composite op implementation for simplicity. This materializes the cross product of elements and test elements,
10 // so it is not very memory efficient, but it is fast on CUDA.
isin_default_kernel_gpu(const Tensor & elements,const Tensor & test_elements,bool invert,const Tensor & out)11 void isin_default_kernel_gpu(
12     const Tensor& elements, const Tensor& test_elements, bool invert, const Tensor& out) {
13   std::vector<int64_t> bc_shape(elements.dim(), 1);
14   bc_shape.push_back(-1);
15   out.copy_(invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1)
16             : elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1));
17 }
18 
19 } // anonymous namespace
20 
21 REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu);
22 
23 } // namespace at::native
24