• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/cuda/detail/TensorInfo.cuh>
4 #include <ATen/cuda/CUDAApplyUtils.cuh>
5 #include <ATen/native/cuda/thread_constants.h>
6 #include <c10/macros/Macros.h>
7 
8 namespace at::native::apply {
9 
10 using at::cuda::detail::TensorInfo;
11 using indexT = int64_t;
12 
13 template <typename IndexType, typename Real, typename Op>
applyOp2(Op op,IndexType blockSize,TensorInfo<Real,IndexType> values1,IndexType idx1,TensorInfo<Real,IndexType> values2,IndexType idx2)14 __device__ void applyOp2(
15     Op op, IndexType blockSize,
16     TensorInfo<Real, IndexType> values1, IndexType idx1,
17     TensorInfo<Real, IndexType> values2, IndexType idx2) {
18   for (IndexType k = blockIdx.x * blockDim.x + threadIdx.x;
19        k < blockSize;
20        k += gridDim.x * blockDim.x) {
21     op(values1.data + idx1 * blockSize + k, values2.data + idx2 * blockSize + k);
22   }
23 }
24 
25 template <typename IndexType, typename Real, typename Op>
applyOp3(Op op,IndexType blockSize,TensorInfo<Real,IndexType> values1,IndexType idx1,TensorInfo<Real,IndexType> values2,IndexType idx2,TensorInfo<Real,IndexType> values3,IndexType idx3)26 __device__ void applyOp3(
27     Op op, IndexType blockSize,
28     TensorInfo<Real, IndexType> values1, IndexType idx1,
29     TensorInfo<Real, IndexType> values2, IndexType idx2,
30     TensorInfo<Real, IndexType> values3, IndexType idx3) {
31   for (IndexType k = blockIdx.x * blockDim.x + threadIdx.x;
32        k < blockSize;
33        k += gridDim.x * blockDim.x) {
34     op(values1.data + idx1 * blockSize + k,
35        values2.data + idx2 * blockSize + k,
36        values3.data + idx3 * blockSize + k);
37   }
38 }
39 
40 // Assume both dense and values are contiguous.
41 // Currently only used in add_out_dense_sparse_cuda: add(dense, sparse, scalar).
42 template <typename Op, typename IndexType, typename Real>
43 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize (),cuda::getApplyBlocksPerSM ())44 C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
45 #endif
46 __global__ void sparseElementwiseKernel(
47     Op op,
48     TensorInfo<Real, IndexType> dense,
49     TensorInfo<indexT, IndexType> indices,
50     TensorInfo<Real, IndexType> values,
51     const IndexType nnz) {
52   IndexType ind_skip = indices.strides[0];
53   IndexType ind_nnz_skip = indices.strides[1];
54   IndexType value_size = values.strides[0];  // numel of each slice in values
55   for (IndexType linearId = blockIdx.x;
56        linearId < nnz;
57        linearId += gridDim.x) {
58     IndexType index = 0;
59     for (IndexType d = 0; d < indices.sizes[0]; d++) {
60       index = dense.sizes[d] * index + indices.data[d * ind_skip + linearId * ind_nnz_skip];
61     }
62     Real *dst = dense.data + index * value_size;
63     Real *src = values.data + linearId * value_size;
64     for (IndexType linearId2 = threadIdx.x; linearId2 < value_size; linearId2 += blockDim.x) {
65       op(dst + linearId2, src + linearId2);
66     }
67   }
68 }
69 
70 // Assume dense is contiguous.
71 // Currently only used in add_out_dense_sparse_cuda: add(dense, sparse, scalar).
72 template <typename Op, typename IndexType, typename Real>
73 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize (),cuda::getApplyBlocksPerSM ())74 C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
75 #endif
76 __global__ void sparseElementwiseKernelScalar(
77     Op op,
78     TensorInfo<Real, IndexType> dense,
79     TensorInfo<indexT, IndexType> indices,
80     TensorInfo<Real, IndexType> values,
81     const IndexType nnz) {
82   IndexType ind_skip = indices.strides[0];
83   IndexType ind_nnz_skip = indices.strides[1];
84   IndexType value_skip = values.strides[0];
85   for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
86        linearId < nnz;
87        linearId += gridDim.x * blockDim.x) {
88     IndexType index = 0;
89     for (IndexType d = 0; d < indices.sizes[0]; d++) {
90       index = dense.sizes[d] * index + indices.data[d * ind_skip + linearId * ind_nnz_skip];
91     }
92     op(dense.data + index, values.data + linearId * value_skip);
93   }
94 }
95 
96 template <typename OpBoth, typename OpLeft, typename OpRight, typename IndexType, typename Real>
97 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize (),cuda::getApplyBlocksPerSM ())98 C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
99 #endif
100 __global__ void valueSparseUnionKernel(
101     OpBoth opBoth,
102     OpLeft opLeft,
103     OpRight opRight,
104     TensorInfo<indexT, IndexType> r_indices,
105     TensorInfo<indexT, IndexType> t_indices,
106     TensorInfo<indexT, IndexType> s_indices,
107     TensorInfo<Real, IndexType> r_values,
108     TensorInfo<Real, IndexType> t_values,
109     TensorInfo<Real, IndexType> s_values,
110     const IndexType t_nnz, const IndexType s_nnz) {
111   IndexType t_indskip = t_indices.strides[0];
112   IndexType s_indskip = s_indices.strides[0];
113   int64_t cmp, d;
114   int64_t nDimI = r_indices.sizes[0];
115   IndexType valueSize = r_values.strides[0];
116   IndexType r_i = 0, t_i = 0, s_i = 0;
117   while (t_i < t_nnz || s_i < s_nnz) {
118     if (t_i >= t_nnz) {
119       cmp = -1;
120     } else if (s_i >= s_nnz) {
121       cmp = 1;
122     } else {
123       cmp = 0;
124       for (d = 0; d < nDimI; d++) {
125         if (t_indices.data[d * t_indskip + t_i] < s_indices.data[d * s_indskip + s_i]) {
126           cmp = 1;
127           break;
128         }
129         if (t_indices.data[d * t_indskip + t_i] > s_indices.data[d * s_indskip + s_i]) {
130           cmp = -1;
131           break;
132         }
133       }
134     }
135     if (cmp == 0) applyOp3(opBoth, valueSize, r_values, r_i, t_values, t_i++, s_values, s_i++);
136     else if (cmp > 0) applyOp2(opLeft, valueSize, r_values, r_i, t_values, t_i++);
137     else if (cmp < 0) applyOp2(opRight, valueSize, r_values, r_i, s_values, s_i++);
138     r_i++;
139   }
140 }
141 
142 // TODO find a way to parallelize this...
143 template <typename IndexType, typename Real>
144 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize (),cuda::getApplyBlocksPerSM ())145 C10_LAUNCH_BOUNDS_2(cuda::getApplyBlockSize(), cuda::getApplyBlocksPerSM())
146 #endif
147 __global__ void indexSparseUnionKernel(
148     TensorInfo<indexT, IndexType> r_indices,
149     TensorInfo<indexT, IndexType> t_indices,
150     TensorInfo<indexT, IndexType> s_indices,
151     const IndexType t_nnz, const IndexType s_nnz, IndexType *resultNnz) {
152   IndexType r_indskip = r_indices.strides[0];
153   IndexType t_indskip = t_indices.strides[0];
154   IndexType s_indskip = s_indices.strides[0];
155   int64_t cmp, d;
156   int64_t nDimI = r_indices.sizes[0];
157   IndexType r_i = 0, t_i = 0, s_i = 0;
158   while (t_i < t_nnz || s_i < s_nnz) {
159     if (t_i >= t_nnz) {
160       cmp = -1;
161     } else if (s_i >= s_nnz) {
162       cmp = 1;
163     } else {
164       cmp = 0;
165       for (d = 0; d < nDimI; d++) {
166         if (t_indices.data[d * t_indskip + t_i] < s_indices.data[d * s_indskip + s_i]) {
167           cmp = 1;
168           break;
169         }
170         if (t_indices.data[d * t_indskip + t_i] > s_indices.data[d * s_indskip + s_i]) {
171           cmp = -1;
172           break;
173         }
174       }
175     }
176     if (cmp >= 0) {
177       for (d = 0; d < nDimI; d++) {
178         r_indices.data[d * r_indskip + r_i] = t_indices.data[d * t_indskip + t_i];
179       }
180       t_i++;
181     }
182     if (cmp <= 0) {
183       for (d = 0; d < nDimI; d++) {
184         r_indices.data[d * r_indskip + r_i] = s_indices.data[d * s_indskip + s_i];
185       }
186       s_i++;
187     }
188     r_i++;
189   }
190   *resultNnz = r_i;
191 }
192 
193 
194 template <typename Dtype, typename Acctype>
C10_LAUNCH_BOUNDS_1(num_threads ())195 C10_LAUNCH_BOUNDS_1(num_threads())
196 __global__ void coalesceValuesKernel(
197   int64_t *segment_offsets, int64_t *value_indices,
198   Dtype *values, Dtype *newValues,
199   int64_t nnz, int64_t newNnz, int64_t stride) {
200 
201   int seg = blockIdx.x * 4 + threadIdx.y;
202 
203   // Number of values processed by each thread (grain size)
204   const int SZ = 4;
205 
206   if (seg < newNnz) {
207     const int newValueRow = seg * stride;
208     const int begin = segment_offsets[seg];
209     const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
210     const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
211     Acctype tmp[SZ];
212     #pragma unroll
213     for (int ii = 0; ii < SZ; ii++) {
214       tmp[ii] = 0;
215     }
216     for (int row = begin; row < end; row++) {
217       const int valueRow = ((int) value_indices[row]) * stride;
218 
219       #pragma unroll
220       for (int ii = 0; ii < SZ; ii++)
221       {
222         int featureDim = startFeature + ii * C10_WARP_SIZE;
223         if (featureDim < stride)
224         {
225           tmp[ii] += static_cast<Acctype>(values[valueRow + featureDim]);
226         }
227       }
228     }
229     #pragma unroll
230     for (int ii = 0; ii < SZ; ii++)
231     {
232       int featureDim = startFeature + ii * C10_WARP_SIZE;
233       if (featureDim < stride)
234       {
235         newValues[newValueRow + featureDim] = static_cast<Dtype>(tmp[ii]);
236       }
237     }
238   }
239 }
240 
241 // coalesceValuesKernel when Dtype/Acctype is bool. Can be eliminated using
242 // `if constexpr` when CUDA codes will be compiled under C++-17, see
243 // gh-56055 for blockers.
244 template<typename Dtype>
245 C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4)
coalesceValuesKernel(int64_t * segment_offsets,int64_t * value_indices,bool * values,bool * newValues,int64_t nnz,int64_t newNnz,int64_t stride)246 __global__ void coalesceValuesKernel(
247   int64_t *segment_offsets, int64_t *value_indices,
248   bool *values, bool *newValues,
249   int64_t nnz, int64_t newNnz, int64_t stride) {
250 
251   int seg = blockIdx.x * 4 + threadIdx.y;
252 
253   // Number of values processed by each thread (grain size)
254   const int SZ = 4;
255 
256   if (seg < newNnz) {
257     const int newValueRow = seg * stride;
258     const int begin = segment_offsets[seg];
259     const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
260     const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
261     bool tmp[SZ];
262     #pragma unroll
263     for (int ii = 0; ii < SZ; ii++) {
264       tmp[ii] = 0;
265     }
266     for (int row = begin; row < end; row++) {
267       const int valueRow = ((int) value_indices[row]) * stride;
268 
269       #pragma unroll
270       for (int ii = 0; ii < SZ; ii++)
271       {
272         int featureDim = startFeature + ii * C10_WARP_SIZE;
273         if (featureDim < stride)
274         {
275           tmp[ii] |= values[valueRow + featureDim];
276         }
277       }
278     }
279     #pragma unroll
280     for (int ii = 0; ii < SZ; ii++)
281     {
282       int featureDim = startFeature + ii * C10_WARP_SIZE;
283       if (featureDim < stride)
284       {
285         newValues[newValueRow + featureDim] = tmp[ii];
286       }
287     }
288   }
289 }
290 
291 } // namespace at::native::apply
292