• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/ThrustAllocator.h>
8 #include <ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh>
9 #include <ATen/native/cuda/SortingCommon.cuh>
10 #include <ATen/native/NonSymbolicBC.h>
11 #include <ATen/native/SparseTensorUtils.h>
12 #include <c10/macros/Macros.h>
13 #include <c10/util/accumulate.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/_coalesce_native.h>
20 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/zeros.h>
23 #endif
24 
25 #include <thrust/device_ptr.h>
26 #include <thrust/device_vector.h>
27 #include <thrust/gather.h>
28 #include <thrust/generate.h>
29 #include <thrust/scan.h>
30 #include <thrust/sequence.h>
31 #include <thrust/sort.h>
32 #include <thrust/system/cuda/execution_policy.h>
33 #include <thrust/transform.h>
34 #include <thrust/unique.h>
35 #include <thrust/system/cuda/execution_policy.h>
36 #include <thrust/binary_search.h>
37 #include <c10/macros/Macros.h>
38 
39 namespace at::native {
40 
41 using namespace at::sparse;
42 
_coalesce_sparse_cuda(const SparseTensor & self)43 SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
44   int64_t nnz = self._nnz();
45   TORCH_INTERNAL_ASSERT(!self.is_coalesced());
46   // NOTE: Since `coalesce` is not an in-place operation when `is_coalesced` is false,
47   // we should keep the original tensor intact and do coalesce on a copy of the tensor
48   if (nnz < 2) {
49     SparseTensor dst = self.clone();
50     dst._coalesced_(true);
51     return dst;
52   }
53 
54   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
55   at::cuda::ThrustAllocator allocator;
56   auto policy = thrust::cuda::par(allocator).on(stream);
57   // Replace instances with
58 
59   // For indices, a simple sort + unique suffices
60   // For values, we use a custom kernel for segmented reduction (can't use Thrust due to indirection).
61 
62   Tensor values = self._values();
63 
64   int64_t sparse_dim = self.sparse_dim();
65 
66   // indices will be modified by Thrust, so we have to clone or use new storage
67   // here.
68   Tensor indices1D = flatten_indices(self._indices(), self.sizes(), true);
69 
70   Tensor origIndices = at::empty({nnz}, self._indices().options());
71   Tensor uniqueOffsets = at::empty({nnz}, self._indices().options());
72 
73   typedef thrust::device_ptr<int64_t> thrust_ptr;
74   thrust_ptr indicesIter(indices1D.data_ptr<int64_t>());
75   thrust_ptr origIndicesIter(origIndices.data_ptr<int64_t>());
76   thrust_ptr uniqueOffsetsIter(uniqueOffsets.data_ptr<int64_t>());
77 
78 
79   // Fill sortedOrigIndices with sequential indices
80   thrust::counting_iterator<int64_t> countIterI(0);
81   thrust::counting_iterator<int64_t> countIterO(0);
82 
83   thrust::copy(policy, countIterI, countIterI + nnz, origIndicesIter);
84   thrust::copy(policy, countIterO, countIterO + nnz, uniqueOffsetsIter);
85 
86   thrust::sort_by_key(policy,
87     indicesIter, indicesIter + nnz,
88     origIndicesIter, LTOp<int64_t>()
89   );
90 
91   // this forces device-host synchronization!
92   thrust::pair<thrust_ptr, thrust_ptr> newEnd = thrust::unique_by_key(policy,
93     indicesIter, indicesIter + nnz,
94     uniqueOffsetsIter
95   );
96   int64_t newNnz = newEnd.first - indicesIter;
97 
98   indices1D.resize_({1, newNnz});
99   auto newValues_size = values.sizes().vec();
100   newValues_size[0] = newNnz;
101   Tensor newValues = at::empty(newValues_size, values.options());
102 
103   // If there is no values to copy, save running the kernel.
104   if (newValues.numel() > 0) {
105     const int SZ = 4;
106     values = values.contiguous();
107     int64_t stride = c10::multiply_integers(values.sizes().slice(1));
108     int warp_size = at::cuda::warp_size();
109     dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ));
110     dim3 block(warp_size, SZ);
111     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
112       at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
113       values.scalar_type(), "coalesce_sparse_cuda", [&] {
114         using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
115         apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
116           uniqueOffsets.data_ptr<int64_t>(),
117           origIndices.data_ptr<int64_t>(),
118           values.data_ptr<scalar_t>(),
119           newValues.data_ptr<scalar_t>(),
120           nnz,
121           newNnz,
122           stride
123         );
124         C10_CUDA_KERNEL_LAUNCH_CHECK();
125       });
126   }
127 
128 // this grid-strided version is slower but probably more flexible
129   // to different sizes
130   // int64_t blockX = min(stride, (int64_t) 512);
131   // dim3 block(blockX, 512 / blockX);
132   // int64_t grid = min((int64_t) 1024, ceil_div((int64_t) newNnz * stride, (int64_t) block.x * block.y));
133   // THCSTensor_coalesceValuesKernel_gridStrided<real, accreal><<<grid, block, 0, stream> >>(
134   //   THCIndexTensor_(data)(state, uniqueOffsets),
135   //   THCIndexTensor_(data)(state, origIndices),
136   //   THCTensor_(data)(state, values),
137   //   THCTensor_(data)(state, newValues),
138   //   nnz,
139   //   newNnz,
140   //   stride
141   // );
142   // C10_CUDA_KERNEL_LAUNCH_CHECK();
143 
144   ////////////////////////////////////////////////////////////
145   // unflatten indices if necessary
146   Tensor newIndices;
147   if (sparse_dim == 1) {
148     newIndices = indices1D;
149   } else {
150     newIndices = at::empty({sparse_dim, newNnz}, origIndices.options());
151     for (int64_t d = sparse_dim - 1; d >= 0; d--) {
152       // NB: Not a select, so I can preserve the outer dimension
153       Tensor indicesSlice = newIndices.narrow(0, d, 1);
154       indicesSlice.copy_(indices1D);
155       indices1D.divide_(self.size(d), "trunc");
156       indicesSlice.add_(indices1D, -self.size(d));
157     }
158   }
159   ////////////////////////////////////////////////////////////
160   // We can use unsafe sparse tensor constructor because the indices do not
161   // need to be revalidated as we do not add or change indices, just remove
162   // duplicates.
163   SparseTensor dst = ::at::native::_sparse_coo_tensor_unsafe(newIndices, newValues, self.sizes())._coalesced_(true);
164 
165   AT_CUDA_CHECK(cudaGetLastError());
166   return dst;
167 }
168 
169 } // namespace at::native
170