1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/TensorCompare.h>
4
5 namespace at::native {
6
7 namespace {
8
9 // Composite op implementation for simplicity. This materializes the cross product of elements and test elements,
10 // so it is not very memory efficient, but it is fast on CUDA.
isin_default_kernel_gpu(const Tensor & elements,const Tensor & test_elements,bool invert,const Tensor & out)11 void isin_default_kernel_gpu(
12 const Tensor& elements, const Tensor& test_elements, bool invert, const Tensor& out) {
13 std::vector<int64_t> bc_shape(elements.dim(), 1);
14 bc_shape.push_back(-1);
15 out.copy_(invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1)
16 : elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1));
17 }
18
19 } // anonymous namespace
20
21 REGISTER_CUDA_DISPATCH(isin_default_stub, &isin_default_kernel_gpu);
22
23 } // namespace at::native
24