// Indexing tensors by by tensors // // This corresponds to "advanced indexing" in NumPy. The two operations are: // // index(Tensor self, indices) -> Tensor // index_put_(Tensor self, indices, value, accumulate=false) // // The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte // tensors (boolean masks) are expanded to long tensors via nonzero(). Null // tensors signify that the dimension is not indexed. // // All indexes are broadcast together and iterated as *one*. From NumPy: // // result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M], // ..., ind_N[i_1, ..., i_M]] // // Note 1: ByteTensors expand to index as many dimensions as there are in the // mask. // // Note 2: The behavior is more complicated when the index tensors are not all // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]] // // The code contains two implementations of indexing. The more efficient // implementation treats indexing like an elementwise operation over the // tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does // not work for index_put_ with accumulate=True. The other implementation // combines the indexed tensors into a single linear index that is used // with Tensor.put_. This is used for index_put_ with accumulate=True. // // The more efficient implementation takes the following steps for the // above operation: // // 1) Broadcast ind_1, ind_2, ind_3 together to a common shape // 2) Record x.stride(i) for each indexed dimension `i` // 3) Replace the indexed subspace of `x` with the shape of the corresponding // subspace of `result` but with stride 0 // 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so // that their shape is compatible with the result shape // // The CPU or CUDA kernel then computes element-wise over the broadcasted // and restrided result, x, ind_1, ind_2, etc.: // // result[...] = *(&x[...] + // ind_1[...] * x.stride(1) + // ind_2[...] * x.stride(2) + // ...) // // where & and * represent the C-style address-of and indirection operations. // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #endif #ifdef USE_FBGEMM #include #endif #include #include #include #include #include #include namespace at::native { std::string shapes_as_str(TensorList tensors); AdvancedIndex make_info(Tensor self, IOptTensorListRef orig); } // namespace at::native namespace at::meta { TORCH_META_FUNC(gather) (const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { const Tensor& result = maybe_get_output(0); int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim()); // Memory overlap checks need to be done after resizing (if required) is done. // But it only makes sense to do these checks when result was defined, hence // the boolean variable `check_result` here. // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832 // and https://github.com/pytorch/pytorch/issues/63837 bool check_result = result.defined(); set_output_raw_strided(0, index.sizes(), {}, self.options()); if (check_result) { at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); at::assert_no_partial_overlap(result, index); } auto is_index_empty = index.numel() == 0; if (!is_index_empty) { TORCH_CHECK( index.scalar_type() == at::ScalarType::Long, "gather", "(): Expected dtype int64 for index" ); } if (is_index_empty) return; at::native::gather_shape_check(self, wrapped_dim, index); } template void scatter_meta_impl( Meta& meta, const Tensor& self, int64_t dim, const Tensor& index, const std::optional& src = std::nullopt, const std::optional reduce = std::nullopt) { int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim()); at::native::scatter_gather_dtype_check("scatter", self, index, src); at::native::scatter_shape_check(self, wrapped_dim, index, src); auto output = meta.maybe_get_output(0); if (output.defined()) { at::assert_no_internal_overlap(output); at::assert_no_overlap(output, index); if (src.has_value()) { at::assert_no_overlap(output, src.value()); } } meta.set_output_raw_strided(0, self.sizes(), {}, self.options()); if (reduce.has_value()) { // Check if we have a valid reduce operator. at::native::get_operator_enum(reduce.value(), use_new_options); } } TORCH_META_FUNC2(scatter, src) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { scatter_meta_impl(*this, self, dim, index, src); } TORCH_META_FUNC2(scatter, value) (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) { scatter_meta_impl(*this, self, dim, index); } TORCH_META_FUNC2(scatter, reduce) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const c10::string_view reduce) { TORCH_WARN_ONCE( "The reduce argument of torch.scatter with Tensor src is deprecated and will be removed ", "in a future PyTorch release. Use torch.scatter_reduce instead for more reduction options." ); scatter_meta_impl(*this, self, dim, index, src, reduce); } TORCH_META_FUNC2(scatter, value_reduce) (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src, const c10::string_view reduce) { scatter_meta_impl(*this, self, dim, index, std::nullopt, reduce); } TORCH_META_FUNC(scatter_add) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { scatter_meta_impl(*this, self, dim, index, src, "add"); } TORCH_META_FUNC2(scatter_reduce, two) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const c10::string_view reduce, bool include_self) { (void) include_self; scatter_meta_impl(*this, self, dim, index, src, reduce); } TORCH_PRECOMPUTE_META_FUNC(index_copy) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source) { dim = maybe_wrap_dim(dim, self.dim()); const Tensor& result = maybe_get_output(0); // Memory overlap checks need to be done after resizing (if required) is done. // But it only makes sense to do these checks when result was defined, hence // the boolean variable `check_result` here. // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832 // and https://github.com/pytorch/pytorch/issues/63837 bool check_result = result.defined(); set_output_raw_strided(0, self.sizes(), {}, self.options()); if (check_result) { at::assert_no_internal_overlap(result); at::assert_no_overlap(result, index); at::assert_no_overlap(result, source); } TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")"); int64_t numIndices = index.numel(); if (source.dim() == 0 && numIndices != 1) { TORCH_CHECK_INDEX(false, "index_copy_(): When source is scalar, index should have one element (got ", numIndices, ")"); } else if ((source.dim() != self.dim()) && (source.dim() != 0 && self.dim() != 0)) { TORCH_CHECK_INDEX(false, "index_copy_(): When source and destination are not scalars, their dimensionality must match. Source dimensionality (", source.dim(), "), destination dimensionality (", self.dim(), ")"); } TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_copy_(): Expected a long tensor for index, but got ", index.scalar_type()); TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_copy_(): self and source expected to have the same dtype, but got (self) ", self.scalar_type(), " and (source) ", source.scalar_type()); TORCH_CHECK(self.device() == source.device() && self.device() == index.device(), "index_copy_(): self, index and source expected to be in the same device, but got (self) ", self.device(), ", (index) ", index.device(), ", and (source) ", source.device()); // Check that source and destination slices have the same size auto selfSlicedSizes = self.sizes().vec(); if (!selfSlicedSizes.empty()) { selfSlicedSizes.erase(selfSlicedSizes.begin() + dim); } auto sourceSlicedSizes = source.sizes().vec(); if (!sourceSlicedSizes.empty()) { sourceSlicedSizes.erase(sourceSlicedSizes.begin() + dim); } if (selfSlicedSizes.size() != sourceSlicedSizes.size() || !std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(), sourceSlicedSizes.begin())) { std::stringstream ss; ss << "index_copy_(): Source/destination tensor must have same slice shapes. "; ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim; ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0."; TORCH_CHECK(false, ss.str()); } TORCH_CHECK_INDEX(source.dim() == 0 || numIndices == source.size(dim), "index_copy_(): Number of indices (", numIndices, ") should be equal to source.size(dim) (", source.size(dim), ")"); return TORCH_PRECOMPUTE_STRUCT(index_copy)().set_dim(dim); } template void index_func_meta_impl( Meta& meta, const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, c10::string_view func) { auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, func, "_(): Index is supposed to be a vector, but got dim: ", index.dim(), " with type: ", index.scalar_type(), " and size: ", index.sizes()); TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, func, "_(): Expected dtype int32/int64 for index but got: ", index.scalar_type()); TORCH_CHECK(self.scalar_type() == source.scalar_type(), func, "_(): self (", self.scalar_type(), ") and source (", source.scalar_type(), ") must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < source.dim(), func, "_(): Indexing dim ", dim, " is out of bounds of the source tensor with dim ", source.dim()); TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)), func, "_(): Number of indices (", numel, ") should be equal to source.size(dim): (", source.size(dim), "), for dim: ", dim); auto self_sizes = self.sizes().vec(); auto source_sizes = source.sizes().vec(); if (source.dim() != 0 && self.dim() != 0) { self_sizes.erase(self_sizes.begin() + dim); source_sizes.erase(source_sizes.begin() + dim); } TORCH_CHECK( self_sizes == source_sizes, "source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ", self.sizes(), " source.shape = ", source.sizes()); auto& result = meta.maybe_get_output(0); bool is_defined = result.defined(); meta.set_output_raw_strided(0, self.sizes(), {}, self.options()); if (is_defined) { at::assert_no_internal_overlap(result); at::assert_no_overlap(result, index); at::assert_no_overlap(result, source); } // A hack to run TensorIterator checks in the meta function. // See comment: https://github.com/pytorch/pytorch/pull/65993#discussion_r760307417 // TODO: (@krshrimali) Try inheriting from TensorIteratorBase instead. if (result.device() == kMeta && result.dim() > 0) { auto selfSlice = result.select(dim, 0); auto sourceSlice = source.select(dim, 0); auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); } } TORCH_PRECOMPUTE_META_FUNC(index_add) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha) { dim = maybe_wrap_dim(dim, self.dim()); index_func_meta_impl(*this, self, dim, index, source, "index_add"); return TORCH_PRECOMPUTE_STRUCT(index_add)().set_dim(dim); } TORCH_PRECOMPUTE_META_FUNC(index_reduce) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const c10::string_view reduce, bool include_self) { (void)include_self; TORCH_CHECK(reduce == "prod" || reduce == "mean" || reduce == "amax" || reduce == "amin", "index_reduce(): Expected reduce to be one of prod, mean, amax or amin but got ", reduce, "."); dim = maybe_wrap_dim(dim, self.dim()); index_func_meta_impl(*this, self, dim, index, source, "index_reduce"); return TORCH_PRECOMPUTE_STRUCT(index_reduce)().set_dim(dim); } static void build_index_op( TensorIteratorBase& iter, const at::native::AdvancedIndex& info, const Tensor& result) { // 'TensorIterator' needs to own the things comming from 'info', since // 'info' will be destroyed after the META function. TensorIteratorConfig config; // info.src is a restrided view of result config.set_check_mem_overlap(false) .check_all_same_dtype(false) .add_output(result) .add_owned_const_input(info.src); for (auto& index : info.indices) { config.add_owned_const_input(index); } if (!result.defined()) { config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device()); } iter.build(config); } static void check_indices_on_cpu_or_selfdevice( const Tensor& self, const at::MaterializedIOptTensorListRef& indices) { auto dev = self.device(); bool indices_on_cpu_or_dev = std::all_of( indices.begin(), indices.end(), [=](const at::OptionalTensorRef& opt) { return opt.has_value() ? (opt->is_cpu() || opt->device() == dev) : true; }); TORCH_CHECK( indices_on_cpu_or_dev, "indices should be either on ", kCPU, " or on the same device as the indexed tensor (", dev, ")"); } TORCH_PRECOMPUTE_META_FUNC2(index, Tensor) (const Tensor& self, at::IOptTensorListRef indices) { auto materialized = indices.materialize(); TORCH_CHECK_INDEX( materialized.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", materialized.size(), ")"); // Only allow: `dev_tensor[{cpu,dev}_tensor]`. // See: https://github.com/pytorch/pytorch/pull/69607 check_indices_on_cpu_or_selfdevice(self, materialized); const auto& result = maybe_get_output(); if (result.defined()) { TORCH_CHECK(self.scalar_type() == result.scalar_type(), "index_out: self (", self.scalar_type(), ") and result (", result.scalar_type(), ") must have the same scalar type"); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); for (const at::OptionalTensorRef& index : materialized) { if (index.has_value()) { at::assert_no_overlap(result, *index); } } } auto info = at::native::make_info(self, std::move(indices)); build_index_op(*this, info, result); return TORCH_PRECOMPUTE_STRUCT2(index, Tensor)() .set_sizes(std::move(info.indexed_sizes)) .set_strides(std::move(info.indexed_strides)); } } // namespace at::meta namespace at::native { DEFINE_DISPATCH(index_stub); DEFINE_DISPATCH(index_fill_stub); DEFINE_DISPATCH(index_copy_stub); DEFINE_DISPATCH(index_put_stub); DEFINE_DISPATCH(index_put_with_sort_stub); DEFINE_DISPATCH(put_stub); DEFINE_DISPATCH(take_stub); DEFINE_DISPATCH(masked_fill_stub); REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub); REGISTER_NO_CPU_DISPATCH(index_put_with_sort_quantized_stub); DEFINE_DISPATCH(masked_select_serial_stub); DEFINE_DISPATCH(masked_select_stub); DEFINE_DISPATCH(masked_scatter_stub); DEFINE_DISPATCH(gather_stub); DEFINE_DISPATCH(scatter_stub); DEFINE_DISPATCH(scatter_fill_stub); DEFINE_DISPATCH(scatter_add_stub); DEFINE_DISPATCH(scatter_reduce_stub); DEFINE_DISPATCH(scatter_scalar_reduce_stub); DEFINE_DISPATCH(scatter_reduce_two_stub); DEFINE_DISPATCH(scatter_add_expanded_index_stub); DEFINE_DISPATCH(scatter_reduce_expanded_index_stub); DEFINE_DISPATCH(gather_expanded_index_stub); static bool all_strides_match(TensorList tensors) { TORCH_CHECK(!tensors.empty()); auto strides = tensors[0].strides(); for (auto& tensor : tensors.slice(1)) { if (!strides.equals(tensor.strides())) { return false; } } return true; } inline std::string shapes_as_str(TensorList tensors) { std::ostringstream os; bool first = true; for (auto& tensor : tensors) { if (tensor.defined()) { if (!first) { os << ", "; } os << tensor.sizes(); first = false; } } return os.str(); } // Replace indexed dimensions in src with stride 0 and the size of the result tensor. // The offset in these dimensions is computed by the kernel using the index tensor's // values and the stride of src. The new shape is not meaningful. It's used to make // the shape compatible with the result tensor. static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed, IntArrayRef replacement_shape) { auto shape = DimVector(src.sizes()); auto strides = DimVector(src.strides()); int64_t end = dims_before + dims_indexed; shape.erase(shape.begin() + dims_before, shape.begin() + end); strides.erase(strides.begin() + dims_before, strides.begin() + end); shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end()); strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0); return src.as_strided(shape, strides); } // Add dimensions of size 1 to an index tensor so that it can be broadcast to the result // shape and iterated over element-wise like the result tensor and the restrided src. static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) { auto orig_shape = index.sizes(); auto shape = DimVector(); shape.append(dims_before, 1); shape.append(orig_shape.begin(), orig_shape.end()); shape.append(dims_after, 1); return index.reshape(shape); } AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) { int64_t element_size_bytes = src.element_size(); int64_t dims_before = 0, dims_after = 0, dims_indexed = 0; IntArrayRef replacement_shape; for (const auto dim : c10::irange(indices_list.size())) { if (!indices_list[dim].defined()) { if (dims_indexed == 0) { dims_before++; } else { dims_after++; } } else { dims_indexed++; replacement_shape = indices_list[dim].sizes(); indexed_sizes.push_back(src.size(dim)); indexed_strides.push_back(src.stride(dim) * element_size_bytes); } } // Check if the indexed subspace contains a dim of size 0, but the replacement // shape does not. This implies that an index is out of bounds, because there // is no number that's a valid index for an empty tensor. Normally, out of // bounds is handled in the indexing kernel, but this case fails earlier in // restride_src with an unhelpful error message. if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() && std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) { TORCH_CHECK_INDEX(false, "index is out of bounds for dimension with size 0"); } this->dims_before = dims_before; this->dims_after = dims_after; this->src = restride_src(src, dims_before, dims_indexed, replacement_shape); for (auto& index : indices_list) { if (index.defined()) { indices.push_back(reshape_indexer(index, dims_before, dims_after)); } } // For CUDA/MPS/XPU tensors, force all index tensors to have the same striding to // simplify the CUDA/MPS/XPU kernel. if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS || this->src.device().type() == kXPU)) { if (!all_strides_match(indices)) { for (auto & indice : indices) { indice = indice.contiguous(); } } } } static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) { TORCH_CHECK(is_expandable_to(value.sizes(), info.src.sizes()), "shape mismatch: value tensor of shape ", value.sizes(), " cannot be broadcast to indexing result of shape ", info.src.sizes()); TORCH_CHECK(value.scalar_type() == info.src.scalar_type(), "Index put requires the source and destination dtypes match, " "got ", info.src.scalar_type(), " for the destination " "and ", value.scalar_type(), " for the source."); TensorIteratorConfig config; // info.src is restrided by restride_src with 0 strided dimensions config.set_check_mem_overlap(false); config.resize_outputs(false); config.check_all_same_dtype(false); config.add_output(info.src); config.add_const_input(value); for (auto& index : info.indices) { config.add_const_input(index); } return config.build(); } TORCH_IMPL_FUNC(index_out) (const Tensor& self, DimVector sizes, DimVector strides, const Tensor& result) { index_stub(device_type(), *this, sizes, strides); } Tensor quantized_index(const Tensor & self, const torch::List>& indices) { TORCH_INTERNAL_ASSERT( self.qscheme() == c10::kPerTensorAffine || self.qscheme() == c10::kPerTensorSymmetric, "Indexing is only supported for per-Tensor quantized Tensors."); // For now, this is a naive implementation which does dq -> index -> q. // TODO(future PR): improve performance by removing the copies. const auto& self_dq = self.dequantize(); auto result = at::index(self_dq, indices); return at::quantize_per_tensor( result, self.q_scale(), self.q_zero_point(), self.scalar_type()); } Tensor _unsafe_index(const Tensor& self, const torch::List>& indices) { // Disallow boolean indexing since it leads to dynamic output shapes for (auto i : c10::irange(indices.size())) { auto index = indices.get(i); if (index.has_value()) { auto dtype = index->scalar_type(); TORCH_CHECK(dtype == kLong || dtype == kInt, "_unsafe_index found unexpected index type ", dtype); } } return at::index(self, indices); } Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Scalar& fill) { // Unsafe masked index is equivalent to // where(mask, self[indices], fill) // with the main difference being that the when the `mask` is false, the tensor // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds // when `mask` is false. When `mask` is true, the `indices` are expected to be // in bounds and is not checked. // // This function is not meant to be executed on eager mode. An unoptimized version // is provided here. // // compiler backends should implement this op such that `self[indices]` is not // loaded when `mask` is true. See inductor for a reference. auto clamp = [](const std::optional& index, auto size) -> std::optional { if (!index) { return index; } // Disallow bool auto dtype = index->scalar_type(); TORCH_CHECK(dtype == kLong || dtype == kInt, "_unsafe_masked_index found unexpected index type ", dtype); return at::clamp(*index, -size, size - 1); }; torch::List> clamped_indices(indices); std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); if (self.numel() == 0) { // Returns a tensor filled with `fill` value // We use a hack here since we do not have a method to get the // correct size of the tensor. (except with meta impl which is // not available on mobile builds) std::vector new_sizes(self.dim()); auto compute_new_size = [](const std::optional& index, auto size) -> int64_t { if (index && size == 0) { return 1; } else { return size; } }; std::transform(indices.begin(), indices.end(), self.sizes().begin(), new_sizes.begin(), compute_new_size); auto result = self.new_full(new_sizes, fill); return at::_unsafe_index(result, clamped_indices); } auto result = at::_unsafe_index(self, clamped_indices); return result.masked_fill(at::logical_not(mask), fill); } Tensor _unsafe_masked_index_put_accumulate(const Tensor& self, const Tensor& mask, const torch::List>& indices, const Tensor& values) { // This is the backward of _unsafe_masked_index. // This function is not meant to be executed on eager mode. if (self.numel() == 0) { return self.clone(); } // We recompute the clamped indices and rely on inductor to CSE the computation auto clamp = [](const std::optional& index, auto size) -> std::optional { if (!index) { return index; } // Disallow bool auto dtype = index->scalar_type(); TORCH_CHECK(dtype == kLong || dtype == kInt, "_unsafe_masked_index found unexpected index type ", dtype); return at::clamp(*index, -size, size - 1); }; torch::List> clamped_indices(indices); std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp); auto masked_value = values.masked_fill(at::logical_not(mask), 0); return at::_unsafe_index_put(self, clamped_indices, masked_value, true); } Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { // See note [Writing Nondeterministic Operations] // Nondeterministic when index contains duplicate entries and we do not accumulate // If we accumulate on GPU, we use atomicGPUAdd, which is non-deterministic if (!accumulate || (accumulate && self.device().type() == DeviceType::CUDA)) { at::globalContext().alertNotDeterministic("put_"); } // Type and device checks TORCH_CHECK(index.scalar_type() == ScalarType::Long, "put_(): Expected a long tensor for index, but got ", index.scalar_type()) TORCH_CHECK(self.scalar_type() == source.scalar_type(), "put_(): self and source expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and source.dtype = ", source.scalar_type()); TORCH_CHECK(self.device() == source.device() && self.device() == index.device(), "put_(): self, index and source expected to be in the same device, but got self.device = ", self.device(), ", index.device = ", index.device(), ", and source.device = ", source.device()); // index checks TORCH_CHECK_INDEX(source.numel() == index.numel(), "put_(): Expected source and index to have the same number of elements, but got source.numel() = ", source.numel(), ", index.numel() = ", index.numel()); TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "put_(): Tried to put elements into an empty tensor"); at::assert_no_internal_overlap(self); at::assert_no_overlap(self, index); at::assert_no_overlap(self, source); // Early return if (index.numel() == 0) { return self; } auto index_reshaped = index.reshape(source.sizes()); // Do not iterate over self, we will compute the offsets manually auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .add_const_input(source) .add_const_input(index_reshaped) .build(); put_stub(iter.device_type(), iter, self, accumulate); return self; } Tensor put(const Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) { return self.clone(at::MemoryFormat::Preserve).put_(index, source, accumulate); } Tensor index_put(const Tensor & self, const torch::List>& indices, const Tensor & value, bool accumulate) { return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate); } Tensor _unsafe_index_put(const Tensor& self, const torch::List>& indices, const Tensor& value, bool accumulate) { return at::index_put(self, indices, value, accumulate); } Tensor & _index_put_impl_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate, const bool unsafe) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); if (at::has_internal_overlap(self) == MemOverlap::Yes) { TORCH_WARN( "Use of index_put_ on expanded tensors is deprecated. " "Please clone() the tensor before performing this operation. " "This also applies to advanced indexing e.g. tensor[indices] = tensor"); } if (!accumulate) { auto masked_fill_dispatch = canDispatchToMaskedFill(self, indices, value); if (std::get<0>(masked_fill_dispatch)) { return self.masked_fill_(std::get<1>(masked_fill_dispatch), value.item()); } } auto value_ = value; if (value.device() != self.device() && value.numel() == 1 && value.dim() == 0) { value_ = value.to(self.device()); } at::assert_no_overlap(self, value); // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) for (const std::optional& index: indices) { if (index.has_value()) { at::assert_no_overlap(self, *index); } } if ((self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU) && (accumulate || globalContext().deterministicAlgorithms())) { TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ", value_.device(), " for value tensor"); index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe); return self; } auto info = make_info(self, indices); auto iter = make_index_put_iterator(info, value_); index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate); return self; } Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) { // Type and device checks TORCH_CHECK(index.scalar_type() == ScalarType::Long, "take(): Expected a long tensor for index, but got ", index.scalar_type()) TORCH_CHECK(self.scalar_type() == out.scalar_type(), "take(): self and out expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and out.dtype = ", out.scalar_type()); TORCH_CHECK(self.device() == out.device() && self.device() == index.device(), "take(): self, index and out expected to be in the same device, but got self.device = ", self.device(), ", index.device = ", index.device(), ", and out.device = ", out.device()); // index checks TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "take(): tried to take from an empty tensor"); at::assert_no_internal_overlap(out); at::assert_no_overlap(out, index); at::assert_no_overlap(out, self); // Do not iterate over self, we will compute the offsets manually // out is resized inside tensor_iterator auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .add_output(out) .add_const_input(index) .build(); // Early return after out has been resized if (index.numel() == 0) { return out; } take_stub(iter.device_type(), iter, self); return out; } Tensor take(const Tensor& self, const Tensor& index) { auto out = at::empty(index.sizes(), self.options()); at::native::take_out(self, index, out); return out; } Tensor & index_put_(Tensor & self, const torch::List>& indices, const Tensor & value, const bool accumulate) { return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false); } TORCH_IMPL_FUNC(index_copy_out) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Tensor& result) { if (!result.is_same(self)) result.copy_(self); // See Note [Enabling Deterministic Operations] if (result.is_cuda() && globalContext().deterministicAlgorithms()){ torch::List> indices; indices.reserve(dim + 1); for (const auto i: c10::irange(dim)) { (void)i; indices.emplace_back(); } indices.emplace_back(index); result.index_put_(indices, source, false); return; } // Handle the case when self / source is 0-dim Tensor result_nonzero = result.dim() == 0 ? result.unsqueeze(0) : result; Tensor source_nonzero = source.dim() == 0 ? source.unsqueeze(0) : source; // The only difference between the following tensor iterator and that of index_fill_ is that // this one has also source as an input. We should refactor it when if constexpr is available (C++17) // Prepare `index` for TensorIterator. // It is restrided to be broadcastable over `self` in TensorIterator. auto index_sizes = std::vector(result_nonzero.dim(), 1); auto index_strides = std::vector(result_nonzero.dim(), 0); index_sizes[dim] = index.numel(); index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar auto index_restrided = index.as_strided( index_sizes, index_strides); // Prepare `result` for TensorIterator. // Restride `result` to not advance in dimension `dim`. // We do not use squash_dim here because `index` will // need to advance in this dimension. // Note that self_sizes[dim] is set to index.numel(). // This is done so that self_sizes[dim] and index_sizes[dim] // match as required by TensorIterator (input shape should // strictly broadcast over output shape, i.e. // output.shape[i] >= input.shape[i] for i in range(dims)). auto result_sizes = result_nonzero.sizes().vec(); auto result_strides = result_nonzero.strides().vec(); result_sizes[dim] = index.numel(); result_strides[dim] = 0; auto result_restrided = result_nonzero.as_strided(result_sizes, result_strides); auto iter = TensorIteratorConfig() // We do not check for overlap because `result` is restrided // with zero stride. Zero strides trigger memory overlap assert // within TensorIterator. .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .add_output(result_restrided) .add_const_input(index_restrided) .add_const_input(source_nonzero) .build(); auto result_dim_size = result_nonzero.size(dim); auto result_dim_stride = result_nonzero.stride(dim); index_copy_stub( iter.device_type(), iter, dim, result_dim_size, result_dim_stride); } // Not calling into index_reduce_func_impl because of a different dtype dispatch TORCH_IMPL_FUNC(index_add_cpu_out) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) { if (!result.is_same(self)) { result.copy_(self); } auto numel = index.numel(); auto index_contig = index.contiguous(); if (result.dim() > 1) { // Equivalent to: // for (const auto i : c10::irange(numel)) { // auto selfSlice = self.select(dim, index_data[i]); // auto sourceSlice = source.select(dim, i); // selfSlice.add_(sourceSlice); // } // But much faster as this reuses the iterator from add_ if (numel == 0 || self.numel() == 0) { return; } dim = maybe_wrap_dim(dim, self.dim()); // When the slice of source or result is noncontiguous, // original index_add is slow as it uses add for the sliced tensor, // which is serial on index and parallel on sliced tensor to avoid write conflict. // Doing parallel on the sliced tensor is not optimal as the size of sliced tensor // may be not big enough to parallel and also causes multiple parallelizations. // scatter_add is used to speedup for this case as scatter_add parallels on // the outer dimension of input and is serial on the inner dimension to // avoid write conflict. scatter_add only need one parallel and the size of // outer dimensions is bigger to do parallel. if ((dim == 0 || dim == self.dim() - 1) && // Data type of index should be long and alpha should be 1 to use scatter_add. alpha.equal(1.0) && index_contig.scalar_type() == ScalarType::Long && // scatter_add does not support ComplexHalf source.scalar_type() != ScalarType::ComplexHalf && result.scalar_type() != ScalarType::ComplexHalf) { std::vector ep_sizes(result.sizes().size()); std::vector ep_strides(source.sizes().size()); // Check whether result and source are matched apart from the dimension dim. // Note that the broadcast case: // source.select(dim, i) is broadcast for result.select(dim, index_data[i]) // The broadcast case is not applicable for scatter_add auto check_sizes = [&ep_sizes, &ep_strides, &numel](IntArrayRef a, IntArrayRef b, int64_t dim) -> bool { ep_sizes[dim] = numel; ep_strides[dim] = 1; for (const int64_t i : c10::irange(a.size())) { if (i == dim) { continue; } if (a[i] != b[i]) { return false; } ep_sizes[i] = a[i]; ep_strides[i] = 0; } return true; }; if (check_sizes(result.sizes(), source.sizes(), dim)) { auto ep_index = index_contig.as_strided(ep_sizes, ep_strides); result.scatter_add_(dim, ep_index, source); return; } } auto selfSlice = result.select(dim, 0); auto sourceSlice = source.select(dim, 0); auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type()); auto self_dim_size = result.size(dim); auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; auto source_data = static_cast(sourceSlice.const_data_ptr()) + i * source_stride_bytes; iter.unsafe_replace_operand(0, self_data); iter.unsafe_replace_operand(1, self_data); iter.unsafe_replace_operand(2, const_cast(source_data)); add_stub(iter.device_type(), iter, alpha); } }); } else { TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, ScalarType::ComplexHalf, result.scalar_type(), "index_add_", [&result, &source, &dim, &index_contig, &numel, &alpha] { auto alpha_value = alpha.to(); auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); // TODO: Maybe TensorAccessor can be used here? auto* result_ptr = result.data_ptr(); auto* source_ptr = source.const_data_ptr(); AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_", [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self"); scalar_t *self_ip = result_ptr + self_i * result_stride; *self_ip += *(source_ptr + i * source_stride) * alpha_value; } }); }); } } static void index_reduce_func_impl( const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, bool include_self, const Tensor& result, const ReductionType& op) { if (!result.is_same(self)) result.copy_(self); if (!include_self) { AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce_func_exclude_input_init", [&] { scalar_t init_val; switch (op) { case ReductionType::PROD: init_val = (scalar_t)1; break; case ReductionType::MAX: init_val = std::numeric_limits::has_infinity ? -std::numeric_limits::infinity() : std::numeric_limits::lowest(); break; case ReductionType::MIN: init_val = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); break; default: init_val = (scalar_t)0; break; } // index_fill_ requires index to be a LongTensor result.index_fill_(dim, index.to(at::ScalarType::Long), init_val); }); } auto numel = index.numel(); auto index_contig = index.contiguous(); if (result.dim() > 1) { // Equivalent to: // for (const auto i : c10::irange(numel)) { // auto selfSlice = self.select(dim, index_data[i]); // auto sourceSlice = source.select(dim, i); // selfSlice.op_(sourceSlice); // } // But much faster as this reuses the iterator from the binary op if (numel == 0) { return; } auto selfSlice = result.select(dim, 0); auto sourceSlice = source.select(dim, 0); auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type()); auto self_dim_size = result.size(dim); auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice); AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); auto self_data = static_cast(selfSlice.data_ptr()) + self_i * self_stride_bytes; auto source_data = static_cast(sourceSlice.const_data_ptr()) + i * source_stride_bytes; iter.unsafe_replace_operand(0, self_data); iter.unsafe_replace_operand(1, self_data); iter.unsafe_replace_operand(2, const_cast(source_data)); switch (op) { case ReductionType::PROD : mul_stub(iter.device_type(), iter); break; case ReductionType::MIN : minimum_stub(iter.device_type(), iter); break; case ReductionType::MAX : maximum_stub(iter.device_type(), iter); break; default : add_stub(iter.device_type(), iter, 1); break; } } }); if (op == ReductionType::MEAN) { auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); counts.index_add_(dim, index, at::ones_like(source)); counts.masked_fill_(counts == 0, 1); if (result.is_floating_point() || result.is_complex()) { result.div_(counts); } else { result.div_(counts, "floor"); } } } else { TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); auto counts = include_self ? at::ones_like(result) : at::zeros_like(result); // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, result.scalar_type(), "index_func_", [&result, &source, &dim, &index_contig, &numel, &op, &counts] { auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto source_stride = source.dim() == 0 ? 1 : source.stride(dim); auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim); // TODO: Maybe TensorAccessor can be used here? auto* result_ptr = result.data_ptr(); auto* source_ptr = source.const_data_ptr(); auto counts_ptr = counts.data_ptr(); AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_", [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self"); scalar_t *self_ip = result_ptr + self_i * result_stride; scalar_t *count_ip; scalar_t val; switch (op) { case ReductionType::MEAN : *self_ip += *(source_ptr + i * source_stride); count_ip = counts_ptr + self_i * counts_stride; *count_ip += 1; break; case ReductionType::PROD : *self_ip *= *(source_ptr + i * source_stride); break; case ReductionType::MIN : val = *(source_ptr + i * source_stride); *self_ip = at::_isnan(val) ? val : std::min(*self_ip, val); break; case ReductionType::MAX : val = *(source_ptr + i * source_stride); *self_ip = at::_isnan(val) ? val : std::max(*self_ip, val); break; default: break; } } }); }); if (op == ReductionType::MEAN) { counts.masked_fill_(counts == 0, 1); if (result.is_floating_point() || result.is_complex()) { result.div_(counts); } else { result.div_(counts, "floor"); } } } } TORCH_IMPL_FUNC(index_reduce_cpu_out) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const c10::string_view reduce, bool include_input, const Tensor& result) { TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time."); auto op = get_operator_enum(reduce, true); index_reduce_func_impl(self, dim, index, source, include_input, result, op); } // Check that indices fall within dimension array size // Avoid redispatch call to min/max template static void check_indexarray_range( const IndexType* indices, int64_t n, IndexType indexing_axis_dim) { for (const auto i : c10::irange(n)) { auto idx = indices[i]; TORCH_CHECK( 0 <= idx && idx < indexing_axis_dim, "INDICES element is out of DATA bounds, id=", idx, " axis_dim=", indexing_axis_dim); } } static Tensor & index_select_out_cpu_dim1_( Tensor & result_contig, const Tensor & self, const Tensor & index_contig) { auto self_contig = self.contiguous(); const caffe2::TypeMeta dataType = self_contig.dtype(); size_t item_bytesize = dataType.itemsize(); auto out = static_cast(result_contig.data_ptr()); auto src_base = static_cast(self_contig.const_data_ptr()); auto self_sizes = self_contig.sizes(); auto outer_dims_product = c10::size_to_dim_(1, self_sizes); auto block_size = c10::size_from_dim_(2, self_sizes); auto block_bytesize = block_size * item_bytesize; auto src_indexing_axis_dim = self_sizes[1]; auto src_batch_bytesize = self_sizes[1] * block_bytesize; auto N = index_contig.numel(); auto gathered_batch_bytesize = N * block_bytesize; AT_DISPATCH_INDEX_TYPES( index_contig.scalar_type(), "batch_index_select_compute", [&]() { const auto* idxs = index_contig.const_data_ptr(); check_indexarray_range(idxs, N, src_indexing_axis_dim); // Special-case single-float copy for efficiency if (self.scalar_type() == ScalarType::Float && block_size == 1) { for (const auto batch : c10::irange(outer_dims_product)) { const float* src_floats = (const float*)(src_base + batch * src_batch_bytesize); float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); for (const auto i : c10::irange(N)) { auto idx = idxs[i]; dst_floats[i] = src_floats[idx]; } } } else { // outer_dims_product specifies how many times we repeat inner dimensions, // so we just iterate over it to cover all outer dimensions. for (const auto batch : c10::irange(outer_dims_product)) { for (const auto i : c10::irange(N)) { auto idx = idxs[i]; auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; memcpy(dst, src, block_bytesize); } } } }); return result_contig; } Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor & index, Tensor & result) { if (self.is_quantized()) { TORCH_CHECK( self.qscheme() == kPerTensorAffine, "Only per_tensor quantized quantized tensors are supported by index_select.") } dim = maybe_wrap_dim(dim, self.dim()); auto numel = index.numel(); TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); TORCH_CHECK(!(self.dim() == 0 && numel != 1), "index_select(): Index to scalar can have only 1 value, got ", numel, " value(s)"); TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "index_select(): self and result must have the same scalar type"); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); at::assert_no_overlap(result, index); auto result_size = self.sizes().vec(); if (self.dim() > 0) { result_size[dim] = numel; } at::native::resize_output(result, result_size); auto index_contig = index.contiguous(); if (self.dim() > 1) { if (numel == 0) { return result; } if (self.numel() == 0) { auto src_indexing_axis_dim = self.size(dim); TORCH_CHECK(src_indexing_axis_dim > 0, "index_select(): self indexing axis dim should be positive"); AT_DISPATCH_INDEX_TYPES( index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() { const auto* idxs = index_contig.const_data_ptr(); check_indexarray_range(idxs, numel, src_indexing_axis_dim); }); return result; } if (dim == 1 && result.is_contiguous()) { // fast pass return index_select_out_cpu_dim1_(result, self, index_contig); } auto selfSlice = self.select(dim, 0); auto resultSlice = result.select(dim, 0); auto selfSlice_data = selfSlice.const_data_ptr(); auto resultSlice_data = resultSlice.data_ptr(); auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type()); auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); auto self_dim_size = self.size(dim); auto slice_size = selfSlice.numel(); auto iter = TensorIteratorConfig() .check_all_same_dtype(false) .resize_outputs(false) .add_output(resultSlice) .add_const_input(selfSlice) .build(); auto grain_size = at::internal::GRAIN_SIZE; auto outer_loop = // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) { auto sub_iter = TensorIterator(iter); AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, &result_stride_bytes] () { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(start, end)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; sub_iter.unsafe_replace_operand(0, result_data); sub_iter.unsafe_replace_operand(1, const_cast(self_data)); copy_stub(sub_iter.device_type(), sub_iter, false); }; }); }; // parallel on inner loop in case the slice is large enough; // otherwise parallel on outer loop if (slice_size >= grain_size) { outer_loop(0, numel); } else { // use a fast loop when self and result are contiguous and of the same data type if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) { auto slice_size_bytes = slice_size * elementSize(self.scalar_type()); // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda at::parallel_for(0, numel, grain_size / slice_size, [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) { AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(start, end)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); auto self_data = static_cast(selfSlice_data) + self_i * self_stride_bytes; auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; memcpy(result_data, self_data, slice_size_bytes); } }); }); } else { at::parallel_for(0, numel, grain_size / slice_size, outer_loop); } } } else { TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")"); // explicitly capture all required variables to work around windows build // TODO: fix this when windows can correctly capture variables in nested lambda if(self.is_quantized()){ AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto self_data_ptr = self.const_data_ptr(); auto result_data_ptr = result.data_ptr(); auto self_numel = self.numel(); AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_", [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); const scalar_t *self_ip = self_data_ptr + self_i * self_stride; *(result_data_ptr + i * result_stride) = *self_ip; } }); }); } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, self.scalar_type(), "index_select", [&index_contig, &self, &result, &dim, &numel] { auto self_stride = self.dim() == 0 ? 1 : self.stride(dim); auto result_stride = result.dim() == 0 ? 1 : result.stride(dim); auto self_data_ptr = self.const_data_ptr(); auto result_data_ptr = result.data_ptr(); auto self_numel = self.numel(); AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_", [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] { auto index_data = index_contig.const_data_ptr(); for (const auto i : c10::irange(numel)) { auto self_i = index_data[i]; TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self"); const scalar_t *self_ip = self_data_ptr + self_i * self_stride; *(result_data_ptr + i * result_stride) = *self_ip; } }); }); } } return result; } Tensor index_select_cpu_(const Tensor & self, int64_t dim, const Tensor & index) { Tensor result = at::empty({0}, self.options()); return at::native::index_select_out_cpu_(self, dim, index, result); } Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tensor & index) { TORCH_CHECK(self.qscheme() == kPerTensorAffine, "Only per_tensor quantized quantized tensors are supported by index_select.") Tensor result = at::empty_quantized({0}, self); return at::native::index_select_out_cpu_(self, dim, index, result); } Tensor index_select_backward_symint(const Tensor& grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor& index) { // for composite compliance, use out-of-place variant of // `index_add` if index tensor is a Tensor Subclass. if (isTensorSubclassLike(index)) { return grad.new_zeros_symint(self_sizes, grad.options()).index_add(dim, index, grad); } return grad.new_zeros_symint(self_sizes, grad.options()).index_add_(dim, index, grad); } Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) { at::NoNamesGuard guard; TORCH_CHECK_INDEX( index.scalar_type() == ScalarType::Long, "index_fill_(): Expected dtype int64 for index."); at::assert_no_overlap(self, index); if (at::has_internal_overlap(self) == at::MemOverlap::Yes) { TORCH_WARN( "Use of index_fill_ on expanded tensors is deprecated. " "Please clone() the tensor before performing this operation. " "This also applies to advanced indexing e.g. tensor[mask] = scalar"); } if (!self.is_complex() && source.isComplex()) { TORCH_CHECK(false, "index_fill_(): Converting complex Scalar to non-complex type is not supported"); } // Handle the case when `self` is 0-dim Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; dim = at::maybe_wrap_dim(dim, self_nonzero_dim); TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar"); // Prepare `index` for TensorIterator. // It is restrided to be broadcastable over `self` in TensorIterator. auto index_sizes = std::vector(self_nonzero_dim.dim(), 1); auto index_strides = std::vector(self_nonzero_dim.dim(), 0); index_sizes[dim] = index.numel(); index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar auto index_restrided = index.as_strided( index_sizes, index_strides); // Prepare `self` for TensorIterator. // Restride `self` to not advance in dimension `dim`. // We do not use squash_dim here because `index` will // need to advance in this dimension. // Note that self_sizes[dim] is set to index.numel(). // This is done so that self_sizes[dim] and index_sizes[dim] // match as required by TensorIterator (input shape should // strictly broadcast over output shape, i.e. // output.shape[i] >= input.shape[i] for i in range(dims)). auto self_sizes = self_nonzero_dim.sizes().vec(); auto self_strides = self_nonzero_dim.strides().vec(); self_sizes[dim] = index.numel(); self_strides[dim] = 0; auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides); auto iter = TensorIteratorConfig() // We do not check for overlap because `self` is restrided // with zero stride. Zero strides trigger memory overlap assert // within TensorIterator. .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) .add_output(self_restrided) .add_const_input(index_restrided) .build(); auto self_dim_size = (self_nonzero_dim.sizes())[dim]; auto self_dim_stride = (self_nonzero_dim.strides())[dim]; index_fill_stub( iter.device_type(), iter, dim, self_dim_size, self_dim_stride, source); return self; } Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { TORCH_CHECK(source.dim() == 0, "index_fill_ only supports a 0-dimensional value tensor, but got tensor " "with ", source.dim(), " dimension(s)."); return self.index_fill_(dim, index, source.item()); } Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) { return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source); } Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source); } // fast paths for GNN usage static bool can_use_expanded_index_path( const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, bool is_scatter_like) { #ifdef USE_FBGEMM if (!fbgemm::is_radix_sort_accelerated_with_openmp()) { return false; } #else return false; #endif if (!self.device().is_cpu()) { return false; } const auto st = self.scalar_type(); if (!(c10::isFloatingType(st))) { return false; } // skip when having empty tensor if (self.numel() == 0 || index.numel() == 0 || src.numel() == 0) { return false; } // skip when having scalar tensor if (self.ndimension() == 0 || index.ndimension() == 0 || src.ndimension() == 0) { return false; } // allow only different size on dim 0 for src and index // https://github.com/pytorch/pytorch/issues/99595 for (const auto dim : c10::irange(1, index.dim())) { if (src.size(dim) != index.size(dim)) { return false; } } if (is_scatter_like) { // using `spmm` for scatter would require sorting on index, // this is only perf beneficial when the inner dimension, aka, `channels` // is big enough. constexpr int64_t threshold = 16; if (index.numel() / index.size(0) < threshold) { return false; } } // usually the expanded index has stride on the first dimension to be 1, // and strides on other dims to be 0 or 1, e.g. // shape [108365, 16]; strides [1, 0] // shape [13264, 1, 7]; strides [1, 1, 0] auto index_strides = index.strides().vec(); bool is_index_expanded = index_strides[0] == 1; for (const auto dim : c10::irange(1, index_strides.size())) { if (index_strides[dim] > 1) { is_index_expanded = false; } } // index is expanded return dim == 0 && is_index_expanded && src.is_contiguous() && self.is_contiguous(); } // gather_out_cpu_cuda TORCH_IMPL_FUNC(gather_out) (const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) { if (index.numel() == 0) return; dim = at::maybe_wrap_dim(dim, self.dim()); if (can_use_expanded_index_path(result, dim, index, self, /*is_scatter_like=*/false)) { gather_expanded_index_stub(result.device().type(), result, self, index); } else { gather_stub(result.device().type(), result, self, dim, index); } } Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) { if (sparse_grad) { return at::_gather_sparse_backward(self, dim, index, grad); } auto result = grad.new_zeros_symint(self.sym_sizes()); // for composite, vmap and inductor compliance, use out-of-place variant of // `scatter_add` if index or grad tensors is a Tensor Subclass. if (areAnyTensorSubclassLike({index, grad})) { return result.scatter_add(dim, index, grad); } result.scatter_add_(dim, index, grad); return result; } static void scatter_reduce_exclude_self_helper( const Tensor& self, int64_t dim, const Tensor& index, const ReductionType& op) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "scatter_reduce_exclude_input_init", [&] { scalar_t init_val; switch (op) { case ReductionType::SUM: init_val = (scalar_t)0; break; case ReductionType::PROD: init_val = (scalar_t)1; break; case ReductionType::MAX: init_val = std::numeric_limits::has_infinity ? -std::numeric_limits::infinity() : std::numeric_limits::lowest(); break; case ReductionType::MIN: init_val = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); break; case ReductionType::MEAN: init_val = (scalar_t)0; break; } self.scatter_(dim, index, init_val); }); } static void _scatter_via_index_put( const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& mut_out, bool accumulate) { if (self.dim() == 1) { torch::List> indices; indices.reserve(1); indices.push_back(index); mut_out.index_put_(indices, src, accumulate); } else { Tensor mut_out_contig = mut_out.contiguous(); auto index_coords_sizes = index.sizes().vec(); index_coords_sizes.push_back(self.dim()); auto index_coords = at::empty( index_coords_sizes, at::TensorOptions().dtype(at::ScalarType::Long).device(self.device())); for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) { if (dim_other == dim) { continue; } auto dim_coord_vals = at::arange( index.size(dim_other), at::TensorOptions().device(self.device())); for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; dim_unsqueeze++) { dim_coord_vals = dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0); } auto view_sizes = index.sizes().vec(); view_sizes.push_back(1); auto view_strides = index_coords.strides().vec(); view_strides[self.dim()] = self.dim(); at::as_strided( index_coords, view_sizes, view_strides, dim_other ).copy_(dim_coord_vals.unsqueeze(-1)); } auto view_sizes = index.sizes().vec(); view_sizes.push_back(1); auto view_strides = index_coords.strides().vec(); view_strides[self.dim()] = self.dim(); at::as_strided( index_coords, view_sizes, view_strides, dim ).copy_(index.unsqueeze(-1)); Tensor index_coords_flat = index_coords.flatten(0, -2); // Copy mut_out_contig's strides into a tensor // TODO: Is there a utility function that already does this? IntArrayRef mut_out_contig_strides = mut_out_contig.strides(); Tensor coord_strides = at::empty( {mut_out_contig.dim()}, TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU)); std::memcpy( coord_strides.mutable_data_ptr(), mut_out_contig_strides.data(), coord_strides.nbytes()); coord_strides = coord_strides.to(mut_out_contig.device()); // `index_flat` contains the 1-D indices corresponding with the // flattened `mut_out` Tensor index_flat = (index_coords_flat * coord_strides).sum({-1}); Tensor mut_out_flat = mut_out_contig.flatten(); Tensor src_flat = at::as_strided( src, index.sizes(), src.strides() ).flatten(); torch::List> indices; indices.reserve(1); indices.push_back(index_flat); mut_out_flat.index_put_(indices, src_flat, accumulate); if (!mut_out.is_contiguous()) { mut_out.copy_(mut_out_flat.reshape(mut_out.sizes())); } } } template void scatter_impl( const Tensor& self, int64_t dim, const Tensor& index, const T& src, const Tensor& out, ReduceStub& reduce_stub, FillStub& fill_stub, const std::optional reduce = std::nullopt, bool reduce_includes_self = true) { dim = at::maybe_wrap_dim(dim, self.dim()); auto mut_out = const_cast(out); if (!self.is_same(mut_out)) { mut_out.copy_(self); } if (index.numel() == 0) return; auto op = ReductionType::SUM; bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU); if (reduce.has_value()) { op = get_operator_enum(reduce.value(), use_new_options); if (!reduce_includes_self) { // scatter inits for reduction to appropriate indices (used by scatter_reduce.two) scatter_reduce_exclude_self_helper(mut_out, dim, index, op); } // _scatter_via_index_put can only handle sum and mean reduction type deterministic = deterministic && (op == ReductionType::SUM || op == ReductionType::MEAN); } // Scalar src should already be deterministic if (deterministic && std::is_same_v) { // both runtime and compile check are required if constexpr (std::is_same_v) { bool accumulate = reduce.has_value(); _scatter_via_index_put(self, dim, index, src, mut_out, accumulate); return; } } if (reduce.has_value()) { reduce_stub(self.device().type(), mut_out, dim, index, src, op); } else { fill_stub(self.device().type(), mut_out, dim, index, src); } } TORCH_IMPL_FUNC(scatter_src_out) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& out) { scatter_impl(self, dim, index, src, out, scatter_reduce_stub, scatter_stub); } TORCH_IMPL_FUNC(scatter_value_out) (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& out) { scatter_impl(self, dim, index, value, out, scatter_scalar_reduce_stub, scatter_fill_stub); } TORCH_IMPL_FUNC(scatter_reduce_out) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const c10::string_view reduce, const Tensor& out) { scatter_impl(self, dim, index, src, out, scatter_reduce_stub, scatter_stub, reduce); } TORCH_IMPL_FUNC(scatter_value_reduce_out) (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const c10::string_view reduce, const Tensor& out) { scatter_impl(self, dim, index, value, out, scatter_scalar_reduce_stub, scatter_fill_stub, reduce); } TORCH_IMPL_FUNC(scatter_add) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& out) { auto mut_out = const_cast(out); dim = maybe_wrap_dim(dim, self.dim()); if (!self.is_same(mut_out)) { mut_out.copy_(self); } if (index.numel() == 0) return; // See Note [Enabling Deterministic Operations] // Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) { _scatter_via_index_put(self, dim, index, src, mut_out, /*accumulate*/true); } else { if (can_use_expanded_index_path(mut_out, dim, index, src, /*is_scatter_like*/true)) { scatter_add_expanded_index_stub(self.device().type(), mut_out, index, src); } else { scatter_add_stub(self.device().type(), mut_out, dim, index, src); } } } TORCH_IMPL_FUNC(scatter_reduce_two) (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const c10::string_view reduce, bool include_self, const Tensor& out) { dim = at::maybe_wrap_dim(dim, self.dim()); if (!self.is_same(out)) { out.copy_(self); } const auto op = get_operator_enum(reduce, true); if (can_use_expanded_index_path(out, dim, index, src, /*is_scatter_like*/true)) { scatter_reduce_expanded_index_stub(self.device().type(), out, index, src, op, include_self); return; } scatter_impl(self, dim, index, src, out, scatter_reduce_two_stub, scatter_stub, reduce, include_self); if (op == ReductionType::MEAN) { auto ones = at::ones_like(src); auto count = include_self ? at::ones_like(out) : at::zeros_like(out); count.scatter_add_(dim, index, ones); count.masked_fill_(count == 0, 1); if (out.is_floating_point() || out.is_complex()) { out.div_(count); } else { out.div_(count, "floor"); } } } Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) { auto [_mask, _self] = expand_outplace(mask, self); return _self->clone(at::MemoryFormat::Contiguous).masked_scatter_(*_mask, source); } Tensor masked_scatter_backward_symint( const Tensor& grad, const Tensor& mask, c10::SymIntArrayRef sizes) { c10::SymInt numel = 1; for (const auto& size : sizes) { numel *= size; } auto mask_selected = grad.masked_select(mask); auto diff_nelem = numel - mask_selected.sym_numel(); if (diff_nelem > 0) { // because mask_selected returns a 1-d tensor with size of masked elements // that are 1, we need to fill out the rest with zeros then reshape back to // tensor2's size. auto zeros_fillin = at::zeros_symint({std::move(diff_nelem)}, grad.options()); mask_selected = at::cat({mask_selected, std::move(zeros_fillin)}, 0); } return mask_selected.view_symint(sizes); } static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, const Scalar& value) { NoNamesGuard guard; TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_fill_ only supports boolean masks, but got mask " "with dtype ", mask.dtype()); if (at::has_internal_overlap(self) == MemOverlap::Yes) { TORCH_WARN( "Use of masked_fill_ on expanded tensors is deprecated. " "Please clone() the tensor before performing this operation. " "This also applies to advanced indexing e.g. tensor[mask] = scalar"); } at::assert_no_partial_overlap(self, mask); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // deprecated, but not a hard error .check_all_same_dtype(false) .resize_outputs(false) .add_output(self) .add_const_input(mask) .build(); masked_fill_stub(iter.device_type(), iter, value); return self; } Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Scalar& value) { auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); masked_fill_impl_cpu(self, mask, value); namedinference::propagate_names_if_nonempty(self, maybe_outnames); return self; } Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & value) { auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_"); TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor " "with ", value.dim(), " dimension(s)."); masked_fill_impl_cpu(self, mask, value.item()); namedinference::propagate_names_if_nonempty(self, maybe_outnames); return self; } Tensor masked_fill(const Tensor & self, const Tensor & mask, const Scalar& source) { Tensor result; auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill"); { NoNamesGuard guard; auto [_mask, _self] = expand_outplace(mask, self); result = _self->clone(at::MemoryFormat::Contiguous); result.masked_fill_(mask, source); } namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) { Tensor result; auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill"); { NoNamesGuard guard; auto [_mask, _self] = expand_outplace(mask, self); result = _self->clone(at::MemoryFormat::Contiguous); result.masked_fill_(mask, source); } namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) { NoNamesGuard guard; TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "masked_select: expected BoolTensor for mask"); TORCH_CHECK(self.scalar_type() == result.scalar_type(), "masked_select(): self and result must have the same scalar type"); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); at::assert_no_overlap(result, mask); auto [_mask, _self] = expand_outplace(mask, self); auto shape = _self->sizes(); int64_t numel = _mask->sum().item().toLong(); at::native::resize_output(result, {numel}); if (numel == 0) { return result; } // Create strided view of result before feeding into TensorIterator auto strides = DimVector(shape.size(), 0); auto orig_stride = result.strides()[0]; auto result_strided = result.as_strided(shape, strides); // serial kernel // serial kernel requires that src is traversed in its logical order. However, TensorIterator might // have reordered dimensions so that src would be traversed in its physical order, producing wrong // answers. A sufficient condition that no reorder happened is that both _self and _mask is contiguous. // If it is not satisfied, use parallel kernel that handles permutations correctly bool use_serial_kernel = (self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ) && _self->is_contiguous() && _mask->is_contiguous(); if (use_serial_kernel) { auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // result is intentionally zero-strided above .check_all_same_dtype(false) .resize_outputs(false) .add_output(result_strided) .add_const_input(*_self) .add_const_input(*_mask) .build(); masked_select_serial_stub(iter.device_type(), iter, orig_stride); return result; } // Use a prefix sum to record the output locations of the masked elements, // so as to parallel with TensorIterator. auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(*_mask); auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong)); auto mask_long_data = mask_long.data_ptr(); auto mask_prefix_sum_data = mask_prefix_sum.data_ptr(); // TODO: Here can only use std::partial_sum for C++14, // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance. // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0); std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // result is intentionally zero-strided above .check_all_same_dtype(false) .resize_outputs(false) .add_output(result_strided) .add_const_input(*_self) .add_const_input(*_mask) .add_const_input(mask_prefix_sum) .build(); masked_select_stub(iter.device_type(), iter, orig_stride); return result; } Tensor & masked_select_out_cpu(const Tensor & self, const Tensor & mask, Tensor & result) { namedinference::compute_broadcast_outnames(self, mask); return masked_select_out_impl_cpu(result, self, mask); } Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) { Tensor result = at::empty({0}, self.options()); return at::native::masked_select_out_cpu(self, mask, result); } Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Tensor& mask) { // The following could just be written as `zeros_like(input).masked_scatter(mask, grad)`. // However, as an optimization, we call the in-place variant of masked_scatter. // Unfortunately, that doesn't allow for the broadcasting of the LHS, so we need // to explicitly broadcast here (the out-of-place variant of masked_scatter // implicitly handles broadcasting). auto result = at::zeros_like( input.expand(at::infer_size(input.sizes(), mask.sizes())), at::MemoryFormat::Preserve); // for composite compliance, use out-of-place variant // of `masked_scatter`. if (areAnyTensorSubclassLike({grad, mask})) { return result.masked_scatter(mask, grad); } result.masked_scatter_(mask, grad); return result; } namespace { inline std::tuple _take_along_dim_helper( const Tensor& self, const Tensor& indices, int64_t dim) { TORCH_CHECK( self.dim() == indices.dim(), "torch.take_along_dim(): input and indices should have the same number of dimensions, ", "but got ", self.dim(), " dimensions for input, and ", indices.dim(), " dimensions for indices") TORCH_CHECK( indices.scalar_type() == ScalarType::Long, "torch.take_along_dim(): dtype of indices should be Long but got ", indices.scalar_type()) dim = at::maybe_wrap_dim(dim, self.dim()); SymDimVector self_sizes{self.sym_sizes()}; // update number of elements at dim as per indices self_sizes[dim] = indices.sym_size(dim); auto broadcast_shape = infer_size_symint(self_sizes, indices.sym_sizes()); auto indices_broadcasted = at::broadcast_to_symint(indices, broadcast_shape); SymDimVector indices_sizes{indices.sym_sizes()}; // update number of elements at dim as per self indices_sizes[dim] = self.sym_size(dim); broadcast_shape = infer_size_symint(indices_sizes, self.sym_sizes()); auto self_broadcasted = at::broadcast_to_symint(self, broadcast_shape); return std::make_tuple(std::move(self_broadcasted), std::move(indices_broadcasted), std::move(dim)); } static inline void checkDevice(CheckedFrom c, const Tensor& t, Device device) { TORCH_CHECK( !t.defined() || t.device() == device, "Expected tensor to have ", device, " Device, but got tensor with ", t.device(), " Device ", "(while checking arguments for ", c, ")"); } static inline void checkDevice(CheckedFrom c, at::ArrayRef tensors, Device device) { for (auto &t : tensors) { checkDevice(c, t, device); } } } // anonymous namespace Tensor take_along_dim(const Tensor& self, const Tensor& indices, std::optional opt_dim) { checkDevice("torch.take_along_dim():", {self, indices}, self.device()); if (opt_dim.has_value()) { auto [self_broadcasted, indices_broadcasted, dim] = _take_along_dim_helper(self, indices, opt_dim.value()); return self_broadcasted.gather(dim, indices_broadcasted); } // similar to `take`, but `take` doesn't support the same dtypes as `gather`. return self.view(-1).gather(0, indices.view(-1)); } Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, std::optional opt_dim, Tensor& result) { checkDevice("torch.take_along_dim():", {self, indices, result}, self.device()); if (opt_dim.has_value()) { auto [self_broadcasted, indices_broadcasted, dim] = _take_along_dim_helper(self, indices, opt_dim.value()); return at::gather_out(result, self_broadcasted, dim, indices_broadcasted); } // similar to `take`, but `take` doesn't support the same dtypes as `gather`. return at::gather_out(result, self.view(-1), 0, indices.view(-1)); } Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ // special case scalar input and/or index if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes()); if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes()); Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong)); SymInt grad_numel = grad.sym_numel(); if (grad_numel > 0) { SymInt n_above = grad_numel; SymInt n_below = 1; if (dim < 0) dim += self.ndimension(); for (const auto i : c10::irange(self.ndimension())) { n_above /= grad.sym_size(i); if (i == dim) { sparse_ind[i] = index.reshape(-1); } else { sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below); } n_below *= grad.sym_size(i); } } return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes()); } template int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) { int64_t num_nonzero = 0; auto loop = [&](char** data, const int64_t* strides, int64_t n) { constexpr int ilp_factor = 4; const char* ptr = data[0]; const auto stride = strides[0]; int64_t nonzero[ilp_factor] = {0}; int64_t i = 0; for (; i + (ilp_factor - 1) < n; i += ilp_factor) { c10::ForcedUnroll{}([&](int k) { const auto& val = c10::load(ptr + k * stride); if (val != scalar_t(0)) { ++nonzero[k]; } }); ptr += ilp_factor * stride; } for (; i < n; ++i) { const auto& val = c10::load(ptr); if (val != scalar_t(0)) { ++nonzero[0]; } ptr += stride; } for (const auto k : c10::irange(1, ilp_factor)) { nonzero[0] += nonzero[k]; } num_nonzero += nonzero[0]; }; iter.serial_for_each(loop, range); return num_nonzero; } Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){ auto reduce = self; if (reduce.scalar_type() != kBool) { reduce = reduce != 0; } return reduce.sum(dims); } Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){ if (!dims.empty()) { auto reduce = self; if (reduce.scalar_type() != kBool) { reduce = reduce != 0; } return reduce.sum(dims); } // Optimized all-reduce auto iter = TensorIteratorConfig() .add_const_input(self) .build(); const auto num_threads = at::get_num_threads(); DimVector thread_count_nonzero(num_threads); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] { at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { const auto tid = at::get_thread_num(); thread_count_nonzero[tid] = count_nonzero_impl(iter, {begin, end}); }); }); for (const auto i : c10::irange(1, num_threads)) { thread_count_nonzero[0] += thread_count_nonzero[i]; } auto out = at::empty({}, self.options().dtype(kLong)); *out.mutable_data_ptr() = thread_count_nonzero[0]; return out; } Tensor count_nonzero(const Tensor& self, std::optional dim) { if (dim) { return at::count_nonzero(self, IntArrayRef{*dim}); } return at::count_nonzero(self, IntArrayRef{}); } Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { TORCH_CHECK(result.scalar_type() == kLong, "nonzero: Expected out tensor to have scalar type Long " "but got scalar type", result.scalar_type()); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); auto iter = TensorIteratorConfig() .add_const_input(self) .enforce_linear_iteration() .build(); const auto numel = iter.numel(); const auto num_threads = at::get_num_threads(); DimVector thread_begin(num_threads, -1); DimVector thread_count_nonzero(num_threads + 1); // Pass 1: Count nonzero element per-thread AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] { at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { const auto tid = at::get_thread_num(); thread_begin[tid] = begin; thread_count_nonzero[tid + 1] = count_nonzero_impl(iter, {begin, end}); }); }); // Convert thread-local counts to cumulative sum for (const auto i : c10::irange(1, thread_count_nonzero.size())) { thread_count_nonzero[i] += thread_count_nonzero[i - 1]; } const auto self_sizes = self.sizes(); const auto total_nonzero = thread_count_nonzero.back(); const int64_t ndim = self_sizes.size(); if (resize_output(result, {total_nonzero, ndim})) { // Default to fortran-contiguous output (see gh-46224) result.as_strided_({total_nonzero, ndim}, {1, total_nonzero}); } if (result.numel() == 0) { return result; } auto out_accessor = result.accessor(); // Pass 2: Write indexes AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] { at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) { auto tid = at::get_thread_num(); // Work needs to be distributed the same on both passes TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]); // +1 faster than additional condition check inside loop c10::SmallVector sizes(ndim + 1, -1); std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1); c10::SmallVector current_idx(ndim + 1); if (begin > 0) { auto idx = begin; for (int64_t k = ndim; idx > 0 && k > 0; --k) { current_idx[k] = idx % sizes[k]; idx /= sizes[k]; } } auto out_ptr = out_accessor[thread_count_nonzero[tid]].data(); auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) { // Copy into local variables to improve compiler alias analysis int64_t* C10_RESTRICT local_idx = current_idx.data() + 1; const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1; const auto in_stride = strides[0]; const auto out_stride1 = out_accessor.stride(1); const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1; const auto ndim = out_accessor.size(1); int64_t* out = out_ptr; for (const auto i : c10::irange(n2)) { const char* ptr = data[0] + i * strides[1]; for (C10_UNUSED const auto j : c10::irange(n1)) { const auto& val = c10::load(ptr); // If nonzero, write index if (val != scalar_t(0)) { for (const auto k : c10::irange(ndim)) { *out = local_idx[k]; out += out_stride1; } out += out_stride0; } ptr += in_stride; // Advance current index int64_t k = ndim - 1; ++local_idx[k]; while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) { local_idx[k] = 0; --k; ++local_idx[k]; } } } out_ptr = out; }; iter.serial_for_each(loop, {begin, end}); TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data()); }); }); return result; } Tensor nonzero_cpu(const Tensor& self) { auto result = at::empty({0}, self.options().dtype(kLong)); nonzero_out_cpu(self, result); return result; } Tensor& nonzero_static_out_cpu( const Tensor& self, int64_t size, int64_t fill_value, Tensor& result) { // Check if `size` is not negative TORCH_CHECK( size >= 0, "nonzero_static: 'size' must be an non-negative integer"); TORCH_CHECK( result.scalar_type() == kLong, "nonzero_static: Expected out tensor to have scalar type Long " "but got scalar type", result.scalar_type()); int64_t ndim = self.dim(); if (result.dim() != 2 || result.size(0) != size || result.size(1) != ndim) { at::native::resize_output(result, {size, ndim}); } // Verify that the output tensor is resized to expected size=(size, ndim) TORCH_CHECK( result.dim() == 2, "nonzero_static: Expected out tensor to be a 2D tensor but got a ", result.dim(), "D tensor"); TORCH_CHECK( result.size(0) == size && result.size(1) == ndim, "nonzero_static: Expected out tensor to have Size([", size, ", ", ndim, "]) but got Size([", result.size(0), ", ", result.size(1), "]) "); at::assert_no_internal_overlap(result); at::assert_no_overlap(result, self); // Return earlier if either dim is 0 if (result.size(0) == 0 || result.size(1) == 0) { return result; } // Delegate call to regular nonzero to get a data-dependent output auto dyn_result = nonzero_cpu(self); int64_t num_nonzeros = dyn_result.size(0); int64_t copy_len = std::min(size, num_nonzeros); // Copy the dynamic result to the fixed-size tensor result.narrow(0, 0, copy_len).copy_(dyn_result.narrow(0, 0, copy_len)); if (size > copy_len) { // Pad result with `fill_value` result.narrow(0, copy_len, size - copy_len).fill_(fill_value); } return result; } Tensor nonzero_static_cpu( const Tensor& self, int64_t size, int64_t fill_value) { // Check if `size` is not negative TORCH_CHECK( size >= 0, "nonzero_static: 'size' must be an non-negative integer"); // Allocate fixed-size out tensor int64_t ndim = self.dim(); auto result = at::empty( {size, ndim}, at::TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU)); nonzero_static_out_cpu(self, size, fill_value, result); return result; } std::vector nonzero_numpy(const Tensor& self) { // special case scalar for compatibility with numpy: // // >>> np.array(5).nonzero() // (array([0]),) // >>> np.array(0).nonzero() // (array([], dtype=int64),) if (self.dim() == 0) { return self.unsqueeze(0).nonzero().unbind(1); } return self.nonzero().unbind(1); } Tensor argwhere(const Tensor& self) { return self.nonzero(); } Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) { at::assert_no_internal_overlap(self); TORCH_CHECK( self.scalar_type() == source.scalar_type(), "masked_scatter: expected self and source to have same dtypes but got", self.scalar_type(), " and ", source.scalar_type()); TORCH_CHECK(self.device().type() == at::kCPU, "device type of self (", self.device().type(), ") is not CPU"); TORCH_CHECK(mask.device().type() == at::kCPU, "device type of mask (", mask.device().type(), ") is not CPU"); TORCH_CHECK(source.device().type() == at::kCPU, "device type of source (", source.device().type(), ") is not CPU"); c10::MaybeOwned b_mask = expand_inplace(self, mask, "masked_scatter_"); if (b_mask->dtype() == ScalarType::Byte) { TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ "please use a mask with dtype torch.bool instead."); } auto src_cont = source.contiguous(); auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) .check_all_same_dtype(false) .resize_outputs(false) // order of indexing matters .enforce_linear_iteration() .add_output(self) .add_const_input(*b_mask) .build(); masked_scatter_stub(iter.device_type(), iter, src_cont); return self; } } // namespace at::native