1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the kernel for the CSRTranspose op, which transposes the
17 // two innermost dimensions of a CSRSparseMatrix object stored in a
18 // DT_VARIANT.
19
20 #define EIGEN_USE_THREADS
21
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/util/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26
27 #include <numeric>
28
29 #include "third_party/eigen3/Eigen/SparseCore"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_common.h"
37 #include "tensorflow/core/kernels/dense_update_functor.h"
38 #include "tensorflow/core/kernels/fill_functor.h"
39 #include "tensorflow/core/kernels/slice_op.h"
40 #include "tensorflow/core/kernels/sparse/kernels.h"
41 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
42 #include "tensorflow/core/kernels/sparse/transpose_op.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44
45 namespace tensorflow {
46
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49
50 namespace {
51
52 template <typename T>
ValidateTransposeInputs(const ConstCSRComponent<T> & input,const CSRComponent<T> & output)53 Status ValidateTransposeInputs(const ConstCSRComponent<T>& input,
54 const CSRComponent<T>& output) {
55 const int rank = input.dense_shape_host.size();
56 const int64 nnz = input.col_ind.size();
57 const int num_rows = input.row_ptr.size() - 1;
58 const int num_cols = input.dense_shape_host(rank - 1);
59
60 if (nnz != input.values.size()) {
61 return errors::InvalidArgument(
62 "Input nnz should equal the input values size. Got ", nnz, " vs. ",
63 input.values.size());
64 }
65 if (num_cols + 1 != output.row_ptr.size()) {
66 return errors::InvalidArgument(
67 "Input num_cols should be equal to output num_rows. Got ", num_cols,
68 " vs. ", output.row_ptr.size());
69 }
70 if (rank != output.dense_shape_host.size()) {
71 return errors::InvalidArgument(
72 "Input rank should be equal to the output rank. Got ", rank, " vs. ",
73 output.dense_shape_host.size());
74 }
75 if (num_rows != output.dense_shape_host(rank - 1)) {
76 return errors::InvalidArgument(
77 "Input num_rows should be equal to the output num_cols. Got ", num_rows,
78 " vs. ", output.dense_shape_host(rank - 1));
79 }
80 if (nnz != output.col_ind.size()) {
81 return errors::InvalidArgument(
82 "Input nnz should equal the output col_ind size. Got ", nnz, " vs. ",
83 output.col_ind.size());
84 }
85 if (nnz != output.values.size()) {
86 return errors::InvalidArgument(
87 "Input nnz should equal the output values size. Got ", nnz, " vs. ",
88 output.values.size());
89 }
90 return Status::OK();
91 }
92 } // namespace
93
94 template <typename Device, typename T>
95 class CSRTransposeOp : public OpKernel {
96 public:
CSRTransposeOp(OpKernelConstruction * ctx)97 explicit CSRTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
98 OP_REQUIRES_OK(ctx, ctx->GetAttr("conjugate", &conjugate_));
99 }
100
Compute(OpKernelContext * ctx)101 void Compute(OpKernelContext* ctx) override {
102 const CSRSparseMatrix* input_matrix;
103 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
104 OP_REQUIRES(
105 ctx, input_matrix->dtype() == DataTypeToEnum<T>::value,
106 errors::InvalidArgument("dtype of input is not equal to 'type': ",
107 DataTypeString(input_matrix->dtype()), " vs. ",
108 DataTypeString(DataTypeToEnum<T>::value)));
109
110 // Allocate output shapes
111 functor::CSRSparseMatrixTranspose<Device, T> transpose;
112 CSRSparseMatrix output_matrix;
113 OP_REQUIRES_OK(ctx,
114 transpose(ctx, conjugate_, *input_matrix, &output_matrix));
115 Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
116 output_t.scalar<Variant>()() = std::move(output_matrix);
117 ctx->set_output(0, output_t);
118 }
119
120 private:
121 bool conjugate_;
122 };
123
124 #define REGISTER_TRANSPOSE(DEV, T) \
125 REGISTER_KERNEL_BUILDER(Name("SparseMatrixTranspose") \
126 .Device(DEVICE_##DEV) \
127 .TypeConstraint<T>("type"), \
128 CSRTransposeOp<DEV##Device, T>);
129
130 REGISTER_TRANSPOSE(CPU, float)
131 REGISTER_TRANSPOSE(CPU, double)
132 REGISTER_TRANSPOSE(CPU, complex64)
133 REGISTER_TRANSPOSE(CPU, complex128)
134
135 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136 REGISTER_TRANSPOSE(GPU, float)
137 REGISTER_TRANSPOSE(GPU, double)
138 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
139
140 #if GOOGLE_CUDA
141 REGISTER_TRANSPOSE(GPU, complex64)
142 REGISTER_TRANSPOSE(GPU, complex128)
143 #endif // GOOGLE_CUDA
144
145 #undef REGISTER_TRANSPOSE
146
147 namespace functor {
148
149 template <typename Device, typename T>
operator ()(OpKernelContext * ctx,bool conjugate,const CSRSparseMatrix & input_matrix,CSRSparseMatrix * output_matrix)150 Status CSRSparseMatrixTranspose<Device, T>::operator()(
151 OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix,
152 CSRSparseMatrix* output_matrix) {
153 const int rank = input_matrix.dims();
154 Tensor output_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
155 const Tensor& input_dense_shape_t = input_matrix.dense_shape();
156 auto input_dense_shape = input_dense_shape_t.vec<int64>();
157 auto output_dense_shape = output_dense_shape_t.vec<int64>();
158 const int64 batch_size = input_matrix.batch_size();
159 if (rank == 3) {
160 output_dense_shape(0) = batch_size;
161 }
162 output_dense_shape(rank - 2) = input_dense_shape(rank - 1);
163 output_dense_shape(rank - 1) = input_dense_shape(rank - 2);
164 const int64 output_rows = output_dense_shape(rank - 2);
165
166 // nnzs per batch do not change with matrix transposition.
167 Tensor batch_ptr_t = input_matrix.batch_pointers();
168 const int total_nnz = input_matrix.total_nnz();
169
170 Tensor output_row_ptr_t;
171 Tensor output_col_ind_t;
172 Tensor output_values_t;
173
174 TF_RETURN_IF_ERROR(ctx->allocate_temp(
175 DT_INT32, TensorShape({batch_size * (output_rows + 1)}),
176 &output_row_ptr_t));
177 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
178 &output_col_ind_t));
179 TF_RETURN_IF_ERROR(ctx->allocate_temp(
180 DataTypeToEnum<T>::value, TensorShape({total_nnz}), &output_values_t));
181
182 TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
183 DataTypeToEnum<T>::value, output_dense_shape_t, batch_ptr_t,
184 output_row_ptr_t, output_col_ind_t, output_values_t, output_matrix));
185
186 // Set the output row pointers to zero, in case we hit any empty
187 // input batches.
188 functor::SetZeroFunctor<Device, int32> set_zero;
189 const Device& d = ctx->eigen_device<Device>();
190 set_zero(d, output_row_ptr_t.flat<int32>());
191
192 functor::CSRSparseMatrixTransposeComponent<Device, T> transpose_component;
193 for (int i = 0; i < batch_size; ++i) {
194 if (output_matrix->nnz(i) == 0) {
195 continue;
196 }
197 ConstCSRComponent<T> input_comp{
198 input_matrix.row_pointers_vec(i), input_matrix.col_indices_vec(i),
199 input_matrix.values_vec<T>(i), input_dense_shape};
200 CSRComponent<T> output_comp{
201 output_matrix->row_pointers_vec(i), output_matrix->col_indices_vec(i),
202 output_matrix->values_vec<T>(i), output_dense_shape};
203
204 TF_RETURN_IF_ERROR(transpose_component(ctx, input_comp, &output_comp));
205 }
206 if (conjugate) {
207 // conjugate all values with a single kernel launch.
208 maybe_conj_inplace<Device, T>::run(d, &output_values_t);
209 }
210
211 return Status::OK();
212 }
213
214 // CPU kernel for transposing a single component of a CSR SparseMatrix.
215 template <typename T>
216 struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
217 using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
218
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent219 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& input,
220 CSRComponent<T>* output) {
221 TF_RETURN_IF_ERROR(ValidateTransposeInputs(input, *output));
222
223 const int rank = input.dense_shape_host.size();
224 const int num_rows = input.row_ptr.size() - 1;
225 const int num_cols = input.dense_shape_host(rank - 1);
226 const int64 nnz = input.col_ind.size();
227
228 // Compute the column counts; whose prefix sums make up the output row
229 // pointers.
230 for (int64 i = 0; i < nnz; ++i) {
231 output->row_ptr(input.col_ind(i) + 1) += 1;
232 }
233 std::partial_sum(output->row_ptr.data(),
234 output->row_ptr.data() + num_cols + 1,
235 output->row_ptr.data());
236
237 // Iterate through each row of the input, and place each non-zero element
238 // into the target output row (based on the current column count).
239 std::vector<int> current_col_count(num_cols);
240 for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
241 const int64 row_begin = input.row_ptr(row_idx);
242 const int64 row_end = input.row_ptr(row_idx + 1);
243 for (int64 i = row_begin; i < row_end; ++i) {
244 const int col_idx = input.col_ind(i);
245 const int64 offset =
246 output->row_ptr(col_idx) + current_col_count[col_idx];
247 output->col_ind(offset) = row_idx;
248 output->values(offset) = input.values(i);
249 current_col_count[col_idx] += 1;
250 }
251 }
252 return Status::OK();
253 }
254 };
255
256 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257
258 template <typename T>
259 struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent260 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
261 CSRComponent<T>* y) {
262 TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
263 GpuSparse cuda_sparse(ctx);
264 TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
265 const gpusparseAction_t copyValues = GPUSPARSE(ACTION_NUMERIC);
266 const int rank = x.dense_shape_host.size();
267 const int m = x.row_ptr.size() - 1;
268 const int n = x.dense_shape_host(rank - 1);
269 const int nnz = x.col_ind.size();
270 DCHECK_EQ(nnz, x.values.size());
271 DCHECK_EQ(n, y->row_ptr.size() - 1);
272 DCHECK_EQ(rank, y->dense_shape_host.size());
273 DCHECK_EQ(m, y->dense_shape_host(rank - 1));
274 DCHECK_EQ(nnz, y->col_ind.size());
275 DCHECK_EQ(nnz, y->values.size());
276
277 return cuda_sparse.Csr2csc(
278 m, n, nnz, x.values.data() /*csrVal*/, x.row_ptr.data() /*csrRowPtr*/,
279 x.col_ind.data() /*csrColInd*/, y->values.data() /*cscVal*/,
280 y->col_ind.data() /*cscRowInd*/, y->row_ptr.data() /*cscColPtr*/,
281 copyValues);
282 return Status::OK();
283 }
284 };
285 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
286 } // namespace functor
287
288 } // namespace tensorflow
289