1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 17 #define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 namespace functor { 29 30 // Calculates number of nonzero entries per batch of a sorted rank-3 31 // SparseTensor's indices. indices is expected to have columns 32 // corresponding to [batch, row, column], where indices[:,0] < B. 33 // 34 // REQUIRES: 35 // indices.dimension(1) == 3 36 // nnz_per_batch.dimension(0) == B 37 template <typename Device> 38 struct CalculateNNZPerBatchMatrixFromIndices { 39 Status operator()(OpKernelContext* c, TTypes<int64>::ConstMatrix indices, 40 TTypes<int32>::Vec nnz_per_batch); 41 }; 42 43 // Split a subset of a SparseTensors' indices into two vectors: 44 // COO row inds and COO col inds. Outputs are: 45 // 46 // coo_row_ind = indices[:, row_dim] 47 // coo_col_ind = indices[:, row_dim + 1] 48 // 49 // where n = coo_row_ind.size() 50 // and row_dim = #cols(indices) - 1 51 // 52 // REQUIRES: 53 // host_dense_shape.size() in [2, 3] 54 // indices.dim_size(1) == host_dense_shape.size() 55 // coo_row_ind.size() == coo_col_ind.size() 56 // coo_row_ind.size() == indices.dim_size(0) 57 template <typename Device> 58 struct SparseTensorToCOOSparseMatrix { 59 void operator()(const Device& d, TTypes<int64>::ConstVec host_dense_shape, 60 TTypes<int64>::ConstMatrix indices, 61 TTypes<int32>::Vec coo_row_ind, 62 TTypes<int32>::Vec coo_col_ind); 63 }; 64 65 // Write coo batch, row, and column vectors to output matrix indices: 66 // 67 // indices[:, row_dim] = coo_row_ind 68 // indices[:, col_dim] = coo_col_ind 69 // 70 // where row_dim = #cols(indices) - 1 and n = coo_row_ind.size(). 71 // In addition, if #cols(indices) == 3, also store the batch: 72 // 73 // indices[i, 0] = batch_of(i) where 74 // host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1) 75 // 76 // REQUIRES: 77 // 78 // host_dense_shape.size() in [2, 3] 79 // indices.dim_size(1) == host_dense_shape.size() 80 // host_batch_ptr.size() == 81 // coo_row_ind.size() == coo_col_ind.size() 82 // 83 template <typename Device> 84 struct COOSparseMatrixToSparseTensor { 85 Status operator()(OpKernelContext* c, 86 TTypes<int64>::ConstVec host_dense_shape, 87 TTypes<int32>::ConstVec host_batch_ptrs, 88 TTypes<int32>::Vec coo_row_ind, 89 TTypes<int32>::ConstVec coo_col_ind, 90 TTypes<int64>::Matrix indices); 91 }; 92 93 // Convert a vector of coo row indices to csr row pointers. 94 // 95 // REQUIRES: 96 // 97 // csr_row_ptr.size() == rows + 1. 98 // max(coo_row_ptr) < rows. 99 // 100 template <typename Device> 101 struct COOSparseMatrixToCSRSparseMatrix { 102 Status operator()(OpKernelContext* c, const int rows, const int cols, 103 TTypes<int32>::UnalignedVec coo_row_ind, 104 TTypes<int32>::UnalignedVec csr_row_ptr); 105 }; 106 107 // Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix 108 // batch ptrs, csr row pointers and coo column indices. 109 // 110 // REQUIRES: 111 // batch_ptr.size() == batch_size + 1 112 // csr_row_ptr.size() == batch_size * (num_rows + 1) 113 // csr_col_ind.size() == total_nnz 114 // batch_size == 1 if rank == 2 115 // 116 // where 117 // total_nnz = indices.dim_size(0) 118 // rank = indices.dim_size(1) 119 // Also csr_row_ptr should be initially filled with zeros. 120 // 121 struct SparseTensorToCSRSparseMatrixCPUFunctor { 122 Status operator()(const int64 batch_size, const int num_rows, 123 TTypes<int64>::ConstMatrix indices, 124 TTypes<int32>::Vec batch_ptr, 125 TTypes<int32>::Vec csr_row_ptr, 126 TTypes<int32>::Vec csr_col_ind); 127 }; 128 129 // Convert a vector of csr row pointers to coo row indices. 130 // 131 // REQUIRES: 132 // 133 // coo_row_ptr.size() == nnz. 134 // csr_row_ptr[-1] == nnz. 135 // 136 template <typename Device> 137 struct CSRSparseMatrixToCOOSparseMatrix { 138 Status operator()(OpKernelContext* c, 139 TTypes<int32>::UnalignedConstVec csr_row_ptr, 140 TTypes<int32>::UnalignedVec coo_row_ind); 141 }; 142 143 // Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format 144 // and B and C are dense. 145 template <typename Device, typename T> 146 struct CSRSparseMatrixMatMul { 147 explicit CSRSparseMatrixMatMul(const bool transpose_output); 148 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 149 typename TTypes<T>::ConstMatrix b, 150 typename TTypes<T>::Matrix c); 151 }; 152 153 // Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format 154 // and x and y are dense vectors. 155 template <typename Device, typename T> 156 class CSRSparseMatrixMatVec { 157 CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a); 158 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 159 const T* x, T* y); 160 }; 161 162 // Calculates C = functor(A, B) where A and B are CSR and C is CSR 163 // with a different sparsity pattern. 164 template <typename Device, typename T> 165 struct CSRStructureModifyingFunctor { ~CSRStructureModifyingFunctorCSRStructureModifyingFunctor166 virtual ~CSRStructureModifyingFunctor() {} 167 168 virtual Status Initialize() = 0; 169 170 virtual Status GetOutputStructure(const ConstCSRComponent<T>& a, 171 const ConstCSRComponent<T>& b, 172 TTypes<int32>::UnalignedVec c_row_ptr, 173 int* output_nnz) = 0; 174 175 virtual Status Compute(const ConstCSRComponent<T>& a, 176 const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0; 177 }; 178 179 // Calculates C = alpha * A + beta * B, where A and B are in CSR 180 // format, and alpha and beta are scalars on the host. 181 template <typename Device, typename T> 182 struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor<Device, T> { 183 explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, 184 const T beta); 185 }; 186 187 // Calculates C = matmul(A, B), where A, B, and C are in CSR format. 188 template <typename Device, typename T> 189 struct CSRSparseSparseMatrixMatMul 190 : public CSRStructureModifyingFunctor<Device, T> { 191 explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a, 192 bool transpose_b); 193 }; 194 195 // Calculates Y = transpose(X) where X and Y are CSR format components. 196 template <typename Device, typename T> 197 struct CSRSparseMatrixTransposeComponent { 198 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x, 199 CSRComponent<T>* y); 200 }; 201 202 // Calculates Y = transpose(X) where X and Y are in CSR format. 203 template <typename Device, typename T> 204 struct CSRSparseMatrixTranspose { 205 Status operator()(OpKernelContext* ctx, bool conjugate, 206 const CSRSparseMatrix& input_matrix, 207 CSRSparseMatrix* output_matrix); 208 }; 209 210 // Calculates Y = softmax(X) where X and Y are in CSR format; 211 // missing coefficients in X are treates as -inf (logits of 0 probability). 212 template <typename Device, typename T> 213 struct CSRSparseMatrixSoftmax { 214 Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits, 215 typename TTypes<T>::Vec softmax_values); 216 }; 217 218 template <typename Device, typename T> 219 struct CSRSparseMatrixSoftmaxGrad { 220 Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax, 221 const CSRSparseMatrix& grad_softmax, 222 typename TTypes<T>::Vec gradient_values); 223 }; 224 225 template <typename Device, typename T> 226 class CSRSparseMatrixMulScalar { 227 public: CSRSparseMatrixMulScalar()228 explicit CSRSparseMatrixMulScalar() {} 229 230 Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, 231 typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c); 232 }; 233 234 template <typename Device, typename T> 235 class CSRSparseMatrixBatchMulVec { 236 public: CSRSparseMatrixBatchMulVec()237 explicit CSRSparseMatrixBatchMulVec() {} 238 239 Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, 240 typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c); 241 }; 242 243 } // namespace functor 244 245 } // namespace tensorflow 246 247 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 248