#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include // For at::native::index_out #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #include #else #include #include #include #include #endif namespace at::native { static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self, const Tensor & mask) { NoNamesGuard guard; TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); auto mask_temp = (mask.dim() == 0) ? c10::MaybeOwned::owned(mask.unsqueeze(0)) : c10::MaybeOwned::borrowed(mask); auto self_temp = (self.dim() == 0) ? c10::MaybeOwned::owned(self.unsqueeze(0)) : c10::MaybeOwned::borrowed(self); // Cannot reassign to mask_temp and self_temp here! if they are // owning and expand_outplace returns a borrow, the returned borrow // would dangle. auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp); at::cuda::index_out( result, *std::get<1>(mask_self_expanded), c10::List>({*std::move(std::get<0>(mask_self_expanded))})); return result; } Tensor masked_select_cuda(const Tensor & self, const Tensor & mask) { namedinference::compute_broadcast_outnames(self, mask); Tensor result = at::empty({0}, self.options()); return masked_select_out_cuda_impl(result, self, mask); } Tensor & masked_select_out_cuda(const Tensor & self, const Tensor & mask, Tensor & result) { namedinference::compute_broadcast_outnames(self, mask); return masked_select_out_cuda_impl(result, self, mask); } Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& source) { at::assert_no_internal_overlap(self); TORCH_CHECK( self.scalar_type() == source.scalar_type(), "masked_scatter_: expected self and source to have same dtypes but got ", self.scalar_type(), " and ", source.scalar_type()); TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, " "but got mask with dtype ", mask.dtype()); c10::MaybeOwned b_mask = expand_inplace(self, mask, "masked_scatter_"); if (self.numel() == 0) { return self; } auto maskPrefixSum = at::empty(self.sizes(), mask.options().dtype(kLong)); launch_masked_scatter_kernel(self, *b_mask, maskPrefixSum, source); return self; } } // namespace at::native