1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/framework/variant_op_registry.h" 27 #include "tensorflow/core/kernels/dense_update_functor.h" 28 #include "tensorflow/core/kernels/slice_op.h" 29 #include "tensorflow/core/kernels/sparse/kernels.h" 30 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 31 32 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 33 #include "tensorflow/core/util/cuda_solvers.h" 34 #include "tensorflow/core/util/cuda_sparse.h" 35 #endif 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 42 template <typename Device, typename T> 43 class CSRSparseMatrixComponentsOp : public OpKernel { 44 public: CSRSparseMatrixComponentsOp(OpKernelConstruction * c)45 explicit CSRSparseMatrixComponentsOp(OpKernelConstruction* c) : OpKernel(c) {} 46 Compute(OpKernelContext * c)47 void Compute(OpKernelContext* c) final { 48 const CSRSparseMatrix* csr_sparse_matrix; 49 OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix)); 50 51 const Tensor& index_t = c->input(1); 52 OP_REQUIRES(c, DataTypeToEnum<T>::value == csr_sparse_matrix->dtype(), 53 errors::InvalidArgument( 54 "dtype of input is not equal to 'type': ", 55 DataTypeString(csr_sparse_matrix->dtype()), " vs. ", 56 DataTypeString(DataTypeToEnum<T>::value))); 57 OP_REQUIRES(c, index_t.dims() == 0, 58 errors::InvalidArgument("index should be a scalar, but saw: ", 59 index_t.DebugString())); 60 int32 index = index_t.scalar<int32>()(); 61 OP_REQUIRES(c, index >= 0 && index < csr_sparse_matrix->batch_size(), 62 errors::InvalidArgument("index (", index, ") not in [0, ", 63 csr_sparse_matrix->batch_size(), ")")); 64 65 if (csr_sparse_matrix->dims() == 2) { 66 c->set_output(0, csr_sparse_matrix->row_pointers()); 67 c->set_output(1, csr_sparse_matrix->col_indices()); 68 c->set_output(2, csr_sparse_matrix->values()); 69 } else { 70 auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>(); 71 auto dense_shape = csr_sparse_matrix->dense_shape().vec<int64>(); 72 int64 rows = dense_shape(1); 73 int nnz = batch_ptrs(index + 1) - batch_ptrs(index); 74 Tensor* row_ptrs_t; 75 Tensor* col_inds_t; 76 Tensor* values_t; 77 OP_REQUIRES_OK( 78 c, c->allocate_output(0, TensorShape({rows + 1}), &row_ptrs_t)); 79 OP_REQUIRES_OK(c, c->allocate_output(1, TensorShape({nnz}), &col_inds_t)); 80 OP_REQUIRES_OK(c, c->allocate_output(2, TensorShape({nnz}), &values_t)); 81 auto row_ptrs = row_ptrs_t->vec<int32>(); 82 auto col_inds = col_inds_t->vec<int32>(); 83 auto values = values_t->vec<T>(); 84 85 functor::Slice<Device, int32, 1> slice_int; 86 functor::Slice<Device, T, 1> slice_t; 87 typedef Eigen::DSizes<Eigen::DenseIndex, 1> EVec; 88 const Device& d = c->eigen_device<Device>(); 89 slice_int(d, 90 /*output*/ row_ptrs, 91 /*input*/ csr_sparse_matrix->row_pointers().vec<int32>(), 92 /*slice_indices*/ 93 EVec{static_cast<Eigen::DenseIndex>(index * (rows + 1))}, 94 /*slice_sizes*/ EVec{static_cast<Eigen::DenseIndex>(rows + 1)}); 95 slice_int(d, 96 /*output*/ col_inds, 97 /*input*/ csr_sparse_matrix->col_indices().vec<int32>(), 98 /*slice_indices*/ EVec{batch_ptrs(index)}, 99 /*slice_sizes*/ EVec{nnz}); 100 slice_t(d, 101 /*output*/ values, /*input*/ csr_sparse_matrix->values().vec<T>(), 102 /*slice_indices*/ EVec{batch_ptrs(index)}, 103 /*slice_sizes*/ EVec{nnz}); 104 } 105 } 106 }; 107 108 #define REGISTER(DEV, T) \ 109 REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixComponents") \ 110 .Device(DEVICE_##DEV) \ 111 .TypeConstraint<T>("type") \ 112 .HostMemory("index"), \ 113 CSRSparseMatrixComponentsOp<DEV##Device, T>); 114 115 REGISTER(CPU, float) 116 REGISTER(CPU, double) 117 REGISTER(CPU, complex64) 118 REGISTER(CPU, complex128) 119 120 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 121 122 REGISTER(GPU, float) 123 REGISTER(GPU, double) 124 REGISTER(GPU, complex64) 125 REGISTER(GPU, complex128) 126 127 #undef REGISTER 128 129 namespace functor { 130 // TODO(ebrevdo): This should move to a slice_functor.cc 131 #define DECLARE_GPU_SPEC(T) \ 132 template <> \ 133 void Slice<GPUDevice, T, 1>::operator()( \ 134 const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \ 135 typename TTypes<T, 1>::ConstTensor input, \ 136 const Eigen::DSizes<Eigen::DenseIndex, 1>& indices, \ 137 const Eigen::DSizes<Eigen::DenseIndex, 1>& sizes); \ 138 extern template struct Slice<GPUDevice, T, 1>; 139 140 DECLARE_GPU_SPEC(int32); 141 DECLARE_GPU_SPEC(float); 142 DECLARE_GPU_SPEC(double); 143 DECLARE_GPU_SPEC(complex64); 144 DECLARE_GPU_SPEC(complex128); 145 146 #undef DECLARE_GPU_SPEC 147 } // namespace functor 148 149 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 150 151 } // namespace tensorflow 152