#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #endif #include #include #include namespace at::native { namespace { constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200; /* The main algorithm. Expects that the input tensor has shape (N, D). * Expects that bin_edges contains D one-dimensional tensors, each specifying * an increasing sequences of bin edges. * * Interprets the input as N different D-dimensional coordinates and maps them * into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional * histogram in the hist tensor. * * Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how * the scalars in each dimension should be mapped into the dimension's bins: * * - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression. * Scalars are mapped to bins by computing * (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct * and truncating the result to an integer. * * This is the fastest option, but its results may not be perfectly consistent * with the boundaries specified in bin_edges due to precision issues. * * Used by torch.histc, which doesn't need consistency with bin_edges as it does not * return bin_edges. Additionally, this implementation is identical to the legacy histc * implementation, which was replaced when histogram was implemented. * * - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence * forms a linear progression. For each scalar, if 'pos' is the bin selected by the * LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to * place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring * bins allows for correction of misclassifications due to precision issues (a scalar * very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION). * * Should produce the same output as the general case BINARY_SEARCH, but run about * 3x faster asymptotically. * * Used by torch.histogram for cases in which bin_edges is constructed using * torch.linspace. The behavior of LINEAR_INTERPOLATION may not perfectly align * with linspace bin_edges due to precision issues. torch.histogram returns both * the hist and bin_edges tensors as output, so the "local search" is needed to * keep its output internally consistent. * * - BINARY_SEARCH: Handles torch.histogram's general case by by searching over the * elements of bin_edges. Implemented using std::upper_bound. * * See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866 * for further details on relative performance of the bin selection algorithms. */ enum BIN_SELECTION_ALGORITHM { LINEAR_INTERPOLATION, LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH, BINARY_SEARCH, }; template void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges, const Tensor& input, const std::optional& weight) { TORCH_INTERNAL_ASSERT(input.dim() == 2); const int64_t N = input.size(0); if (weight.has_value()) { TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N); } const int64_t D = input.size(1); TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D); for (const auto dim : c10::irange(D)) { TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous()); TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel()); } if (D == 0) { // hist is an empty tensor in this case; nothing to do here return; } TensorAccessor accessor_in = input.accessor(); /* Constructs a std::optional containing an accessor if * the optional weight tensor has a value. */ const auto accessor_wt = weight.has_value() ? std::optional>(weight.value().accessor()) : std::optional>(); std::vector bin_seq(D); std::vector num_bin_edges(D); std::vector leftmost_edge(D), rightmost_edge(D); for (const auto dim : c10::irange(D)) { bin_seq[dim] = bin_edges[dim].data_ptr(); num_bin_edges[dim] = bin_edges[dim].numel(); leftmost_edge[dim] = bin_seq[dim][0]; rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1]; } int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D); /* Parallelizes processing of input using at::parallel_for. * Each thread accumulates a local result into their own slice of * thread_histograms which get summed together at the end. */ const auto num_threads = at::get_num_threads(); const auto hist_sizes = hist.sizes(); DimVector thread_hist_sizes(hist_sizes.size() + 1); thread_hist_sizes[0] = num_threads; std::copy(hist_sizes.begin(), hist_sizes.end(), thread_hist_sizes.begin() + 1); Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype()); TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous()); at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) { const auto tid = at::get_thread_num(); auto hist_strides = thread_histograms.strides(); input_t *hist_local_data = thread_histograms.data_ptr(); // View only this thread's local results hist_local_data += hist_strides[0] * tid; hist_strides = hist_strides.slice(1); for (const auto i : c10::irange(start, end)) { bool skip_elt = false; int64_t hist_index = 0; for (const auto dim : c10::irange(D)) { const input_t elt = accessor_in[i][dim]; // Skips elements which fall outside the specified bins and NaN elements if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) { skip_elt = true; break; } int64_t pos = -1; if (algorithm == BINARY_SEARCH) { // Handles the general case via binary search on the bin edges. pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt) - bin_seq[dim] - 1; } else if (algorithm == LINEAR_INTERPOLATION || algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) { /* When bin_edges is known to be a linear progression, maps elt to * the appropriate bin via simple division. */ pos = static_cast((elt - leftmost_edge[dim]) * (num_bin_edges[dim] - 1) / (rightmost_edge[dim] - leftmost_edge[dim])); /* Ensures consistency with bin_edges by checking the bins to the left and right * of the selected position. Necessary for cases in which an element very close * to a bin edge may be misclassified by simple division. */ if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) { int64_t pos_min = std::max(static_cast(0), pos - 1); int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]); pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt) - bin_seq[dim] - 1; } } else { TORCH_INTERNAL_ASSERT(false); } // Unlike other bins, the rightmost bin includes its right boundary if (pos == (num_bin_edges[dim] - 1)) { pos -= 1; } hist_index += hist_strides[dim] * pos; } if (!skip_elt) { // In the unweighted case, the default weight is 1 input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast(1); hist_local_data[hist_index] += wt; } } }); at::sum_out(hist, thread_histograms, /*dim=*/{0}); } /* Some pre- and post- processing steps for the main algorithm. * Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary. */ template void histogramdd_out_cpu_template(const Tensor& self, const std::optional& weight, bool density, Tensor& hist, const TensorList& bin_edges) { hist.fill_(0); const int64_t N = self.size(-1); const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1, (int64_t)1, std::multiplies()); const Tensor reshaped_input = self.reshape({M, N}); const auto reshaped_weight = weight.has_value() ? std::optional(weight.value().reshape({M})) : std::optional(); std::vector bin_edges_contig(bin_edges.size()); for (const auto dim : c10::irange(bin_edges_contig.size())) { bin_edges_contig[dim] = bin_edges[dim].contiguous(); } AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "histogram_cpu", [&]() { histogramdd_cpu_contiguous( hist, bin_edges_contig, reshaped_input, reshaped_weight); }); /* Divides each bin's value by the total count/weight in all bins, * and by the bin's volume. */ if (density) { const auto hist_sum = hist.sum().item(); hist.div_(hist_sum); /* For each dimension, divides each bin's value * by the bin's length in that dimension. */ for (const auto dim : c10::irange(N)) { const auto bin_lengths = bin_edges[dim].diff(); // Used to reshape bin_lengths to align with the corresponding dimension of hist. std::vector shape(N, 1); shape[dim] = bin_lengths.numel(); hist.div_(bin_lengths.reshape(shape)); } } } /* The general implementation of the histogram kernel. Maps each element of the input tensor * to its corresponding bin by performing a binary search over the elements of bin_edges. * * Refer to histogramdd_out_cpu_template for more details. */ static void histogramdd_kernel_impl(const Tensor& self, const std::optional& weight, bool density, Tensor& hist, const TensorList& bin_edges) { histogramdd_out_cpu_template(self, weight, density, hist, bin_edges); } /* A faster version of the histogram kernel for cases in which bin_edges are known * to form a linear progression. * * Refer to histogramdd_out_cpu_template for more details. */ static void histogramdd_linear_kernel_impl(const Tensor& self, const std::optional& weight, bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) { if (local_search) { // histogramdd codepath: both hist and bin_edges are eventually returned as output, // so we'll keep them consistent histogramdd_out_cpu_template( self, weight, density, hist, bin_edges); } else { // histc codepath: bin_edges are not returned to the caller histogramdd_out_cpu_template( self, weight, density, hist, bin_edges); } } template void infer_bin_edges_from_input(const Tensor& input, const int64_t N, std::vector &leftmost_edges, std::vector &rightmost_edges) { // Calls aminmax on input with dim=0, reducing all but the innermost dimension of input. auto [min, max] = aminmax(input, 0); TORCH_INTERNAL_ASSERT(min.is_contiguous() && max.is_contiguous()); const scalar_t *min_data = min.const_data_ptr(); std::copy(min_data, min_data + N, leftmost_edges.begin()); const scalar_t *max_data = max.const_data_ptr(); std::copy(max_data, max_data + N, rightmost_edges.begin()); } static void histogram_select_outer_bin_edges_impl(const Tensor& input, const int64_t N, std::vector &leftmost_edges, std::vector &rightmost_edges) { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() { infer_bin_edges_from_input(input, N, leftmost_edges, rightmost_edges); }); } } // namespace REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl); REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl); REGISTER_DISPATCH(histogram_select_outer_bin_edges_stub, &histogram_select_outer_bin_edges_impl); } // namespace at::native