• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/variant_op_registry.h"
27 #include "tensorflow/core/kernels/dense_update_functor.h"
28 #include "tensorflow/core/kernels/slice_op.h"
29 #include "tensorflow/core/kernels/sparse/kernels.h"
30 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
31 
32 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
33 #include "tensorflow/core/util/cuda_sparse.h"
34 #include "tensorflow/core/util/gpu_solvers.h"
35 #endif
36 
37 namespace tensorflow {
38 
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41 
42 template <typename Device, typename T>
43 class CSRSparseMatrixComponentsOp : public OpKernel {
44  public:
CSRSparseMatrixComponentsOp(OpKernelConstruction * c)45   explicit CSRSparseMatrixComponentsOp(OpKernelConstruction* c) : OpKernel(c) {}
46 
Compute(OpKernelContext * c)47   void Compute(OpKernelContext* c) final {
48     const CSRSparseMatrix* csr_sparse_matrix;
49     OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
50 
51     const Tensor& index_t = c->input(1);
52     OP_REQUIRES(c, DataTypeToEnum<T>::value == csr_sparse_matrix->dtype(),
53                 errors::InvalidArgument(
54                     "dtype of input is not equal to 'type': ",
55                     DataTypeString(csr_sparse_matrix->dtype()), " vs. ",
56                     DataTypeString(DataTypeToEnum<T>::value)));
57     OP_REQUIRES(c, index_t.dims() == 0,
58                 errors::InvalidArgument("index should be a scalar, but saw: ",
59                                         index_t.DebugString()));
60     int32_t index = index_t.scalar<int32>()();
61     OP_REQUIRES(c, index >= 0 && index < csr_sparse_matrix->batch_size(),
62                 errors::InvalidArgument("index (", index, ") not in [0, ",
63                                         csr_sparse_matrix->batch_size(), ")"));
64 
65     if (csr_sparse_matrix->dims() == 2) {
66       c->set_output(0, csr_sparse_matrix->row_pointers());
67       c->set_output(1, csr_sparse_matrix->col_indices());
68       c->set_output(2, csr_sparse_matrix->values());
69     } else {
70       auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
71       auto dense_shape = csr_sparse_matrix->dense_shape().vec<int64_t>();
72       int64_t rows = dense_shape(1);
73       int nnz = batch_ptrs(index + 1) - batch_ptrs(index);
74       Tensor* row_ptrs_t;
75       Tensor* col_inds_t;
76       Tensor* values_t;
77       OP_REQUIRES_OK(
78           c, c->allocate_output(0, TensorShape({rows + 1}), &row_ptrs_t));
79       OP_REQUIRES_OK(c, c->allocate_output(1, TensorShape({nnz}), &col_inds_t));
80       OP_REQUIRES_OK(c, c->allocate_output(2, TensorShape({nnz}), &values_t));
81       auto row_ptrs = row_ptrs_t->vec<int32>();
82       auto col_inds = col_inds_t->vec<int32>();
83       auto values = values_t->vec<T>();
84 
85       functor::Slice<Device, int32, 1> slice_int;
86       functor::Slice<Device, T, 1> slice_t;
87       typedef Eigen::DSizes<Eigen::DenseIndex, 1> EVec;
88       const Device& d = c->eigen_device<Device>();
89       slice_int(d,
90                 /*output*/ row_ptrs,
91                 /*input*/ csr_sparse_matrix->row_pointers().vec<int32>(),
92                 /*slice_indices*/
93                 EVec{static_cast<Eigen::DenseIndex>(index * (rows + 1))},
94                 /*slice_sizes*/ EVec{static_cast<Eigen::DenseIndex>(rows + 1)});
95       slice_int(d,
96                 /*output*/ col_inds,
97                 /*input*/ csr_sparse_matrix->col_indices().vec<int32>(),
98                 /*slice_indices*/ EVec{batch_ptrs(index)},
99                 /*slice_sizes*/ EVec{nnz});
100       slice_t(d,
101               /*output*/ values, /*input*/ csr_sparse_matrix->values().vec<T>(),
102               /*slice_indices*/ EVec{batch_ptrs(index)},
103               /*slice_sizes*/ EVec{nnz});
104     }
105   }
106 };
107 
108 #define REGISTER(DEV, T)                                    \
109   REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixComponents") \
110                               .Device(DEVICE_##DEV)         \
111                               .TypeConstraint<T>("type")    \
112                               .HostMemory("index"),         \
113                           CSRSparseMatrixComponentsOp<DEV##Device, T>);
114 
115 REGISTER(CPU, float)
116 REGISTER(CPU, double)
117 REGISTER(CPU, complex64)
118 REGISTER(CPU, complex128)
119 
120 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
121 
122 REGISTER(GPU, float)
123 REGISTER(GPU, double)
124 REGISTER(GPU, complex64)
125 REGISTER(GPU, complex128)
126 
127 #undef REGISTER
128 
129 namespace functor {
130 // TODO(ebrevdo): This should move to a slice_functor.cc
131 #define DECLARE_GPU_SPEC(T)                                     \
132   template <>                                                   \
133   void Slice<GPUDevice, T, 1>::operator()(                      \
134       const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \
135       typename TTypes<T, 1>::ConstTensor input,                 \
136       const Eigen::DSizes<Eigen::DenseIndex, 1>& indices,       \
137       const Eigen::DSizes<Eigen::DenseIndex, 1>& sizes);        \
138   extern template struct Slice<GPUDevice, T, 1>;
139 
140 DECLARE_GPU_SPEC(int32);
141 DECLARE_GPU_SPEC(float);
142 DECLARE_GPU_SPEC(double);
143 DECLARE_GPU_SPEC(complex64);
144 DECLARE_GPU_SPEC(complex128);
145 
146 #undef DECLARE_GPU_SPEC
147 }  // namespace functor
148 
149 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
150 
151 }  // namespace tensorflow
152