• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #if GOOGLE_CUDA
22 #include "third_party/cub/device/device_histogram.cuh"
23 #include "third_party/cub/iterator/counting_input_iterator.cuh"
24 #include "third_party/cub/iterator/transform_input_iterator.cuh"
25 #include "third_party/gpus/cuda/include/cusparse.h"
26 #elif TENSORFLOW_USE_ROCM
27 #include "rocm/include/hipcub/hipcub.hpp"
28 #endif
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/kernels/cuda_sparse.h"
32 #include "tensorflow/core/kernels/gpu_device_array.h"
33 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
34 #include "tensorflow/core/kernels/sparse/kernels.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/gpu_kernel_helper.h"
38 
39 #if GOOGLE_CUDA
40 namespace gpuprim = ::cub;
41 #elif TENSORFLOW_USE_ROCM
42 namespace gpuprim = ::hipcub;
43 #endif
44 
45 namespace tensorflow {
46 
47 typedef Eigen::GpuDevice GPUDevice;
48 
49 namespace functor {
50 
51 namespace {
52 struct StridedDataReader {
StridedDataReadertensorflow::functor::__anona7216e310111::StridedDataReader53   StridedDataReader(const int64* begin, int stride)
54       : begin_(begin), stride_(stride) {}
55 
operator ()tensorflow::functor::__anona7216e310111::StridedDataReader56   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
57     return static_cast<int>(ldg(begin_ + idx * stride_));
58   }
59 
60   const int64* begin_;
61   const int stride_;
62 };
63 }  // namespace
64 
65 template <>
operator ()(OpKernelContext * c,TTypes<int64>::ConstMatrix indices,TTypes<int32>::Vec nnz_per_batch)66 Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
67     OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
68     TTypes<int32>::Vec nnz_per_batch) {
69   const auto& cu_stream = GetGpuStream(c);
70 
71   const int total_nnz = indices.dimension(0);
72   const int size = nnz_per_batch.size();
73 
74   DCHECK_EQ(indices.rank(), 2);
75   DCHECK_EQ(indices.dimension(1), 3);  // batch, row, col
76 
77   const int rank = indices.dimension(1);
78   gpuprim::CountingInputIterator<int> row_counter(0);
79   gpuprim::TransformInputIterator<int, StridedDataReader,
80                                   gpuprim::CountingInputIterator<int>>
81       indices_first_column(row_counter,
82                            StridedDataReader(indices.data(), rank));
83 
84   std::size_t temp_storage_bytes = 0;
85 
86   DCHECK_NE(indices.data(), nullptr);
87   DCHECK_NE(nnz_per_batch.data(), nullptr);
88 
89   auto first_success = gpuprim::DeviceHistogram::HistogramEven(
90       /*d_temp_storage*/ nullptr,
91       /*temp_storage_bytes&*/ temp_storage_bytes,
92       /*d_samples*/ indices_first_column,
93       /*d_histogram*/ nnz_per_batch.data(),
94       /*num_levels*/ size + 1,
95       /*lower_level*/ 0,
96       /*upper_level*/ size,
97       /*num_samples*/ total_nnz,
98       /*stream*/ cu_stream);
99 
100   if (first_success != gpuSuccess) {
101     return errors::Internal(
102         "SparseTensorToCSRSparseMatrix: Could not launch "
103         "gpuprim::DeviceHistogram::HistogramEven "
104         "to calculate temp_storage_bytes, status: ",
105         GpuGetErrorString(first_success));
106   }
107 
108   Tensor temp_storage;
109   TF_RETURN_IF_ERROR(c->allocate_temp(
110       DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
111       &temp_storage));
112   DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
113   auto second_success = gpuprim::DeviceHistogram::HistogramEven(
114       /*d_temp_storage*/ temp_storage.flat<int8>().data(),
115       /*temp_storage_bytes&*/ temp_storage_bytes,
116       /*d_samples*/ indices_first_column,
117       /*d_histogram*/ nnz_per_batch.data(),
118       /*num_levels*/ size + 1,
119       /*lower_level*/ 0,
120       /*upper_level*/ size,
121       /*num_samples*/ total_nnz,
122       /*stream*/ cu_stream);
123 
124   if (second_success != gpuSuccess) {
125     return errors::Internal(
126         "SparseTensorToCSRSparseMatrix: Could not launch "
127         "gpuprim::DeviceHistogram::HistogramEven "
128         "to count nnz entries per batch.  temp_storage_bytes: ",
129         temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
130   }
131 
132   return Status::OK();
133 }
134 
135 // TODO(ebrevdo): Write a custom batch-friendly impl of this to update
136 // the SparseTensor indices directly.
137 template <>
operator ()(OpKernelContext * c,TTypes<const int>::UnalignedVec csr_row_ptr,TTypes<int>::UnalignedVec coo_row_ind)138 Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
139     OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
140     TTypes<int>::UnalignedVec coo_row_ind) {
141   GpuSparse gpu_sparse(c);
142   const int nnz = coo_row_ind.size();
143   TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
144   const int m = csr_row_ptr.size() - 1;  // rows
145   return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
146 }
147 
148 template <int stride>
SparseTensorToCOOMatrixKernel(const int64 * indices,int * coo_rows_out,int * coo_cols_out,int size)149 __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
150                                               int* coo_rows_out,
151                                               int* coo_cols_out, int size) {
152   const int offset = (stride == 3) ? 1 : 0;
153   GPU_1D_KERNEL_LOOP(i, size) {
154     coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
155     coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
156   }
157 }
158 
159 template <>
operator ()(const GPUDevice & d,TTypes<int64>::ConstVec host_dense_shape,TTypes<int64>::ConstMatrix indices,TTypes<int>::Vec coo_row_ind,TTypes<int>::Vec coo_col_ind)160 void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
161     const GPUDevice& d, TTypes<int64>::ConstVec host_dense_shape,
162     TTypes<int64>::ConstMatrix indices, TTypes<int>::Vec coo_row_ind,
163     TTypes<int>::Vec coo_col_ind) {
164   const int stride = host_dense_shape.size();
165   DCHECK(stride == 2 || stride == 3);
166   DCHECK_EQ(stride, indices.dimension(1));
167   const int size = coo_row_ind.dimension(0);
168   GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
169   if (stride == 2) {
170     TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
171                                 config.block_count, config.thread_per_block, 0,
172                                 d.stream(), indices.data(), coo_row_ind.data(),
173                                 coo_col_ind.data(), size));
174   } else {
175     TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
176                                 config.block_count, config.thread_per_block, 0,
177                                 d.stream(), indices.data(), coo_row_ind.data(),
178                                 coo_col_ind.data(), size));
179   }
180 }
181 
COOMatrixToSparseTensorKernel2D(const int * coo_rows,const int * coo_cols,int64 * indices_out,int size)182 __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
183                                                 const int* coo_cols,
184                                                 int64* indices_out, int size) {
185   GPU_1D_KERNEL_LOOP(i, size) {
186     indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
187     indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
188   }
189 }
190 
BinarySearchRange(int * range,int n,int x)191 __device__ inline int BinarySearchRange(int* range, int n, int x) {
192   int left = 0;
193   int right = n - 1;
194   while (left < right) {
195     int mid = left + (right - left) / 2;
196     if (x < range[mid])
197       right = mid - 1;
198     else if (range[mid + 1] <= x)
199       left = mid + 1;
200     else
201       return mid;  // range[mid] <= x < range[mid + 1].
202   }
203   return left;
204 }
205 
COOMatrixToSparseTensorKernel3D(const int * coo_rows,const int * coo_cols,int64 * indices_out,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int size)206 __global__ void COOMatrixToSparseTensorKernel3D(
207     const int* coo_rows, const int* coo_cols, int64* indices_out,
208     GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
209     const int size) {
210   // Step 1: access the batch ptrs and copy to shared memory.
211   const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
212   extern __shared__ int local_batch_ptr[];
213   for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
214     local_batch_ptr[i] = batch_ptr[i];
215   }
216   __syncthreads();
217 
218   GPU_1D_KERNEL_LOOP(i, size) {
219     // TODO(ebrevdo): Consider special casing batch_size <= 3,
220     // alternatively doing linear instead of binary search.  Requires
221     // some benchmarks.
222     const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
223     indices_out[i * 3] = static_cast<int64>(b);
224     indices_out[i * 3 + 1] = static_cast<int64>(ldg(coo_rows + i));
225     indices_out[i * 3 + 2] = static_cast<int64>(ldg(coo_cols + i));
226   }
227 }
228 
229 template <>
operator ()(OpKernelContext * c,TTypes<int64>::ConstVec host_dense_shape,TTypes<int>::ConstVec host_batch_ptr,TTypes<int>::Vec coo_row_ind,TTypes<int>::ConstVec coo_col_ind,TTypes<int64>::Matrix indices)230 Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
231     OpKernelContext* c, TTypes<int64>::ConstVec host_dense_shape,
232     TTypes<int>::ConstVec host_batch_ptr, TTypes<int>::Vec coo_row_ind,
233     TTypes<int>::ConstVec coo_col_ind, TTypes<int64>::Matrix indices) {
234   const int ndims = indices.dimension(1);
235   DCHECK(ndims == 2 || ndims == 3);
236   DCHECK_EQ(ndims, host_dense_shape.size());
237   DCHECK_NE(coo_row_ind.data(), nullptr);
238   DCHECK_NE(coo_col_ind.data(), nullptr);
239   DCHECK_NE(indices.data(), nullptr);
240   const GPUDevice& d = c->eigen_device<GPUDevice>();
241   const int size = coo_row_ind.size();
242   DCHECK_EQ(size, coo_col_ind.size());
243   DCHECK_EQ(size, indices.dimension(0));
244   if (ndims == 2) {
245     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
246     TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
247                                 config.block_count, config.thread_per_block, 0,
248                                 d.stream(), coo_row_ind.data(),
249                                 coo_col_ind.data(), indices.data(), size));
250     return Status::OK();
251   } else {
252     const int batch_size = host_dense_shape(0);
253     GpuDeviceArrayOnHost<int> batch_ptr_copy(c, host_batch_ptr.size());
254     TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
255     for (int i = 0; i < batch_size; ++i) {
256       batch_ptr_copy.Set(i, host_batch_ptr(i));
257     }
258     TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
259     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
260     // shared memory stores the batch pointers.
261     const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
262     TF_CHECK_OK(
263         GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
264                         config.thread_per_block, shared_memory_size, d.stream(),
265                         coo_row_ind.data(), coo_col_ind.data(), indices.data(),
266                         batch_ptr_copy.data(), batch_size, size));
267     return Status::OK();
268   }
269 }
270 
271 template <typename T>
CSRSparseMatrixBatchMulVecKernel3D(const T * a_values,const T * b_batch_values,T * c_values,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int total_nnz)272 __global__ void CSRSparseMatrixBatchMulVecKernel3D(
273     const T* a_values, const T* b_batch_values, T* c_values,
274     GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
275     const int total_nnz) {
276   // Step 1: Access the batch ptrs and copy to shared memory.
277   //         Also copy the per-batch multipliers into shared memory.
278   const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
279   extern __shared__ int local_batch_ptr[];
280   T* local_batch_values =
281       reinterpret_cast<T*>(local_batch_ptr + batch_size + 1);
282   for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
283     local_batch_ptr[i] = batch_ptr[i];
284     if (i < batch_size) {
285       local_batch_values[i] = b_batch_values[i];
286     }
287   }
288   __syncthreads();
289 
290   GPU_1D_KERNEL_LOOP(i, total_nnz) {
291     const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
292     c_values[i] = ldg(a_values + i) * local_batch_values[b];
293   }
294 }
295 
296 template <typename T>
CSRSparseMatrixBatchMulVecImpl(OpKernelContext * ctx,const CSRSparseMatrix & a,typename TTypes<T>::ConstFlat b,CSRSparseMatrix * c)297 Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
298                                       const CSRSparseMatrix& a,
299                                       typename TTypes<T>::ConstFlat b,
300                                       CSRSparseMatrix* c) {
301   DCHECK_EQ(a.dims(), 3);
302   const int total_nnz = a.total_nnz();
303   Tensor c_values_t;
304   TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
305                                         TensorShape({total_nnz}), &c_values_t));
306   TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
307       DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
308       a.row_pointers(), a.col_indices(), c_values_t, c));
309 
310   auto a_values = a.values().flat<T>();
311   auto c_values = c_values_t.flat<T>();
312 
313   auto host_dense_shape = a.dense_shape().vec<int64>();
314   auto host_batch_ptr = a.batch_pointers().vec<int>();
315 
316   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
317 
318   const int batch_size = host_dense_shape(0);
319   DCHECK_EQ(b.size(), batch_size);
320 
321   GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
322   TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
323   for (int i = 0; i < batch_size; ++i) {
324     batch_ptr_copy.Set(i, host_batch_ptr(i));
325   }
326   TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
327   GpuLaunchConfig config = GetGpuLaunchConfig(total_nnz, d);
328   // shared memory stores the batch pointers.
329   const size_t shared_memory_size =
330       (sizeof(int) * (batch_size + 1)  // local batch_pointers.
331        + sizeof(T) * batch_size);      // local copy of b.
332   TF_CHECK_OK(GpuLaunchKernel(
333       CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
334       config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
335       b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
336 
337   return Status::OK();
338 }
339 
340 #define DEFINE_SPARSE_MUL_VEC_GPU(T)                                        \
341   template <>                                                               \
342   CSRSparseMatrixBatchMulVec<GPUDevice, T>::CSRSparseMatrixBatchMulVec() {} \
343   template <>                                                               \
344   Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute(                 \
345       OpKernelContext* ctx, const CSRSparseMatrix& a,                       \
346       typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c) {                \
347     return CSRSparseMatrixBatchMulVecImpl<T>(ctx, a, b, c);                 \
348   }
349 
350 DEFINE_SPARSE_MUL_VEC_GPU(float);
351 DEFINE_SPARSE_MUL_VEC_GPU(double);
352 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<float>);
353 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<double>);
354 
355 #undef DEFINE_SPARSE_MUL_VEC_GPU
356 
357 template <typename T>
CalculateRowSoftmax(const int begin,const int end,const T * logits,T * softmax)358 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmax(const int begin,
359                                                                const int end,
360                                                                const T* logits,
361                                                                T* softmax) {
362   // For each row, calculate the vector:
363   //   softmax[row] = exp(shifted_logits[row]) / sum(exp(shifted_logits[row]))
364   // where
365   //   shifted_logits[row] = logits[row] - max(logits[row])
366   // are the logits normalized for stability.
367   T row_max = Eigen::NumTraits<T>::lowest();
368   for (int r_i = begin; r_i < end; ++r_i) {
369     row_max = Eigen::numext::maxi(row_max, ldg(logits + r_i));
370   }
371   T sum_exp = 0;
372   for (int r_i = begin; r_i < end; ++r_i) {
373     const T exp_i = Eigen::numext::exp(ldg(logits + r_i) - row_max);
374     softmax[r_i] = exp_i;
375     sum_exp += exp_i;
376   }
377   for (int r_i = begin; r_i < end; ++r_i) {
378     softmax[r_i] = softmax[r_i] / sum_exp;
379   }
380 }
381 
382 template <typename T>
CSRSparseMatrixSoftmaxKernel2D(const int rows,const int * row_ptr,const T * logits,T * softmax)383 __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
384                                                const int* row_ptr,
385                                                const T* logits, T* softmax) {
386   // TODO(ebrevdo): consider something like a merge-path based
387   // algorithm to distribute the work in case the row sizes are
388   // uneven:
389   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
390   GPU_1D_KERNEL_LOOP(row, rows) {
391     CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
392                         softmax);
393   }
394 }
395 
CopyFromGpuDeviceArrayToLocal(GpuDeviceArrayStruct<int> cuda_ptr_s,int * local_ptr,int length)396 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
397     GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
398 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
399   const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
400   for (int i = threadIdx.x; i < length; i += blockDim.x) {
401     local_ptr[i] = cuda_ptr[i];
402   }
403   __syncthreads();
404 #endif
405 }
406 
407 template <typename T>
CSRSparseMatrixSoftmaxKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> batch_ptr_s,const int * row_ptr,const T * logits,T * softmax)408 __global__ void CSRSparseMatrixSoftmaxKernel3D(
409     const int size, const int rows, GpuDeviceArrayStruct<int> batch_ptr_s,
410     const int* row_ptr, const T* logits, T* softmax) {
411   // TODO(ebrevdo): consider something like a merge-path based
412   // algorithm to distribute the work in case the row sizes are
413   // uneven:
414   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
415   const int batch_size = size / rows;
416   extern __shared__ int local_batch_ptr[];
417   CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
418                                 batch_size + 1);
419 
420   GPU_1D_KERNEL_LOOP(i, size) {
421     const int batch = i / rows;
422     const int row = i % rows;
423     const int batch_offset = local_batch_ptr[batch];
424     const int row_offset = batch * (rows + 1) + row;
425     CalculateRowSoftmax(batch_offset + ldg(row_ptr + row_offset),
426                         batch_offset + ldg(row_ptr + row_offset + 1), logits,
427                         softmax);
428   }
429 }
430 
431 template <typename T>
CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & logits,typename TTypes<T>::Vec softmax_values)432 Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
433                                      const CSRSparseMatrix& logits,
434                                      typename TTypes<T>::Vec softmax_values) {
435   auto host_dense_shape = logits.dense_shape().vec<int64>();
436   auto host_batch_ptr = logits.batch_pointers().vec<int32>();
437   auto row_ptr = logits.row_pointers().vec<int32>();
438   auto logits_values = logits.values().vec<T>();
439 
440   const int ndims = host_dense_shape.size();
441   DCHECK(ndims == 2 || ndims == 3);
442   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
443   if (ndims == 2) {
444     const int rows = host_dense_shape(0);
445     DCHECK_EQ(rows, row_ptr.size() - 1);
446     GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
447     TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
448                                 config.block_count, config.thread_per_block, 0,
449                                 d.stream(), rows /*size*/, row_ptr.data(),
450                                 logits_values.data(), softmax_values.data()));
451   } else {
452     const int batch_size = host_dense_shape(0);
453     const int rows = host_dense_shape(1);
454     DCHECK_EQ(batch_size, host_batch_ptr.size() - 1);
455     DCHECK_EQ((rows + 1) * batch_size, row_ptr.size());
456     const int size = rows * batch_size;
457 
458     GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
459     TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
460     for (int i = 0; i < host_batch_ptr.size(); ++i) {
461       batch_ptr_copy.Set(i, host_batch_ptr(i));
462     }
463     TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
464 
465     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
466     // shared memory stores the batch pointers.
467     const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
468     TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
469                                 config.block_count, config.thread_per_block,
470                                 shared_memory_size, d.stream(), size, rows,
471                                 batch_ptr_copy.data(), row_ptr.data(),
472                                 logits_values.data(), softmax_values.data()));
473   }
474 
475   return Status::OK();
476 }
477 
478 #define DEFINE_SOFTMAX_GPU(T)                                             \
479   template <>                                                             \
480   Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()(                \
481       OpKernelContext* ctx, const CSRSparseMatrix& logits,                \
482       typename TTypes<T>::Vec softmax_values) {                           \
483     return CSRSparseMatrixSoftmaxGPUImpl<T>(ctx, logits, softmax_values); \
484   }
485 
486 DEFINE_SOFTMAX_GPU(float);
487 DEFINE_SOFTMAX_GPU(double);
488 
489 #undef DEFINE_SOFTMAX_GPU
490 
491 template <typename T>
CalculateRowSoftmaxGrad(const int softmax_begin,const int softmax_end,const int * softmax_col_ind,const T * softmax,const int grad_softmax_begin,const int grad_softmax_end,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)492 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmaxGrad(
493     const int softmax_begin, const int softmax_end, const int* softmax_col_ind,
494     const T* softmax, const int grad_softmax_begin, const int grad_softmax_end,
495     const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
496   // Iterate from
497   //   softmax_col_ind[softmax_begin] to
498   //   softmax_col_ind[softmax_end]
499   // and from
500   //  grad_softmax_col_ind[grad_softmax_begin] to
501   //  grad_softmax_col_ind[grad_softmax_end]
502   //
503   // looking for for matching indices.  In the softmax indices only, perform:
504   //
505   //   gradient = (grad_softmax - sum(grad_softmax * softmax)) * softmax
506   //
507   // where the sum is along the given row.
508   T sum_prod = 0;
509   for (int i = softmax_begin, j = grad_softmax_begin;
510        i < softmax_end && j < grad_softmax_end;) {
511     const int softmax_col = ldg(softmax_col_ind + i);
512     const int grad_softmax_col = ldg(grad_softmax_col_ind + j);
513     if (softmax_col == grad_softmax_col) {
514       sum_prod += ldg(softmax + i) * ldg(grad_softmax + j);
515       ++i;
516       ++j;
517     } else if (softmax_col > grad_softmax_col) {
518       ++j;
519     } else {
520       ++i;
521     }
522   }
523 
524   // Find an upper bound on the column numbers in this row; for use in
525   // the special case of a empty grad_softmax row and a non-empty
526   // softmax row.
527   const int softmax_col_upper_bound =
528       (softmax_begin == softmax_end)
529           ? -1
530           : ldg(softmax_col_ind + softmax_end - 1) + 1;
531   for (int i = softmax_begin, j = grad_softmax_begin; i < softmax_end;) {
532     const int softmax_col = ldg(softmax_col_ind + i);
533     // We need to keep a large grad_softmax_col value if we're at the
534     // end of the grad_softmax row, so we can fill in the remainder of
535     // the gradients row (the last if branch in this loop).
536     const int grad_softmax_col = (j == grad_softmax_end)
537                                      ? softmax_col_upper_bound
538                                      : ldg(grad_softmax_col_ind + j);
539 
540     if (softmax_col == grad_softmax_col) {
541       gradient[i] = (ldg(grad_softmax + j) - sum_prod) * ldg(softmax + i);
542       ++i;
543       ++j;
544     } else if (softmax_col > grad_softmax_col) {
545       // grad_softmax is nonzero here, but since softmax is zero, the
546       // gradient is 0; so we skip it since the sparsity structure
547       // already encodes this zero.
548       ++j;
549     } else {
550       // grad_softmax is zero but softmax is not.
551       gradient[i] = -sum_prod * ldg(softmax + i);
552       ++i;
553     }
554   }
555 }
556 
557 template <typename T>
CSRSparseMatrixSoftmaxGradKernel2D(const int rows,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)558 __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
559     const int rows, const int* softmax_row_ptr, const int* softmax_col_ind,
560     const T* softmax, const int* grad_softmax_row_ptr,
561     const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
562   // TODO(ebrevdo): consider something like a merge-path based
563   // algorithm to distribute the work in case the row sizes are
564   // uneven:
565   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
566   GPU_1D_KERNEL_LOOP(row, rows) {
567     CalculateRowSoftmaxGrad(
568         ldg(softmax_row_ptr + row) /*softmax_begin*/,
569         ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
570         softmax, ldg(grad_softmax_row_ptr + row) /*grad_softmax_begin*/,
571         ldg(grad_softmax_row_ptr + row + 1) /*grad_softmax_end*/,
572         grad_softmax_col_ind, grad_softmax, gradient);
573   }
574 }
575 
576 template <typename T>
CSRSparseMatrixSoftmaxGradKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)577 __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
578     const int size, const int rows,
579     GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,
580     const int* softmax_row_ptr, const int* softmax_col_ind, const T* softmax,
581     const int* grad_softmax_row_ptr, const int* grad_softmax_col_ind,
582     const T* grad_softmax, T* gradient) {
583   // TODO(ebrevdo): consider something like a merge-path based
584   // algorithm to distribute the work in case the row sizes are
585   // uneven:
586   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
587 
588   const int batch_size = size / rows;
589   extern __shared__ int local_batch_ptr[];
590   CopyFromGpuDeviceArrayToLocal(std::move(softmax_and_grad_batch_ptr_s),
591                                 local_batch_ptr, 2 * (batch_size + 1));
592 
593 #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
594 #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
595 
596   GPU_1D_KERNEL_LOOP(i, size) {
597     const int batch = i / rows;
598     const int row = i % rows;
599     const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
600     const int grad_softmax_batch_offset = GRAD_SOFTMAX_BATCH_PTR(batch);
601     const int row_offset = batch * (rows + 1) + row;
602     CalculateRowSoftmaxGrad(
603         softmax_batch_offset +
604             ldg(softmax_row_ptr + row_offset) /*softmax_begin*/,
605         softmax_batch_offset +
606             ldg(softmax_row_ptr + row_offset + 1) /*softmax_end*/,
607         softmax_col_ind, softmax,
608         grad_softmax_batch_offset +
609             ldg(grad_softmax_row_ptr + row_offset) /*grad_softmax_begin*/,
610         grad_softmax_batch_offset +
611             ldg(grad_softmax_row_ptr + row_offset + 1) /*grad_softmax_end*/,
612         grad_softmax_col_ind, grad_softmax, gradient);
613   }
614 
615 #undef SOFTMAX_BATCH_PTR
616 #undef GRAD_SOFTMAX_BATCH_PTR
617 }
618 
619 template <typename T>
CSRSparseMatrixSoftmaxGradGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & softmax,const CSRSparseMatrix & grad_softmax,typename TTypes<T>::Vec gradient_values)620 Status CSRSparseMatrixSoftmaxGradGPUImpl(
621     OpKernelContext* ctx, const CSRSparseMatrix& softmax,
622     const CSRSparseMatrix& grad_softmax,
623     typename TTypes<T>::Vec gradient_values) {
624   auto host_dense_shape = softmax.dense_shape().vec<int64>();
625   auto softmax_host_batch_ptr = softmax.batch_pointers().vec<int32>();
626   auto softmax_row_ptr = softmax.row_pointers().vec<int32>();
627   auto softmax_col_ind = softmax.col_indices().vec<int32>();
628   auto softmax_values = softmax.values().vec<T>();
629   auto grad_softmax_host_batch_ptr = grad_softmax.batch_pointers().vec<int32>();
630   auto grad_softmax_row_ptr = grad_softmax.row_pointers().vec<int32>();
631   auto grad_softmax_col_ind = grad_softmax.col_indices().vec<int32>();
632   auto grad_softmax_values = grad_softmax.values().vec<T>();
633 
634   const int ndims = host_dense_shape.size();
635   DCHECK(ndims == 2 || ndims == 3);
636   const int rows = host_dense_shape(0);
637   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
638   if (ndims == 2) {
639     DCHECK_EQ(rows + 1, softmax_row_ptr.size());
640     DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
641     GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
642     TF_CHECK_OK(GpuLaunchKernel(
643         CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
644         config.thread_per_block, 0, d.stream(), rows /*size*/,
645         softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
646         grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
647         grad_softmax_values.data(), gradient_values.data()));
648   } else {
649     const int batch_size = host_dense_shape(0);
650     const int rows = host_dense_shape(1);
651     DCHECK_EQ(batch_size, softmax_host_batch_ptr.size() - 1);
652     DCHECK_EQ(batch_size, grad_softmax_host_batch_ptr.size() - 1);
653     DCHECK_EQ((rows + 1) * batch_size, softmax_row_ptr.size());
654     DCHECK_EQ((rows + 1) * batch_size, grad_softmax_row_ptr.size());
655     const int size = rows * batch_size;
656     // The length of softmax_and_grad_batch_ptr_copy is 2 * (batch_size + 1)
657     // The first (batch_size + 1) entries contain softmax_batch_ptr and
658     // the second (batch_size + 1) entries contain grad_softmax_batch_ptr.
659     GpuDeviceArrayOnHost<int> softmax_and_grad_batch_ptr_copy(
660         ctx, 2 * softmax_host_batch_ptr.size());
661     TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Init());
662     for (int i = 0; i < softmax_host_batch_ptr.size(); ++i) {
663       softmax_and_grad_batch_ptr_copy.Set(i, softmax_host_batch_ptr(i));
664       softmax_and_grad_batch_ptr_copy.Set(batch_size + 1 + i,
665                                           grad_softmax_host_batch_ptr(i));
666     }
667     TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Finalize());
668 
669     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
670     // shared memory stores two copies of batch pointers: one for the
671     // softmax CSR matrix, one for the grad_softmax CSR matrix.
672     const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
673     TF_CHECK_OK(GpuLaunchKernel(
674         CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
675         config.thread_per_block, shared_memory_size, d.stream(), size, rows,
676         softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
677         softmax_col_ind.data(), softmax_values.data(),
678         grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
679         grad_softmax_values.data(), gradient_values.data()));
680   }
681 
682   return Status::OK();
683 }
684 
685 #define DEFINE_SOFTMAX_GRAD_GPU(T)                                          \
686   template <>                                                               \
687   Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()(              \
688       OpKernelContext* ctx, const CSRSparseMatrix& softmax,                 \
689       const CSRSparseMatrix& grad_softmax,                                  \
690       typename TTypes<T>::Vec gradient_values) {                            \
691     return CSRSparseMatrixSoftmaxGradGPUImpl<T>(ctx, softmax, grad_softmax, \
692                                                 gradient_values);           \
693   }
694 
695 DEFINE_SOFTMAX_GRAD_GPU(float);
696 DEFINE_SOFTMAX_GRAD_GPU(double);
697 
698 #undef DEFINE_SOFTMAX_GRAD_GPU
699 
700 }  // namespace functor
701 
702 }  // namespace tensorflow
703 
704 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
705