1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #if GOOGLE_CUDA
22 #include "third_party/cub/device/device_histogram.cuh"
23 #include "third_party/cub/iterator/counting_input_iterator.cuh"
24 #include "third_party/cub/iterator/transform_input_iterator.cuh"
25 #include "third_party/gpus/cuda/include/cusparse.h"
26 #elif TENSORFLOW_USE_ROCM
27 #include "rocm/include/hipcub/hipcub.hpp"
28 #endif
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/kernels/cuda_sparse.h"
32 #include "tensorflow/core/kernels/gpu_device_array.h"
33 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
34 #include "tensorflow/core/kernels/sparse/kernels.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/gpu_kernel_helper.h"
38
39 #if GOOGLE_CUDA
40 namespace gpuprim = ::cub;
41 #elif TENSORFLOW_USE_ROCM
42 namespace gpuprim = ::hipcub;
43 #endif
44
45 namespace tensorflow {
46
47 typedef Eigen::GpuDevice GPUDevice;
48
49 namespace functor {
50
51 namespace {
52 struct StridedDataReader {
StridedDataReadertensorflow::functor::__anona7216e310111::StridedDataReader53 StridedDataReader(const int64* begin, int stride)
54 : begin_(begin), stride_(stride) {}
55
operator ()tensorflow::functor::__anona7216e310111::StridedDataReader56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
57 return static_cast<int>(ldg(begin_ + idx * stride_));
58 }
59
60 const int64* begin_;
61 const int stride_;
62 };
63 } // namespace
64
65 template <>
operator ()(OpKernelContext * c,TTypes<int64>::ConstMatrix indices,TTypes<int32>::Vec nnz_per_batch)66 Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
67 OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
68 TTypes<int32>::Vec nnz_per_batch) {
69 const auto& cu_stream = GetGpuStream(c);
70
71 const int total_nnz = indices.dimension(0);
72 const int size = nnz_per_batch.size();
73
74 DCHECK_EQ(indices.rank(), 2);
75 DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
76
77 const int rank = indices.dimension(1);
78 gpuprim::CountingInputIterator<int> row_counter(0);
79 gpuprim::TransformInputIterator<int, StridedDataReader,
80 gpuprim::CountingInputIterator<int>>
81 indices_first_column(row_counter,
82 StridedDataReader(indices.data(), rank));
83
84 std::size_t temp_storage_bytes = 0;
85
86 DCHECK_NE(indices.data(), nullptr);
87 DCHECK_NE(nnz_per_batch.data(), nullptr);
88
89 auto first_success = gpuprim::DeviceHistogram::HistogramEven(
90 /*d_temp_storage*/ nullptr,
91 /*temp_storage_bytes&*/ temp_storage_bytes,
92 /*d_samples*/ indices_first_column,
93 /*d_histogram*/ nnz_per_batch.data(),
94 /*num_levels*/ size + 1,
95 /*lower_level*/ 0,
96 /*upper_level*/ size,
97 /*num_samples*/ total_nnz,
98 /*stream*/ cu_stream);
99
100 if (first_success != gpuSuccess) {
101 return errors::Internal(
102 "SparseTensorToCSRSparseMatrix: Could not launch "
103 "gpuprim::DeviceHistogram::HistogramEven "
104 "to calculate temp_storage_bytes, status: ",
105 GpuGetErrorString(first_success));
106 }
107
108 Tensor temp_storage;
109 TF_RETURN_IF_ERROR(c->allocate_temp(
110 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
111 &temp_storage));
112 DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
113 auto second_success = gpuprim::DeviceHistogram::HistogramEven(
114 /*d_temp_storage*/ temp_storage.flat<int8>().data(),
115 /*temp_storage_bytes&*/ temp_storage_bytes,
116 /*d_samples*/ indices_first_column,
117 /*d_histogram*/ nnz_per_batch.data(),
118 /*num_levels*/ size + 1,
119 /*lower_level*/ 0,
120 /*upper_level*/ size,
121 /*num_samples*/ total_nnz,
122 /*stream*/ cu_stream);
123
124 if (second_success != gpuSuccess) {
125 return errors::Internal(
126 "SparseTensorToCSRSparseMatrix: Could not launch "
127 "gpuprim::DeviceHistogram::HistogramEven "
128 "to count nnz entries per batch. temp_storage_bytes: ",
129 temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
130 }
131
132 return Status::OK();
133 }
134
135 // TODO(ebrevdo): Write a custom batch-friendly impl of this to update
136 // the SparseTensor indices directly.
137 template <>
operator ()(OpKernelContext * c,TTypes<const int>::UnalignedVec csr_row_ptr,TTypes<int>::UnalignedVec coo_row_ind)138 Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
139 OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
140 TTypes<int>::UnalignedVec coo_row_ind) {
141 GpuSparse gpu_sparse(c);
142 const int nnz = coo_row_ind.size();
143 TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
144 const int m = csr_row_ptr.size() - 1; // rows
145 return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
146 }
147
148 template <int stride>
SparseTensorToCOOMatrixKernel(const int64 * indices,int * coo_rows_out,int * coo_cols_out,int size)149 __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
150 int* coo_rows_out,
151 int* coo_cols_out, int size) {
152 const int offset = (stride == 3) ? 1 : 0;
153 GPU_1D_KERNEL_LOOP(i, size) {
154 coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
155 coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
156 }
157 }
158
159 template <>
operator ()(const GPUDevice & d,TTypes<int64>::ConstVec host_dense_shape,TTypes<int64>::ConstMatrix indices,TTypes<int>::Vec coo_row_ind,TTypes<int>::Vec coo_col_ind)160 void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
161 const GPUDevice& d, TTypes<int64>::ConstVec host_dense_shape,
162 TTypes<int64>::ConstMatrix indices, TTypes<int>::Vec coo_row_ind,
163 TTypes<int>::Vec coo_col_ind) {
164 const int stride = host_dense_shape.size();
165 DCHECK(stride == 2 || stride == 3);
166 DCHECK_EQ(stride, indices.dimension(1));
167 const int size = coo_row_ind.dimension(0);
168 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
169 if (stride == 2) {
170 TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
171 config.block_count, config.thread_per_block, 0,
172 d.stream(), indices.data(), coo_row_ind.data(),
173 coo_col_ind.data(), size));
174 } else {
175 TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
176 config.block_count, config.thread_per_block, 0,
177 d.stream(), indices.data(), coo_row_ind.data(),
178 coo_col_ind.data(), size));
179 }
180 }
181
COOMatrixToSparseTensorKernel2D(const int * coo_rows,const int * coo_cols,int64 * indices_out,int size)182 __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
183 const int* coo_cols,
184 int64* indices_out, int size) {
185 GPU_1D_KERNEL_LOOP(i, size) {
186 indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
187 indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
188 }
189 }
190
BinarySearchRange(int * range,int n,int x)191 __device__ inline int BinarySearchRange(int* range, int n, int x) {
192 int left = 0;
193 int right = n - 1;
194 while (left < right) {
195 int mid = left + (right - left) / 2;
196 if (x < range[mid])
197 right = mid - 1;
198 else if (range[mid + 1] <= x)
199 left = mid + 1;
200 else
201 return mid; // range[mid] <= x < range[mid + 1].
202 }
203 return left;
204 }
205
COOMatrixToSparseTensorKernel3D(const int * coo_rows,const int * coo_cols,int64 * indices_out,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int size)206 __global__ void COOMatrixToSparseTensorKernel3D(
207 const int* coo_rows, const int* coo_cols, int64* indices_out,
208 GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
209 const int size) {
210 // Step 1: access the batch ptrs and copy to shared memory.
211 const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
212 extern __shared__ int local_batch_ptr[];
213 for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
214 local_batch_ptr[i] = batch_ptr[i];
215 }
216 __syncthreads();
217
218 GPU_1D_KERNEL_LOOP(i, size) {
219 // TODO(ebrevdo): Consider special casing batch_size <= 3,
220 // alternatively doing linear instead of binary search. Requires
221 // some benchmarks.
222 const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
223 indices_out[i * 3] = static_cast<int64>(b);
224 indices_out[i * 3 + 1] = static_cast<int64>(ldg(coo_rows + i));
225 indices_out[i * 3 + 2] = static_cast<int64>(ldg(coo_cols + i));
226 }
227 }
228
229 template <>
operator ()(OpKernelContext * c,TTypes<int64>::ConstVec host_dense_shape,TTypes<int>::ConstVec host_batch_ptr,TTypes<int>::Vec coo_row_ind,TTypes<int>::ConstVec coo_col_ind,TTypes<int64>::Matrix indices)230 Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
231 OpKernelContext* c, TTypes<int64>::ConstVec host_dense_shape,
232 TTypes<int>::ConstVec host_batch_ptr, TTypes<int>::Vec coo_row_ind,
233 TTypes<int>::ConstVec coo_col_ind, TTypes<int64>::Matrix indices) {
234 const int ndims = indices.dimension(1);
235 DCHECK(ndims == 2 || ndims == 3);
236 DCHECK_EQ(ndims, host_dense_shape.size());
237 DCHECK_NE(coo_row_ind.data(), nullptr);
238 DCHECK_NE(coo_col_ind.data(), nullptr);
239 DCHECK_NE(indices.data(), nullptr);
240 const GPUDevice& d = c->eigen_device<GPUDevice>();
241 const int size = coo_row_ind.size();
242 DCHECK_EQ(size, coo_col_ind.size());
243 DCHECK_EQ(size, indices.dimension(0));
244 if (ndims == 2) {
245 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
246 TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
247 config.block_count, config.thread_per_block, 0,
248 d.stream(), coo_row_ind.data(),
249 coo_col_ind.data(), indices.data(), size));
250 return Status::OK();
251 } else {
252 const int batch_size = host_dense_shape(0);
253 GpuDeviceArrayOnHost<int> batch_ptr_copy(c, host_batch_ptr.size());
254 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
255 for (int i = 0; i < batch_size; ++i) {
256 batch_ptr_copy.Set(i, host_batch_ptr(i));
257 }
258 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
259 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
260 // shared memory stores the batch pointers.
261 const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
262 TF_CHECK_OK(
263 GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
264 config.thread_per_block, shared_memory_size, d.stream(),
265 coo_row_ind.data(), coo_col_ind.data(), indices.data(),
266 batch_ptr_copy.data(), batch_size, size));
267 return Status::OK();
268 }
269 }
270
271 template <typename T>
CSRSparseMatrixBatchMulVecKernel3D(const T * a_values,const T * b_batch_values,T * c_values,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int total_nnz)272 __global__ void CSRSparseMatrixBatchMulVecKernel3D(
273 const T* a_values, const T* b_batch_values, T* c_values,
274 GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
275 const int total_nnz) {
276 // Step 1: Access the batch ptrs and copy to shared memory.
277 // Also copy the per-batch multipliers into shared memory.
278 const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
279 extern __shared__ int local_batch_ptr[];
280 T* local_batch_values =
281 reinterpret_cast<T*>(local_batch_ptr + batch_size + 1);
282 for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
283 local_batch_ptr[i] = batch_ptr[i];
284 if (i < batch_size) {
285 local_batch_values[i] = b_batch_values[i];
286 }
287 }
288 __syncthreads();
289
290 GPU_1D_KERNEL_LOOP(i, total_nnz) {
291 const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
292 c_values[i] = ldg(a_values + i) * local_batch_values[b];
293 }
294 }
295
296 template <typename T>
CSRSparseMatrixBatchMulVecImpl(OpKernelContext * ctx,const CSRSparseMatrix & a,typename TTypes<T>::ConstFlat b,CSRSparseMatrix * c)297 Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
298 const CSRSparseMatrix& a,
299 typename TTypes<T>::ConstFlat b,
300 CSRSparseMatrix* c) {
301 DCHECK_EQ(a.dims(), 3);
302 const int total_nnz = a.total_nnz();
303 Tensor c_values_t;
304 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
305 TensorShape({total_nnz}), &c_values_t));
306 TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
307 DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
308 a.row_pointers(), a.col_indices(), c_values_t, c));
309
310 auto a_values = a.values().flat<T>();
311 auto c_values = c_values_t.flat<T>();
312
313 auto host_dense_shape = a.dense_shape().vec<int64>();
314 auto host_batch_ptr = a.batch_pointers().vec<int>();
315
316 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
317
318 const int batch_size = host_dense_shape(0);
319 DCHECK_EQ(b.size(), batch_size);
320
321 GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
322 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
323 for (int i = 0; i < batch_size; ++i) {
324 batch_ptr_copy.Set(i, host_batch_ptr(i));
325 }
326 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
327 GpuLaunchConfig config = GetGpuLaunchConfig(total_nnz, d);
328 // shared memory stores the batch pointers.
329 const size_t shared_memory_size =
330 (sizeof(int) * (batch_size + 1) // local batch_pointers.
331 + sizeof(T) * batch_size); // local copy of b.
332 TF_CHECK_OK(GpuLaunchKernel(
333 CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
334 config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
335 b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
336
337 return Status::OK();
338 }
339
340 #define DEFINE_SPARSE_MUL_VEC_GPU(T) \
341 template <> \
342 CSRSparseMatrixBatchMulVec<GPUDevice, T>::CSRSparseMatrixBatchMulVec() {} \
343 template <> \
344 Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute( \
345 OpKernelContext* ctx, const CSRSparseMatrix& a, \
346 typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c) { \
347 return CSRSparseMatrixBatchMulVecImpl<T>(ctx, a, b, c); \
348 }
349
350 DEFINE_SPARSE_MUL_VEC_GPU(float);
351 DEFINE_SPARSE_MUL_VEC_GPU(double);
352 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<float>);
353 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<double>);
354
355 #undef DEFINE_SPARSE_MUL_VEC_GPU
356
357 template <typename T>
CalculateRowSoftmax(const int begin,const int end,const T * logits,T * softmax)358 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmax(const int begin,
359 const int end,
360 const T* logits,
361 T* softmax) {
362 // For each row, calculate the vector:
363 // softmax[row] = exp(shifted_logits[row]) / sum(exp(shifted_logits[row]))
364 // where
365 // shifted_logits[row] = logits[row] - max(logits[row])
366 // are the logits normalized for stability.
367 T row_max = Eigen::NumTraits<T>::lowest();
368 for (int r_i = begin; r_i < end; ++r_i) {
369 row_max = Eigen::numext::maxi(row_max, ldg(logits + r_i));
370 }
371 T sum_exp = 0;
372 for (int r_i = begin; r_i < end; ++r_i) {
373 const T exp_i = Eigen::numext::exp(ldg(logits + r_i) - row_max);
374 softmax[r_i] = exp_i;
375 sum_exp += exp_i;
376 }
377 for (int r_i = begin; r_i < end; ++r_i) {
378 softmax[r_i] = softmax[r_i] / sum_exp;
379 }
380 }
381
382 template <typename T>
CSRSparseMatrixSoftmaxKernel2D(const int rows,const int * row_ptr,const T * logits,T * softmax)383 __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
384 const int* row_ptr,
385 const T* logits, T* softmax) {
386 // TODO(ebrevdo): consider something like a merge-path based
387 // algorithm to distribute the work in case the row sizes are
388 // uneven:
389 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
390 GPU_1D_KERNEL_LOOP(row, rows) {
391 CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
392 softmax);
393 }
394 }
395
CopyFromGpuDeviceArrayToLocal(GpuDeviceArrayStruct<int> cuda_ptr_s,int * local_ptr,int length)396 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
397 GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
398 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
399 const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
400 for (int i = threadIdx.x; i < length; i += blockDim.x) {
401 local_ptr[i] = cuda_ptr[i];
402 }
403 __syncthreads();
404 #endif
405 }
406
407 template <typename T>
CSRSparseMatrixSoftmaxKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> batch_ptr_s,const int * row_ptr,const T * logits,T * softmax)408 __global__ void CSRSparseMatrixSoftmaxKernel3D(
409 const int size, const int rows, GpuDeviceArrayStruct<int> batch_ptr_s,
410 const int* row_ptr, const T* logits, T* softmax) {
411 // TODO(ebrevdo): consider something like a merge-path based
412 // algorithm to distribute the work in case the row sizes are
413 // uneven:
414 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
415 const int batch_size = size / rows;
416 extern __shared__ int local_batch_ptr[];
417 CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
418 batch_size + 1);
419
420 GPU_1D_KERNEL_LOOP(i, size) {
421 const int batch = i / rows;
422 const int row = i % rows;
423 const int batch_offset = local_batch_ptr[batch];
424 const int row_offset = batch * (rows + 1) + row;
425 CalculateRowSoftmax(batch_offset + ldg(row_ptr + row_offset),
426 batch_offset + ldg(row_ptr + row_offset + 1), logits,
427 softmax);
428 }
429 }
430
431 template <typename T>
CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & logits,typename TTypes<T>::Vec softmax_values)432 Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
433 const CSRSparseMatrix& logits,
434 typename TTypes<T>::Vec softmax_values) {
435 auto host_dense_shape = logits.dense_shape().vec<int64>();
436 auto host_batch_ptr = logits.batch_pointers().vec<int32>();
437 auto row_ptr = logits.row_pointers().vec<int32>();
438 auto logits_values = logits.values().vec<T>();
439
440 const int ndims = host_dense_shape.size();
441 DCHECK(ndims == 2 || ndims == 3);
442 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
443 if (ndims == 2) {
444 const int rows = host_dense_shape(0);
445 DCHECK_EQ(rows, row_ptr.size() - 1);
446 GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
447 TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
448 config.block_count, config.thread_per_block, 0,
449 d.stream(), rows /*size*/, row_ptr.data(),
450 logits_values.data(), softmax_values.data()));
451 } else {
452 const int batch_size = host_dense_shape(0);
453 const int rows = host_dense_shape(1);
454 DCHECK_EQ(batch_size, host_batch_ptr.size() - 1);
455 DCHECK_EQ((rows + 1) * batch_size, row_ptr.size());
456 const int size = rows * batch_size;
457
458 GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
459 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
460 for (int i = 0; i < host_batch_ptr.size(); ++i) {
461 batch_ptr_copy.Set(i, host_batch_ptr(i));
462 }
463 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
464
465 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
466 // shared memory stores the batch pointers.
467 const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
468 TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
469 config.block_count, config.thread_per_block,
470 shared_memory_size, d.stream(), size, rows,
471 batch_ptr_copy.data(), row_ptr.data(),
472 logits_values.data(), softmax_values.data()));
473 }
474
475 return Status::OK();
476 }
477
478 #define DEFINE_SOFTMAX_GPU(T) \
479 template <> \
480 Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
481 OpKernelContext* ctx, const CSRSparseMatrix& logits, \
482 typename TTypes<T>::Vec softmax_values) { \
483 return CSRSparseMatrixSoftmaxGPUImpl<T>(ctx, logits, softmax_values); \
484 }
485
486 DEFINE_SOFTMAX_GPU(float);
487 DEFINE_SOFTMAX_GPU(double);
488
489 #undef DEFINE_SOFTMAX_GPU
490
491 template <typename T>
CalculateRowSoftmaxGrad(const int softmax_begin,const int softmax_end,const int * softmax_col_ind,const T * softmax,const int grad_softmax_begin,const int grad_softmax_end,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)492 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmaxGrad(
493 const int softmax_begin, const int softmax_end, const int* softmax_col_ind,
494 const T* softmax, const int grad_softmax_begin, const int grad_softmax_end,
495 const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
496 // Iterate from
497 // softmax_col_ind[softmax_begin] to
498 // softmax_col_ind[softmax_end]
499 // and from
500 // grad_softmax_col_ind[grad_softmax_begin] to
501 // grad_softmax_col_ind[grad_softmax_end]
502 //
503 // looking for for matching indices. In the softmax indices only, perform:
504 //
505 // gradient = (grad_softmax - sum(grad_softmax * softmax)) * softmax
506 //
507 // where the sum is along the given row.
508 T sum_prod = 0;
509 for (int i = softmax_begin, j = grad_softmax_begin;
510 i < softmax_end && j < grad_softmax_end;) {
511 const int softmax_col = ldg(softmax_col_ind + i);
512 const int grad_softmax_col = ldg(grad_softmax_col_ind + j);
513 if (softmax_col == grad_softmax_col) {
514 sum_prod += ldg(softmax + i) * ldg(grad_softmax + j);
515 ++i;
516 ++j;
517 } else if (softmax_col > grad_softmax_col) {
518 ++j;
519 } else {
520 ++i;
521 }
522 }
523
524 // Find an upper bound on the column numbers in this row; for use in
525 // the special case of a empty grad_softmax row and a non-empty
526 // softmax row.
527 const int softmax_col_upper_bound =
528 (softmax_begin == softmax_end)
529 ? -1
530 : ldg(softmax_col_ind + softmax_end - 1) + 1;
531 for (int i = softmax_begin, j = grad_softmax_begin; i < softmax_end;) {
532 const int softmax_col = ldg(softmax_col_ind + i);
533 // We need to keep a large grad_softmax_col value if we're at the
534 // end of the grad_softmax row, so we can fill in the remainder of
535 // the gradients row (the last if branch in this loop).
536 const int grad_softmax_col = (j == grad_softmax_end)
537 ? softmax_col_upper_bound
538 : ldg(grad_softmax_col_ind + j);
539
540 if (softmax_col == grad_softmax_col) {
541 gradient[i] = (ldg(grad_softmax + j) - sum_prod) * ldg(softmax + i);
542 ++i;
543 ++j;
544 } else if (softmax_col > grad_softmax_col) {
545 // grad_softmax is nonzero here, but since softmax is zero, the
546 // gradient is 0; so we skip it since the sparsity structure
547 // already encodes this zero.
548 ++j;
549 } else {
550 // grad_softmax is zero but softmax is not.
551 gradient[i] = -sum_prod * ldg(softmax + i);
552 ++i;
553 }
554 }
555 }
556
557 template <typename T>
CSRSparseMatrixSoftmaxGradKernel2D(const int rows,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)558 __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
559 const int rows, const int* softmax_row_ptr, const int* softmax_col_ind,
560 const T* softmax, const int* grad_softmax_row_ptr,
561 const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
562 // TODO(ebrevdo): consider something like a merge-path based
563 // algorithm to distribute the work in case the row sizes are
564 // uneven:
565 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
566 GPU_1D_KERNEL_LOOP(row, rows) {
567 CalculateRowSoftmaxGrad(
568 ldg(softmax_row_ptr + row) /*softmax_begin*/,
569 ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
570 softmax, ldg(grad_softmax_row_ptr + row) /*grad_softmax_begin*/,
571 ldg(grad_softmax_row_ptr + row + 1) /*grad_softmax_end*/,
572 grad_softmax_col_ind, grad_softmax, gradient);
573 }
574 }
575
576 template <typename T>
CSRSparseMatrixSoftmaxGradKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)577 __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
578 const int size, const int rows,
579 GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,
580 const int* softmax_row_ptr, const int* softmax_col_ind, const T* softmax,
581 const int* grad_softmax_row_ptr, const int* grad_softmax_col_ind,
582 const T* grad_softmax, T* gradient) {
583 // TODO(ebrevdo): consider something like a merge-path based
584 // algorithm to distribute the work in case the row sizes are
585 // uneven:
586 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
587
588 const int batch_size = size / rows;
589 extern __shared__ int local_batch_ptr[];
590 CopyFromGpuDeviceArrayToLocal(std::move(softmax_and_grad_batch_ptr_s),
591 local_batch_ptr, 2 * (batch_size + 1));
592
593 #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
594 #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
595
596 GPU_1D_KERNEL_LOOP(i, size) {
597 const int batch = i / rows;
598 const int row = i % rows;
599 const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
600 const int grad_softmax_batch_offset = GRAD_SOFTMAX_BATCH_PTR(batch);
601 const int row_offset = batch * (rows + 1) + row;
602 CalculateRowSoftmaxGrad(
603 softmax_batch_offset +
604 ldg(softmax_row_ptr + row_offset) /*softmax_begin*/,
605 softmax_batch_offset +
606 ldg(softmax_row_ptr + row_offset + 1) /*softmax_end*/,
607 softmax_col_ind, softmax,
608 grad_softmax_batch_offset +
609 ldg(grad_softmax_row_ptr + row_offset) /*grad_softmax_begin*/,
610 grad_softmax_batch_offset +
611 ldg(grad_softmax_row_ptr + row_offset + 1) /*grad_softmax_end*/,
612 grad_softmax_col_ind, grad_softmax, gradient);
613 }
614
615 #undef SOFTMAX_BATCH_PTR
616 #undef GRAD_SOFTMAX_BATCH_PTR
617 }
618
619 template <typename T>
CSRSparseMatrixSoftmaxGradGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & softmax,const CSRSparseMatrix & grad_softmax,typename TTypes<T>::Vec gradient_values)620 Status CSRSparseMatrixSoftmaxGradGPUImpl(
621 OpKernelContext* ctx, const CSRSparseMatrix& softmax,
622 const CSRSparseMatrix& grad_softmax,
623 typename TTypes<T>::Vec gradient_values) {
624 auto host_dense_shape = softmax.dense_shape().vec<int64>();
625 auto softmax_host_batch_ptr = softmax.batch_pointers().vec<int32>();
626 auto softmax_row_ptr = softmax.row_pointers().vec<int32>();
627 auto softmax_col_ind = softmax.col_indices().vec<int32>();
628 auto softmax_values = softmax.values().vec<T>();
629 auto grad_softmax_host_batch_ptr = grad_softmax.batch_pointers().vec<int32>();
630 auto grad_softmax_row_ptr = grad_softmax.row_pointers().vec<int32>();
631 auto grad_softmax_col_ind = grad_softmax.col_indices().vec<int32>();
632 auto grad_softmax_values = grad_softmax.values().vec<T>();
633
634 const int ndims = host_dense_shape.size();
635 DCHECK(ndims == 2 || ndims == 3);
636 const int rows = host_dense_shape(0);
637 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
638 if (ndims == 2) {
639 DCHECK_EQ(rows + 1, softmax_row_ptr.size());
640 DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
641 GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
642 TF_CHECK_OK(GpuLaunchKernel(
643 CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
644 config.thread_per_block, 0, d.stream(), rows /*size*/,
645 softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
646 grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
647 grad_softmax_values.data(), gradient_values.data()));
648 } else {
649 const int batch_size = host_dense_shape(0);
650 const int rows = host_dense_shape(1);
651 DCHECK_EQ(batch_size, softmax_host_batch_ptr.size() - 1);
652 DCHECK_EQ(batch_size, grad_softmax_host_batch_ptr.size() - 1);
653 DCHECK_EQ((rows + 1) * batch_size, softmax_row_ptr.size());
654 DCHECK_EQ((rows + 1) * batch_size, grad_softmax_row_ptr.size());
655 const int size = rows * batch_size;
656 // The length of softmax_and_grad_batch_ptr_copy is 2 * (batch_size + 1)
657 // The first (batch_size + 1) entries contain softmax_batch_ptr and
658 // the second (batch_size + 1) entries contain grad_softmax_batch_ptr.
659 GpuDeviceArrayOnHost<int> softmax_and_grad_batch_ptr_copy(
660 ctx, 2 * softmax_host_batch_ptr.size());
661 TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Init());
662 for (int i = 0; i < softmax_host_batch_ptr.size(); ++i) {
663 softmax_and_grad_batch_ptr_copy.Set(i, softmax_host_batch_ptr(i));
664 softmax_and_grad_batch_ptr_copy.Set(batch_size + 1 + i,
665 grad_softmax_host_batch_ptr(i));
666 }
667 TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Finalize());
668
669 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
670 // shared memory stores two copies of batch pointers: one for the
671 // softmax CSR matrix, one for the grad_softmax CSR matrix.
672 const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
673 TF_CHECK_OK(GpuLaunchKernel(
674 CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
675 config.thread_per_block, shared_memory_size, d.stream(), size, rows,
676 softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
677 softmax_col_ind.data(), softmax_values.data(),
678 grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
679 grad_softmax_values.data(), gradient_values.data()));
680 }
681
682 return Status::OK();
683 }
684
685 #define DEFINE_SOFTMAX_GRAD_GPU(T) \
686 template <> \
687 Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
688 OpKernelContext* ctx, const CSRSparseMatrix& softmax, \
689 const CSRSparseMatrix& grad_softmax, \
690 typename TTypes<T>::Vec gradient_values) { \
691 return CSRSparseMatrixSoftmaxGradGPUImpl<T>(ctx, softmax, grad_softmax, \
692 gradient_values); \
693 }
694
695 DEFINE_SOFTMAX_GRAD_GPU(float);
696 DEFINE_SOFTMAX_GRAD_GPU(double);
697
698 #undef DEFINE_SOFTMAX_GRAD_GPU
699
700 } // namespace functor
701
702 } // namespace tensorflow
703
704 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
705