#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #endif #include namespace at::meta { using namespace ::at::native; TORCH_META_FUNC(topk) (const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted) { int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); TORCH_CHECK( k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range"); int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim); TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension"); // Build the output size, which is the dim being selected set to // size k DimVector topKSize(self.sizes().vec()); if (!topKSize.empty()) { topKSize[dim] = k; } set_output_raw_strided(0, topKSize, {}, self.options()); set_output_raw_strided(1, topKSize, {}, self.options().dtype(at::kLong)); } TORCH_META_FUNC2(sort, stable) (const Tensor& self, std::optional stable, int64_t dim, bool descending) { maybe_wrap_dim(dim, self.dim()); // See issue: https://github.com/pytorch/pytorch/issues/65863 // Strides should be dense, so as not to allocate too much memory. // We either use 'self' strides, or infer dense strides from them. std::vector strides = (self.is_non_overlapping_and_dense()) ? self.strides().vec() : at::infer_dense_strides(self.sizes(), self.strides()); set_output_raw_strided(0, self.sizes(), strides, self.options(), {}); set_output_raw_strided(1, self.sizes(), strides, self.options().dtype(kLong), {}); } } // namespace at::meta namespace at::native { DEFINE_DISPATCH(sort_stub); DEFINE_DISPATCH(topk_stub); void _fill_indices(const TensorBase &indices, int64_t dim) { auto ndim = indices.dim(); assert(0 <= dim && dim < ndim); auto dim_size = indices.size(dim); auto idx_dim = at::arange(0, dim_size, indices.options().dtype(at::kLong)); auto idx_dim_sizes = std::vector(ndim, 1); auto idx_dim_strides = std::vector(ndim, 0); idx_dim_sizes[dim] = dim_size; idx_dim_strides[dim] = 1; auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides); OptionalTensorRef(indices)->copy_(idx_dim_restrided); } namespace { /* Note from TH: I cut and pasted (slightly adapted) the quicksort code from Sedgewick's 1978 "Implementing Quicksort Programs" article http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf It is the state of the art existing implementation. The macros are here to make as close a match as possible to the pseudocode of Program 2 p.851 Note that other partition schemes exist, and are typically presented in textbook, but those are less efficient. See e.g. http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto Julien, November 12th 2013 */ template void quick_select_template( TensorAccessor arr, int64_t k, Comp gt_or_nan, Fn swap_fn) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t P, L, R, i, j; scalar_t piv; L = 0; R = arr.size(0) - 1; do { if (R <= L) // One element only return; if (R == L + 1) { // Two elements only if (gt_or_nan(arr[L], arr[R])) { swap_fn(L, R); } return; } // Use median of three for pivot choice P = L + (R - L) / 2; swap_fn(P, L + 1); if (gt_or_nan(arr[L + 1], arr[R])) { swap_fn(L + 1, R); } if (gt_or_nan(arr[L], arr[R])) { swap_fn(L, R); } if (gt_or_nan(arr[L + 1], arr[L])) { swap_fn(L + 1, L); } i = L + 1; j = R; piv = arr[L]; do { do i++; while (gt_or_nan(piv, arr[i])); do j--; while (gt_or_nan(arr[j], piv)); if (j < i) break; swap_fn(i, j); } while (true); swap_fn(L, j); // Re-set active partition if (j <= k) L = i; if (j >= k) R = j - 1; } while (true); } namespace { QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode( const c10::string_view interpolation) { if (interpolation == "linear") { return QUANTILE_INTERPOLATION_MODE::LINEAR; } else if (interpolation == "lower") { return QUANTILE_INTERPOLATION_MODE::LOWER; } else if (interpolation == "higher") { return QUANTILE_INTERPOLATION_MODE::HIGHER; } else if (interpolation == "midpoint") { return QUANTILE_INTERPOLATION_MODE::MIDPOINT; } else if (interpolation == "nearest") { return QUANTILE_INTERPOLATION_MODE::NEAREST; } else { TORCH_CHECK( false, "quantile() interpolation must be one of linear, lower, higher, midpoint or nearest, but got ", interpolation); } } void quantile_checks(const Tensor& self, const Tensor& q) { TORCH_CHECK(self.numel() > 0, "quantile() input tensor must be non-empty"); TORCH_CHECK(q.dim() <= 1, "quantile() q must be a scalar or 1D tensor"); TORCH_CHECK( self.scalar_type() == kFloat || self.scalar_type() == kDouble, "quantile() input tensor must be either float or double dtype"); TORCH_CHECK( self.scalar_type() == q.scalar_type(), "quantile() q tensor must be same dtype as the input tensor"); TORCH_CHECK( self.device() == q.device(), "quantile() q tensor must be on the same device as the input tensor"); } std::vector quantile_output_shape( const std::optional original_dim, const Tensor& self, const Tensor& q, const bool keepdim, int64_t wrapped_dim) { // Compute output shape: q_size + reduced_size std::vector out_shape; if (original_dim && self.dim() > 0) { out_shape = self.sizes().vec(); if (keepdim) { out_shape[wrapped_dim] = 1; } else { out_shape.erase(out_shape.begin() + wrapped_dim); } } else if (keepdim) { out_shape = std::vector(self.dim(), 1); } if (q.dim() > 0) { out_shape.insert(out_shape.begin(), q.numel()); } return out_shape; } Tensor quantile_compute( const Tensor& self, const Tensor& q, const std::optional orginal_dim, const bool keepdim, const QUANTILE_INTERPOLATION_MODE& interpolation, const bool ignore_nan, int64_t wrapped_dim, std::vector out_shape) { // Checks that all q values are between 0 and 1, inclusive // NOTE: this check is only performed when running on the CPU to avoid // synchronizing an accelerator with the CPU if (self.device().is_cpu()) { auto all_q_in_range = q.ge(0).logical_and_(q.le(1)).all(); TORCH_CHECK(at::is_scalar_tensor_true(all_q_in_range), "quantile() q values must be in the range [0, 1]"); } // Flatten input if no dim provided else move dim to reduce as last dimension. // Sort to efficiently query kth values. Tensor sorted; if (!orginal_dim) { sorted = std::get<0>(self.flatten().sort()); } else if (wrapped_dim == self.dim() - 1) { sorted = std::get<0>(self.sort()); } else { sorted = std::get<0>(self.unsqueeze(-1).transpose(wrapped_dim, -1).sort()); } // Treat q as a 1D tensor for the following computations if (q.dim() == 0) { out_shape.insert(out_shape.begin(), q.numel()); } // View input as reduced_size + size of dim to reduce std::vector in_shape(out_shape.size()); std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin()); in_shape[in_shape.size() - 1] = sorted.size(-1); sorted = sorted.view(in_shape); // Ensure converting from int64_t to double won't overflow TORCH_CHECK( sorted.size(-1) <= std::pow(2, 24), "quantile() input tensor is too large"); // Convert q in [0, 1] to ranks in [0, reduction_size) Tensor ranks; if (ignore_nan) { // For nanquantile, compute ranks based on number of non-nan values. // If all values are nan, set rank to 0 so the quantile computed is nan. ranks = q * (sorted.isnan().logical_not_().sum(-1, true) - 1); // For Composite Compliance, // if `ranks` is `CCT` but it's tangent is a regular Tensor, // then while computing jvp, we end calling `masked_fill_` // on a regular Tensor with CCT args, so we call // `masked_fill` instead. if (isTensorSubclassLike(ranks) && ranks._fw_grad(/*level=*/0).defined()) { ranks = ranks.masked_fill(ranks < 0, 0); } else { ranks.masked_fill_(ranks < 0, 0); } } else { // For quantile, compute ranks based on reduction size. If there is nan // set rank to last index so the quantile computed will be nan. int64_t last_index = sorted.size(-1) - 1; std::vector tl = at::broadcast_tensors({q * last_index, sorted.isnan().any(-1, true)}); ranks = at::masked_fill(tl[0], tl[1], last_index); } // adjust ranks based on the interpolation mode if (interpolation == QUANTILE_INTERPOLATION_MODE::LOWER) { ranks.floor_(); } else if (interpolation == QUANTILE_INTERPOLATION_MODE::HIGHER) { ranks.ceil_(); } else if (interpolation == QUANTILE_INTERPOLATION_MODE::NEAREST) { ranks.round_(); } Tensor ranks_below = ranks.toType(kLong); Tensor values_below = sorted.gather(-1, ranks_below); // Actual interpolation is only needed for the liner and midpoint modes if (interpolation == QUANTILE_INTERPOLATION_MODE::LINEAR || interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT) { // calculate weights for linear and midpoint Tensor weights = interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT ? at::full_like(ranks, 0.5) : ranks - ranks_below; // Interpolate to compute quantiles and store in values_below Tensor ranks_above = ranks.ceil_().toType(kLong); Tensor values_above = sorted.gather(-1, ranks_above); // For Composite Compliance, // if either `values_below`, `values_above` or `weights` are a CCT // or tangents of `value_above` and `weights` are a CCT, // but if the tangent of `value_below` is a regular Tensor, // then while computing jvp, we will end-up copying a `CCT`, // into regular Tensor. So we use out-of-place variant of `lerp` auto is_primal_cct = areAnyTensorSubclassLike({values_below, values_above, weights}); auto is_tangent_cct = areAnyTensorSubclassLike( {values_above._fw_grad(/*level=*/0), weights._fw_grad(/*level=*/0)}); if ((is_primal_cct || is_tangent_cct) && values_below._fw_grad(/*level=*/0).defined() && !isTensorSubclassLike(values_below._fw_grad(/*level=*/0))) { values_below = values_below.lerp(values_above, weights); } else { values_below.lerp_(values_above, weights); } } if (q.dim() == 0) { // If q is scalar, remove last dim to match out shape values_below.squeeze_(-1); } else { // Move quantiles to first dim to match out shape values_below.unsqueeze_(0).transpose_(0, -1).squeeze_(-1); } return values_below; } } // namespace void quantile_out_impl( Tensor& out, const Tensor& self, const Tensor& q, const std::optional original_dim, const bool keepdim, const QUANTILE_INTERPOLATION_MODE& interpolation, const bool ignore_nan) { quantile_checks(self, q); TORCH_CHECK( self.scalar_type() == out.scalar_type(), "quantile() out tensor must be same dtype as the input tensor"); TORCH_CHECK( self.device() == out.device(), "quantile() out tensor must be on the same device as the input tensor"); int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim()); auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim); resize_output(out, out_shape); auto quantile = quantile_compute( self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape)); out.copy_(quantile); } Tensor quantile_impl( const Tensor& self, const Tensor& q, const std::optional original_dim, const bool keepdim, const QUANTILE_INTERPOLATION_MODE& interpolation, const bool ignore_nan) { quantile_checks(self, q); int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim()); auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim); return quantile_compute( self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape)); } std::tuple kthvalue_out_impl_cpu( Tensor& values, Tensor& indices, const Tensor& self, int64_t k, int64_t dim_, bool keepdim) { int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim); zero_numel_check_dims(self, dim, "kthvalue()"); TORCH_CHECK(k >= 1 && k <= slicesize, "kthvalue(): selected number k out of range for dimension ", dim); at::assert_no_overlap(self, values); _reduction_with_indices_allocate_or_resize_output( values, indices, self, dim_, keepdim); if (self.dim() == 0 && self.numel() == 1) { values.copy_(self); indices.zero_(); return std::forward_as_tuple(values, indices); } auto tmp_values = self.clone(at::MemoryFormat::Contiguous); auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong)); auto tmp_values_stride = tmp_values.strides()[dim]; auto tmp_indices_stride = tmp_indices.strides()[dim]; auto sizes = self.sizes(); TORCH_CHECK(indices.scalar_type() == kLong); auto iter = TensorIteratorConfig() .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(sizes, /*squash_dims=*/dim) .add_output(tmp_values) .add_output(tmp_indices) .add_output(values) .add_output(indices) .build(); AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "kthvalue_cpu", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { for (const auto i : c10::irange(n)) { TensorAccessor tmp_values( reinterpret_cast(data[0] + i * strides[0]), &sizes[dim], &tmp_values_stride); TensorAccessor tmp_indices( reinterpret_cast(data[1] + i * strides[1]), &sizes[dim], &tmp_indices_stride); auto mode_value = reinterpret_cast(data[2] + i * strides[2]); auto mode_index = reinterpret_cast(data[3] + i * strides[3]); for (const auto j : c10::irange(tmp_indices.size(0))) { tmp_indices[j] = j; } // we want NaN to be sorted as top for numpy compatibility quick_select_template( tmp_values, k - 1, [](scalar_t x, scalar_t y) -> bool { return ( (_isnan(x) && !_isnan(y)) || (x > y)); }, [&](int64_t i, int64_t j) { std::swap(tmp_values[i], tmp_values[j]); std::swap(tmp_indices[i], tmp_indices[j]); }); *mode_value = tmp_values[k - 1]; *mode_index = tmp_indices[k - 1]; } }; int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]); iter.for_each(loop, /*grain_size=*/grain_size); }); if (!keepdim) { values.squeeze_(dim); indices.squeeze_(dim); } return std::forward_as_tuple(values, indices); } // Computes both the median and its index along dimension dim of the input std::tuple median_with_indices_impl( Tensor& values, Tensor& indices, const Tensor& self, int64_t dim, bool keepdim, bool ignore_nan) { dim = at::maybe_wrap_dim(dim, self.dim()); int64_t size = self.dim() > 0 ? self.size(dim) : 1; zero_numel_check_dims(self, dim, "median()"); checkDeviceType("median", {values, indices}, self.device().type()); checkScalarType("median", {indices, "indices", 1}, kLong); checkSameType("median", {values, "values", 0}, {self, "self", 2}); std::vector out_shape = self.sizes().vec(); if (self.dim() > 0) { if (keepdim) { out_shape[dim] = 1; } else { out_shape.erase(out_shape.begin() + dim); } } resize_output(values, out_shape); resize_output(indices, out_shape); // Ensure #dim is the same for all tensors required for dim_apply Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim); Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim); // Make dim to reduce contiguous (stride=1) if (in.stride(dim) > 1) { in = in.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim).contiguous(); vals = vals.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); inds = inds.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim); dim = in.dim() - 1; } auto sizes = in.sizes(); auto iter = TensorIteratorConfig() .check_all_same_dtype(false) .resize_outputs(false) .declare_static_shape(sizes, /*squash_dims=*/dim) .add_output(vals) .add_output(inds) .add_const_input(in) .build(); AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_out", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { for (const auto i : c10::irange(n)) { auto valp = reinterpret_cast(data[0] + i * strides[0]); auto indp = reinterpret_cast(data[1] + i * strides[1]); auto ip = reinterpret_cast(data[2] + i * strides[2]); // For torch.median, search for NaN and return it if found if (!ignore_nan) { const scalar_t* nanp = std::find_if(ip, ip + size, _isnan); if (nanp != ip + size) { *valp = *nanp; *indp = nanp - ip; continue; } } // Vector of indices for indirectly partitioning input around median std::vector idx(size); auto first = idx.begin(); auto last = idx.end(); std::iota(first, last, 0); // We partition the input around the median indirectly using the indices // vector so that nth points to the index of the median in the unmodified // input tensor. auto nth = first; if (!ignore_nan) { // If we got here, there are no nan values nth += (size - 1) / 2; std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { return ip[i] < ip[j] || (ip[i] == ip[j] && i < j); }); } else { // For torch.nanmedian, compute median of non-nan values only int64_t num_nan = std::count_if(ip, ip + size, _isnan); nth += (size - num_nan - 1) / 2; std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) { return ip[i] < ip[j] || (ip[i] == ip[j] && i < j) || (_isnan(ip[j]) && !_isnan(ip[i])); }); } *valp = ip[*nth]; *indp = *nth; } }; int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]); iter.for_each(loop, /*grain_size=*/grain_size); }); return std::forward_as_tuple(values, indices); } // Computes the median of all values in the input Tensor median_impl(const Tensor& self, bool ignore_nan) { NoNamesGuard guard; const int64_t size = self.numel(); // Return nan for empty tensors if (size <= 0) { return at::full({}, std::numeric_limits::quiet_NaN()).to(self.options()); } // Clone the input tensor so we can partition it around the median value Tensor in = self.clone(); Tensor out = at::empty({}, self.options()); AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_cpu", [&] { scalar_t* op = out.data_ptr(); scalar_t* first = in.data_ptr(); scalar_t* last = first + size; // For torch.median, if there are nan values return nan if (!ignore_nan && std::any_of(first, last, _isnan)) { *op = std::numeric_limits::quiet_NaN(); return; } scalar_t* median = first; if (!ignore_nan) { // If we got here, there are no nan values median += (size - 1) / 2; std::nth_element(first, median, last); } else { // For torch.nanmedian, compute median of non-nan values only int64_t num_nan = std::count_if(first, last, _isnan); median += (size - num_nan - 1) / 2; std::nth_element(first, median, last, [](scalar_t a, scalar_t b) { return a < b || (_isnan(b) && !_isnan(a)); }); } *op = *median; }); return out; } } // namespace Tensor& quantile_out( const Tensor& self, const Tensor& q, std::optional dim, bool keepdim, const c10::string_view interpolation, Tensor& out) { quantile_out_impl( out, self, q, dim, keepdim, get_quantile_interpolation_mode(interpolation), /*ignore_nan=*/false); return out; } Tensor& quantile_out( const Tensor& self, double q, std::optional dim, bool keepdim, const c10::string_view interpolation, Tensor& out) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::quantile_out( self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation, out); } Tensor quantile( const Tensor& self, const Tensor& q, std::optional dim, bool keepdim, const c10::string_view interpolation) { return quantile_impl( self, q, dim, keepdim, get_quantile_interpolation_mode(interpolation), /*ignore_nan=*/false); } Tensor quantile( const Tensor& self, double q, std::optional dim, bool keepdim, const c10::string_view interpolation) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::quantile( self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation); } Tensor& nanquantile_out( const Tensor& self, const Tensor& q, std::optional dim, bool keepdim, const c10::string_view interpolation, Tensor& out) { quantile_out_impl( out, self, q, dim, keepdim, get_quantile_interpolation_mode(interpolation), /*ignore_nan=*/true); return out; } Tensor& nanquantile_out( const Tensor& self, double q, std::optional dim, bool keepdim, const c10::string_view interpolation, Tensor& out) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::nanquantile_out( self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation, out); } Tensor nanquantile( const Tensor& self, const Tensor& q, std::optional dim, bool keepdim, const c10::string_view interpolation) { return quantile_impl( self, q, dim, keepdim, get_quantile_interpolation_mode(interpolation), /*ignore_nan=*/true); } Tensor nanquantile( const Tensor& self, double q, std::optional dim, bool keepdim, const c10::string_view interpolation) { TORCH_CHECK( q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q); return at::native::nanquantile( self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation); } std::tuple kthvalue_out_cpu( const Tensor& self, int64_t k, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { auto result = [&]() { NoNamesGuard guard; return kthvalue_out_impl_cpu(values, indices, self, k, dim, keepdim); }(); namedinference::propagate_names_for_reduction(values, self, dim, keepdim); namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); return result; } std::tuple kthvalue_out( const Tensor& self, int64_t k, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) { return at::kthvalue_out( values, indices, self, k, dimname_to_position(self, dim), keepdim); } std::tuple kthvalue( const Tensor& self, int64_t k, int64_t dim, bool keepdim) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); at::kthvalue_out(values, indices, self, k, dim, keepdim); return std::make_tuple(values, indices); } std::tuple kthvalue( const Tensor& self, int64_t k, Dimname dim, bool keepdim) { return at::kthvalue(self, k, dimname_to_position(self, dim), keepdim); } TORCH_IMPL_FUNC(topk_out_cpu) (const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted, const Tensor& values, const Tensor& indices) { int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true); TORCH_CHECK( k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1), "selected index k out of range"); if (self.dim() == 0 && self.numel() == 1) { values.copy_(self); indices.zero_(); } else { topk_stub(kCPU, values, indices, self, k, dim, largest, sorted); } } std::tuple median_out_cpu( const Tensor& self, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { auto result = [&]() { NoNamesGuard guard; return median_with_indices_impl( values, indices, self, dim, keepdim, /*ignore_nan=*/false); }(); namedinference::propagate_names_for_reduction(values, self, dim, keepdim); namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); return result; } std::tuple median_out( const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) { return at::median_out( values, indices, self, dimname_to_position(self, dim), keepdim); } std::tuple median( const Tensor& self, int64_t dim, bool keepdim) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); at::median_out(values, indices, self, dim, keepdim); return std::make_tuple(values, indices); } std::tuple median( const Tensor& self, Dimname dim, bool keepdim) { return at::median(self, dimname_to_position(self, dim), keepdim); } Tensor median_cpu(const Tensor& self) { return median_impl(self, /*ignore_nan=*/false); } std::tuple nanmedian_out_cpu( const Tensor& self, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { auto result = [&]() { NoNamesGuard guard; return median_with_indices_impl( values, indices, self, dim, keepdim, /*ignore_nan=*/true); }(); namedinference::propagate_names_for_reduction(values, self, dim, keepdim); namedinference::propagate_names_for_reduction(indices, self, dim, keepdim); return result; } std::tuple nanmedian_out( const Tensor& self, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) { return at::nanmedian_out( values, indices, self, dimname_to_position(self, dim), keepdim); } std::tuple nanmedian( const Tensor& self, int64_t dim, bool keepdim) { Tensor values = at::empty({0}, self.options()); Tensor indices = at::empty({0}, self.options().dtype(kLong)); at::nanmedian_out(values, indices, self, dim, keepdim); return std::make_tuple(values, indices); } std::tuple nanmedian( const Tensor& self, Dimname dim, bool keepdim) { return at::nanmedian(self, dimname_to_position(self, dim), keepdim); } Tensor nanmedian_cpu(const Tensor& self) { return median_impl(self, /*ignore_nan=*/true); } TORCH_IMPL_FUNC(sort_stable_out) (const Tensor& self, std::optional stable, int64_t dim, bool descending, const Tensor& values, const Tensor& indices) { values.copy_(self); // check if self is scalar if (self.dim() == 0 && self.numel() == 1) { indices.zero_(); } else { dim = maybe_wrap_dim(dim, self.dim()); sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false)); } } std::tuple sort_out( const Tensor& self, int64_t dim, bool descending, Tensor& values, Tensor& indices) { return at::sort_out(values, indices, self, false, dim, descending); } std::tuple sort( const Tensor& self, int64_t dim, bool descending) { return at::sort(self, false, dim, descending); } Tensor& msort_out(const Tensor& self, Tensor& values) { Tensor indices = at::empty({0}, self.options().dtype(kLong)); at::sort_out(values, indices, self, 0, false); return values; } Tensor msort(const Tensor& self) { return std::get<0>(at::sort(self, 0, false)); } Tensor argsort(const Tensor & self, int64_t dim, bool descending) { return std::get<1>(at::sort(self, dim, descending)); } Tensor argsort(const Tensor & self, bool stable, int64_t dim, bool descending) { return std::get<1>(at::sort(self, stable, dim, descending)); } Tensor& argsort_out(const Tensor & self, bool stable, int64_t dim, bool descending, Tensor& out) { auto values = at::empty({0}, self.options()); at::sort_outf(self, stable, dim, descending, values, out); return out; } } // namespace at::native