• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 namespace functor {
29 
30 // Calculates number of nonzero entries per batch of a sorted rank-3
31 // SparseTensor's indices.  indices is expected to have columns
32 // corresponding to [batch, row, column],  where indices[:,0] < B.
33 //
34 // REQUIRES:
35 //  indices.dimension(1) == 3
36 //  nnz_per_batch.dimension(0) == B
37 template <typename Device>
38 struct CalculateNNZPerBatchMatrixFromIndices {
39   Status operator()(OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
40                     TTypes<int32>::Vec nnz_per_batch);
41 };
42 
43 // Split a subset of a SparseTensors' indices into two vectors:
44 // COO row inds and COO col inds.  Outputs are:
45 //
46 //   coo_row_ind = indices[:, row_dim]
47 //   coo_col_ind = indices[:, row_dim + 1]
48 //
49 // where n = coo_row_ind.size()
50 // and row_dim = #cols(indices) - 1
51 //
52 // REQUIRES:
53 //   host_dense_shape.size() in [2, 3]
54 //   indices.dim_size(1) == host_dense_shape.size()
55 //   coo_row_ind.size() == coo_col_ind.size()
56 //   coo_row_ind.size() == indices.dim_size(0)
57 template <typename Device>
58 struct SparseTensorToCOOSparseMatrix {
59   void operator()(const Device& d, TTypes<int64>::ConstVec host_dense_shape,
60                   TTypes<int64>::ConstMatrix indices,
61                   TTypes<int32>::Vec coo_row_ind,
62                   TTypes<int32>::Vec coo_col_ind);
63 };
64 
65 // Write coo batch, row, and column vectors to output matrix indices:
66 //
67 //   indices[:, row_dim] = coo_row_ind
68 //   indices[:, col_dim] = coo_col_ind
69 //
70 // where row_dim = #cols(indices) - 1 and n = coo_row_ind.size().
71 // In addition, if #cols(indices) == 3, also store the batch:
72 //
73 //   indices[i, 0] = batch_of(i) where
74 //      host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1)
75 //
76 // REQUIRES:
77 //
78 //   host_dense_shape.size() in [2, 3]
79 //   indices.dim_size(1) == host_dense_shape.size()
80 //   host_batch_ptr.size() ==
81 //   coo_row_ind.size() == coo_col_ind.size()
82 //
83 template <typename Device>
84 struct COOSparseMatrixToSparseTensor {
85   Status operator()(OpKernelContext* c,
86                     TTypes<int64>::ConstVec host_dense_shape,
87                     TTypes<int32>::ConstVec host_batch_ptrs,
88                     TTypes<int32>::Vec coo_row_ind,
89                     TTypes<int32>::ConstVec coo_col_ind,
90                     TTypes<int64>::Matrix indices);
91 };
92 
93 // Convert a vector of coo row indices to csr row pointers.
94 //
95 // REQUIRES:
96 //
97 //   csr_row_ptr.size() == rows + 1.
98 //   max(coo_row_ptr) < rows.
99 //
100 template <typename Device>
101 struct COOSparseMatrixToCSRSparseMatrix {
102   Status operator()(OpKernelContext* c, const int rows, const int cols,
103                     TTypes<int32>::UnalignedVec coo_row_ind,
104                     TTypes<int32>::UnalignedVec csr_row_ptr);
105 };
106 
107 // Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix
108 // batch ptrs, csr row pointers and coo column indices.
109 //
110 // REQUIRES:
111 //   batch_ptr.size() == batch_size + 1
112 //   csr_row_ptr.size() == batch_size * (num_rows + 1)
113 //   csr_col_ind.size() == total_nnz
114 //   batch_size == 1 if rank == 2
115 //
116 //   where
117 //     total_nnz = indices.dim_size(0)
118 //     rank = indices.dim_size(1)
119 //   Also csr_row_ptr should be initially filled with zeros.
120 //
121 struct SparseTensorToCSRSparseMatrixCPUFunctor {
122   Status operator()(const int64 batch_size, const int num_rows,
123                     TTypes<int64>::ConstMatrix indices,
124                     TTypes<int32>::Vec batch_ptr,
125                     TTypes<int32>::Vec csr_row_ptr,
126                     TTypes<int32>::Vec csr_col_ind);
127 };
128 
129 // Convert a vector of csr row pointers to coo row indices.
130 //
131 // REQUIRES:
132 //
133 //   coo_row_ptr.size() == nnz.
134 //   csr_row_ptr[-1] == nnz.
135 //
136 template <typename Device>
137 struct CSRSparseMatrixToCOOSparseMatrix {
138   Status operator()(OpKernelContext* c,
139                     TTypes<int32>::UnalignedConstVec csr_row_ptr,
140                     TTypes<int32>::UnalignedVec coo_row_ind);
141 };
142 
143 // Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format
144 // and B and C are dense.
145 template <typename Device, typename T>
146 struct CSRSparseMatrixMatMul {
147   explicit CSRSparseMatrixMatMul(const bool transpose_output);
148   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
149                  typename TTypes<T>::ConstMatrix b,
150                  typename TTypes<T>::Matrix c);
151 };
152 
153 // Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format
154 // and x and y are dense vectors.
155 template <typename Device, typename T>
156 class CSRSparseMatrixMatVec {
157   CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a);
158   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
159                  const T* x, T* y);
160 };
161 
162 // Calculates C = functor(A, B) where A and B are CSR and C is CSR
163 // with a different sparsity pattern.
164 template <typename Device, typename T>
165 struct CSRStructureModifyingFunctor {
~CSRStructureModifyingFunctorCSRStructureModifyingFunctor166   virtual ~CSRStructureModifyingFunctor() {}
167 
168   virtual Status Initialize() = 0;
169 
170   virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
171                                     const ConstCSRComponent<T>& b,
172                                     TTypes<int32>::UnalignedVec c_row_ptr,
173                                     int* output_nnz) = 0;
174 
175   virtual Status Compute(const ConstCSRComponent<T>& a,
176                          const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0;
177 };
178 
179 // Calculates C = alpha * A + beta * B, where A and B are in CSR
180 // format, and alpha and beta are scalars on the host.
181 template <typename Device, typename T>
182 struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor<Device, T> {
183   explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha,
184                               const T beta);
185 };
186 
187 // Calculates C = matmul(A, B), where A, B, and C are in CSR format.
188 template <typename Device, typename T>
189 struct CSRSparseSparseMatrixMatMul
190     : public CSRStructureModifyingFunctor<Device, T> {
191   explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
192                                        bool transpose_b);
193 };
194 
195 // Calculates Y = transpose(X) where X and Y are CSR format components.
196 template <typename Device, typename T>
197 struct CSRSparseMatrixTransposeComponent {
198   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
199                     CSRComponent<T>* y);
200 };
201 
202 // Calculates Y = transpose(X) where X and Y are in CSR format.
203 template <typename Device, typename T>
204 struct CSRSparseMatrixTranspose {
205   Status operator()(OpKernelContext* ctx, bool conjugate,
206                     const CSRSparseMatrix& input_matrix,
207                     CSRSparseMatrix* output_matrix);
208 };
209 
210 // Calculates Y = softmax(X) where X and Y are in CSR format;
211 // missing coefficients in X are treates as -inf (logits of 0 probability).
212 template <typename Device, typename T>
213 struct CSRSparseMatrixSoftmax {
214   Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits,
215                     typename TTypes<T>::Vec softmax_values);
216 };
217 
218 template <typename Device, typename T>
219 struct CSRSparseMatrixSoftmaxGrad {
220   Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax,
221                     const CSRSparseMatrix& grad_softmax,
222                     typename TTypes<T>::Vec gradient_values);
223 };
224 
225 template <typename Device, typename T>
226 class CSRSparseMatrixMulScalar {
227  public:
CSRSparseMatrixMulScalar()228   explicit CSRSparseMatrixMulScalar() {}
229 
230   Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
231                  typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c);
232 };
233 
234 template <typename Device, typename T>
235 class CSRSparseMatrixBatchMulVec {
236  public:
CSRSparseMatrixBatchMulVec()237   explicit CSRSparseMatrixBatchMulVec() {}
238 
239   Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
240                  typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c);
241 };
242 
243 }  // namespace functor
244 
245 }  // namespace tensorflow
246 
247 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
248