1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
21
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/kernels/bounds_check.h"
24 #include "tensorflow/core/util/cuda_kernel_helper.h"
25
26 namespace tensorflow {
27
28 typedef Eigen::GpuDevice GPUDevice;
29
30 template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
SparseTensorDenseMatMulKernel(int nnz,int m,int b_rows,int b_cols,int p,const Tindices * a_indices,const T * a_values,const T * b,T * out)31 __global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
32 int b_cols, int p,
33 const Tindices* a_indices,
34 const T* a_values, const T* b,
35 T* out) {
36 // out_{ij} = sum_k {a_ik b_kj}
37 // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
38 const int n = (ADJ_B) ? b_cols : b_rows;
39 CUDA_1D_KERNEL_LOOP(index, nnz * p) {
40 const int a_ix = index / p;
41 const int j = index % p;
42 const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0));
43 const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1));
44 if (!FastBoundsCheck(i, m)) {
45 continue; // Nowhere to signal an error :(
46 }
47 // out[i, j]
48 T* out_location = out + i * p + j;
49 if (!FastBoundsCheck(k, n)) {
50 CudaAtomicAdd(out_location, std::numeric_limits<T>::quiet_NaN());
51 continue;
52 }
53
54 // a_value == (ADJ_A) ? a[k, i] : a[i, k]
55 const T a_value = ldg(a_values + a_ix);
56
57 // b_value == (ADJ_B) ? b[j, k] : b[k, j]
58 const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
59 CudaAtomicAdd(out_location, a_value * b_value);
60 }
61 }
62
63 namespace functor {
64
65 template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
66 struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
67 static EIGEN_ALWAYS_INLINE Status
Computetensorflow::functor::SparseTensorDenseMatMulFunctor68 Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
69 typename TTypes<Tindices>::ConstMatrix a_indices,
70 typename TTypes<T>::ConstVec a_values,
71 typename TTypes<T>::ConstMatrix b) {
72 out.device(d) = out.constant(T(0));
73 int nnz = a_values.size();
74 // out = A * B, A is [m x n] and B is [n x p], out is [m x p]
75 int m = out.dimension(0);
76 int p = out.dimension(1);
77 int b_rows = b.dimension(0);
78 int b_cols = b.dimension(1);
79
80 // TODO(ebrevdo): Should this be alpha * nnz instead of
81 // out.size()? Perhaps p * nnz ?
82 CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d);
83
84 SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>
85 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
86 nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(),
87 b.data(), out.data());
88
89 return Status::OK();
90 }
91 };
92
93 } // namespace functor
94
95 #define DEFINE(T, Tindices) \
96 template struct functor::SparseTensorDenseMatMulFunctor< \
97 GPUDevice, T, Tindices, false, false>; \
98 template struct functor::SparseTensorDenseMatMulFunctor< \
99 GPUDevice, T, Tindices, false, true>; \
100 template struct functor::SparseTensorDenseMatMulFunctor< \
101 GPUDevice, T, Tindices, true, false>; \
102 template struct functor::SparseTensorDenseMatMulFunctor< \
103 GPUDevice, T, Tindices, true, true>;
104
105 DEFINE(float, int32);
106 DEFINE(float, int64);
107 #undef DEFINE
108
109 } // end namespace tensorflow
110
111 #endif // GOOGLE_CUDA
112