• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the kernel for the CSRTranspose op, which transposes the
17 // two innermost dimensions of a CSRSparseMatrix object stored in a
18 // DT_VARIANT.
19 
20 #define EIGEN_USE_THREADS
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/util/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26 
27 #include <numeric>
28 
29 #include "third_party/eigen3/Eigen/SparseCore"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_types.h"
34 #include "tensorflow/core/framework/variant_op_registry.h"
35 #include "tensorflow/core/kernels/cwise_ops.h"
36 #include "tensorflow/core/kernels/cwise_ops_common.h"
37 #include "tensorflow/core/kernels/dense_update_functor.h"
38 #include "tensorflow/core/kernels/fill_functor.h"
39 #include "tensorflow/core/kernels/slice_op.h"
40 #include "tensorflow/core/kernels/sparse/kernels.h"
41 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
42 #include "tensorflow/core/kernels/sparse/transpose_op.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44 
45 namespace tensorflow {
46 
47 typedef Eigen::ThreadPoolDevice CPUDevice;
48 typedef Eigen::GpuDevice GPUDevice;
49 
50 namespace {
51 
52 template <typename T>
ValidateTransposeInputs(const ConstCSRComponent<T> & input,const CSRComponent<T> & output)53 Status ValidateTransposeInputs(const ConstCSRComponent<T>& input,
54                                const CSRComponent<T>& output) {
55   const int rank = input.dense_shape_host.size();
56   const int64 nnz = input.col_ind.size();
57   const int num_rows = input.row_ptr.size() - 1;
58   const int num_cols = input.dense_shape_host(rank - 1);
59 
60   if (nnz != input.values.size()) {
61     return errors::InvalidArgument(
62         "Input nnz should equal the input values size. Got ", nnz, " vs. ",
63         input.values.size());
64   }
65   if (num_cols + 1 != output.row_ptr.size()) {
66     return errors::InvalidArgument(
67         "Input num_cols should be equal to output num_rows. Got ", num_cols,
68         " vs. ", output.row_ptr.size());
69   }
70   if (rank != output.dense_shape_host.size()) {
71     return errors::InvalidArgument(
72         "Input rank should be equal to the output rank. Got ", rank, " vs. ",
73         output.dense_shape_host.size());
74   }
75   if (num_rows != output.dense_shape_host(rank - 1)) {
76     return errors::InvalidArgument(
77         "Input num_rows should be equal to the output num_cols. Got ", num_rows,
78         " vs. ", output.dense_shape_host(rank - 1));
79   }
80   if (nnz != output.col_ind.size()) {
81     return errors::InvalidArgument(
82         "Input nnz should equal the output col_ind size. Got ", nnz, " vs. ",
83         output.col_ind.size());
84   }
85   if (nnz != output.values.size()) {
86     return errors::InvalidArgument(
87         "Input nnz should equal the output values size. Got ", nnz, " vs. ",
88         output.values.size());
89   }
90   return Status::OK();
91 }
92 }  // namespace
93 
94 template <typename Device, typename T>
95 class CSRTransposeOp : public OpKernel {
96  public:
CSRTransposeOp(OpKernelConstruction * ctx)97   explicit CSRTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
98     OP_REQUIRES_OK(ctx, ctx->GetAttr("conjugate", &conjugate_));
99   }
100 
Compute(OpKernelContext * ctx)101   void Compute(OpKernelContext* ctx) override {
102     const CSRSparseMatrix* input_matrix;
103     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
104     OP_REQUIRES(
105         ctx, input_matrix->dtype() == DataTypeToEnum<T>::value,
106         errors::InvalidArgument("dtype of input is not equal to 'type': ",
107                                 DataTypeString(input_matrix->dtype()), " vs. ",
108                                 DataTypeString(DataTypeToEnum<T>::value)));
109 
110     // Allocate output shapes
111     functor::CSRSparseMatrixTranspose<Device, T> transpose;
112     CSRSparseMatrix output_matrix;
113     OP_REQUIRES_OK(ctx,
114                    transpose(ctx, conjugate_, *input_matrix, &output_matrix));
115     Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
116     output_t.scalar<Variant>()() = std::move(output_matrix);
117     ctx->set_output(0, output_t);
118   }
119 
120  private:
121   bool conjugate_;
122 };
123 
124 #define REGISTER_TRANSPOSE(DEV, T)                        \
125   REGISTER_KERNEL_BUILDER(Name("SparseMatrixTranspose")   \
126                               .Device(DEVICE_##DEV)       \
127                               .TypeConstraint<T>("type"), \
128                           CSRTransposeOp<DEV##Device, T>);
129 
130 REGISTER_TRANSPOSE(CPU, float)
131 REGISTER_TRANSPOSE(CPU, double)
132 REGISTER_TRANSPOSE(CPU, complex64)
133 REGISTER_TRANSPOSE(CPU, complex128)
134 
135 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
136 REGISTER_TRANSPOSE(GPU, float)
137 REGISTER_TRANSPOSE(GPU, double)
138 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
139 
140 #if GOOGLE_CUDA
141 REGISTER_TRANSPOSE(GPU, complex64)
142 REGISTER_TRANSPOSE(GPU, complex128)
143 #endif  // GOOGLE_CUDA
144 
145 #undef REGISTER_TRANSPOSE
146 
147 namespace functor {
148 
149 template <typename Device, typename T>
operator ()(OpKernelContext * ctx,bool conjugate,const CSRSparseMatrix & input_matrix,CSRSparseMatrix * output_matrix)150 Status CSRSparseMatrixTranspose<Device, T>::operator()(
151     OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix,
152     CSRSparseMatrix* output_matrix) {
153   const int rank = input_matrix.dims();
154   Tensor output_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
155   const Tensor& input_dense_shape_t = input_matrix.dense_shape();
156   auto input_dense_shape = input_dense_shape_t.vec<int64>();
157   auto output_dense_shape = output_dense_shape_t.vec<int64>();
158   const int64 batch_size = input_matrix.batch_size();
159   if (rank == 3) {
160     output_dense_shape(0) = batch_size;
161   }
162   output_dense_shape(rank - 2) = input_dense_shape(rank - 1);
163   output_dense_shape(rank - 1) = input_dense_shape(rank - 2);
164   const int64 output_rows = output_dense_shape(rank - 2);
165 
166   // nnzs per batch do not change with matrix transposition.
167   Tensor batch_ptr_t = input_matrix.batch_pointers();
168   const int total_nnz = input_matrix.total_nnz();
169 
170   Tensor output_row_ptr_t;
171   Tensor output_col_ind_t;
172   Tensor output_values_t;
173 
174   TF_RETURN_IF_ERROR(ctx->allocate_temp(
175       DT_INT32, TensorShape({batch_size * (output_rows + 1)}),
176       &output_row_ptr_t));
177   TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
178                                         &output_col_ind_t));
179   TF_RETURN_IF_ERROR(ctx->allocate_temp(
180       DataTypeToEnum<T>::value, TensorShape({total_nnz}), &output_values_t));
181 
182   TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
183       DataTypeToEnum<T>::value, output_dense_shape_t, batch_ptr_t,
184       output_row_ptr_t, output_col_ind_t, output_values_t, output_matrix));
185 
186   // Set the output row pointers to zero, in case we hit any empty
187   // input batches.
188   functor::SetZeroFunctor<Device, int32> set_zero;
189   const Device& d = ctx->eigen_device<Device>();
190   set_zero(d, output_row_ptr_t.flat<int32>());
191 
192   functor::CSRSparseMatrixTransposeComponent<Device, T> transpose_component;
193   for (int i = 0; i < batch_size; ++i) {
194     if (output_matrix->nnz(i) == 0) {
195       continue;
196     }
197     ConstCSRComponent<T> input_comp{
198         input_matrix.row_pointers_vec(i), input_matrix.col_indices_vec(i),
199         input_matrix.values_vec<T>(i), input_dense_shape};
200     CSRComponent<T> output_comp{
201         output_matrix->row_pointers_vec(i), output_matrix->col_indices_vec(i),
202         output_matrix->values_vec<T>(i), output_dense_shape};
203 
204     TF_RETURN_IF_ERROR(transpose_component(ctx, input_comp, &output_comp));
205   }
206   if (conjugate) {
207     // conjugate all values with a single kernel launch.
208     maybe_conj_inplace<Device, T>::run(d, &output_values_t);
209   }
210 
211   return Status::OK();
212 }
213 
214 // CPU kernel for transposing a single component of a CSR SparseMatrix.
215 template <typename T>
216 struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
217   using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
218 
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent219   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& input,
220                     CSRComponent<T>* output) {
221     TF_RETURN_IF_ERROR(ValidateTransposeInputs(input, *output));
222 
223     const int rank = input.dense_shape_host.size();
224     const int num_rows = input.row_ptr.size() - 1;
225     const int num_cols = input.dense_shape_host(rank - 1);
226     const int64 nnz = input.col_ind.size();
227 
228     // Compute the column counts; whose prefix sums make up the output row
229     // pointers.
230     for (int64 i = 0; i < nnz; ++i) {
231       output->row_ptr(input.col_ind(i) + 1) += 1;
232     }
233     std::partial_sum(output->row_ptr.data(),
234                      output->row_ptr.data() + num_cols + 1,
235                      output->row_ptr.data());
236 
237     // Iterate through each row of the input, and place each non-zero element
238     // into the target output row (based on the current column count).
239     std::vector<int> current_col_count(num_cols);
240     for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
241       const int64 row_begin = input.row_ptr(row_idx);
242       const int64 row_end = input.row_ptr(row_idx + 1);
243       for (int64 i = row_begin; i < row_end; ++i) {
244         const int col_idx = input.col_ind(i);
245         const int64 offset =
246             output->row_ptr(col_idx) + current_col_count[col_idx];
247         output->col_ind(offset) = row_idx;
248         output->values(offset) = input.values(i);
249         current_col_count[col_idx] += 1;
250       }
251     }
252     return Status::OK();
253   }
254 };
255 
256 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257 
258 template <typename T>
259 struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
operator ()tensorflow::functor::CSRSparseMatrixTransposeComponent260   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
261                     CSRComponent<T>* y) {
262     TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
263     GpuSparse cuda_sparse(ctx);
264     TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
265     const gpusparseAction_t copyValues = GPUSPARSE(ACTION_NUMERIC);
266     const int rank = x.dense_shape_host.size();
267     const int m = x.row_ptr.size() - 1;
268     const int n = x.dense_shape_host(rank - 1);
269     const int nnz = x.col_ind.size();
270     DCHECK_EQ(nnz, x.values.size());
271     DCHECK_EQ(n, y->row_ptr.size() - 1);
272     DCHECK_EQ(rank, y->dense_shape_host.size());
273     DCHECK_EQ(m, y->dense_shape_host(rank - 1));
274     DCHECK_EQ(nnz, y->col_ind.size());
275     DCHECK_EQ(nnz, y->values.size());
276 
277     return cuda_sparse.Csr2csc(
278         m, n, nnz, x.values.data() /*csrVal*/, x.row_ptr.data() /*csrRowPtr*/,
279         x.col_ind.data() /*csrColInd*/, y->values.data() /*cscVal*/,
280         y->col_ind.data() /*cscRowInd*/, y->row_ptr.data() /*cscColPtr*/,
281         copyValues);
282     return Status::OK();
283   }
284 };
285 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
286 }  // namespace functor
287 
288 }  // namespace tensorflow
289