1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements the kernel for the CSRSoftmax op, which performs softmax 17 // along the innermost (col) dimension of a CSRSparseMatrix object 18 // stored in a DT_VARIANT. 19 20 #define EIGEN_USE_THREADS 21 22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 23 #include "tensorflow/core/kernels/cuda_sparse.h" 24 #define EIGEN_USE_GPU 25 #endif 26 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/tensor_types.h" 31 #include "tensorflow/core/framework/variant_op_registry.h" 32 #include "tensorflow/core/kernels/dense_update_functor.h" 33 #include "tensorflow/core/kernels/fill_functor.h" 34 #include "tensorflow/core/kernels/slice_op.h" 35 #include "tensorflow/core/kernels/sparse/kernels.h" 36 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T> 44 class CSRSoftmaxOp : public OpKernel { 45 public: CSRSoftmaxOp(OpKernelConstruction * ctx)46 explicit CSRSoftmaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 47 Compute(OpKernelContext * ctx)48 void Compute(OpKernelContext* ctx) override { 49 const CSRSparseMatrix* logits_matrix; 50 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &logits_matrix)); 51 OP_REQUIRES( 52 ctx, logits_matrix->dtype() == DataTypeToEnum<T>::value, 53 errors::InvalidArgument("dtype of logits is not equal to 'type': ", 54 DataTypeString(logits_matrix->dtype()), " vs. ", 55 DataTypeString(DataTypeToEnum<T>::value))); 56 57 // Allocate output shapes 58 const int total_nnz = logits_matrix->total_nnz(); 59 Tensor output_values_t; 60 OP_REQUIRES_OK( 61 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 62 TensorShape({total_nnz}), &output_values_t)); 63 64 CSRSparseMatrix output_matrix; 65 66 Tensor dense_shape_t = logits_matrix->dense_shape(); 67 68 OP_REQUIRES_OK( 69 ctx, 70 CSRSparseMatrix::CreateCSRSparseMatrix( 71 DataTypeToEnum<T>::value, dense_shape_t, 72 logits_matrix->batch_pointers(), logits_matrix->row_pointers(), 73 logits_matrix->col_indices(), output_values_t, &output_matrix)); 74 75 if (total_nnz > 0) { 76 functor::CSRSparseMatrixSoftmax<Device, T> softmax; 77 OP_REQUIRES_OK( 78 ctx, softmax(ctx, *logits_matrix, output_matrix.values().vec<T>())); 79 } 80 81 Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({})); 82 output_t.scalar<Variant>()() = std::move(output_matrix); 83 ctx->set_output(0, output_t); 84 } 85 }; 86 87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 88 #define REGISTER(DEV, T) \ 89 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \ 90 .Device(DEVICE_##DEV) \ 91 .TypeConstraint<T>("type"), \ 92 CSRSoftmaxOp<DEV##Device, T>); 93 94 REGISTER(GPU, float) 95 REGISTER(GPU, double) 96 97 #undef REGISTER 98 99 namespace functor { 100 #define DECLARE_GPU_SPEC(T) \ 101 template <> \ 102 Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \ 103 OpKernelContext* ctx, const CSRSparseMatrix& logits, \ 104 typename TTypes<T>::Vec softmax_values); \ 105 extern template struct CSRSparseMatrixSoftmax<GPUDevice, T>; 106 107 DECLARE_GPU_SPEC(float); 108 DECLARE_GPU_SPEC(double); 109 110 #undef DECLARE_GPU_SPEC 111 } // namespace functor 112 113 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 114 115 template <typename Device, typename T> 116 class CSRSoftmaxGradOp : public OpKernel { 117 public: CSRSoftmaxGradOp(OpKernelConstruction * ctx)118 explicit CSRSoftmaxGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 119 Compute(OpKernelContext * ctx)120 void Compute(OpKernelContext* ctx) override { 121 const CSRSparseMatrix* softmax_matrix; 122 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &softmax_matrix)); 123 OP_REQUIRES(ctx, softmax_matrix->dtype() == DataTypeToEnum<T>::value, 124 errors::InvalidArgument( 125 "dtype of softmax is not equal to 'type': ", 126 DataTypeString(softmax_matrix->dtype()), " vs. ", 127 DataTypeString(DataTypeToEnum<T>::value))); 128 129 const CSRSparseMatrix* grad_softmax_matrix; 130 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &grad_softmax_matrix)); 131 OP_REQUIRES(ctx, grad_softmax_matrix->dtype() == DataTypeToEnum<T>::value, 132 errors::InvalidArgument( 133 "dtype of grad_softmax is not equal to 'type': ", 134 DataTypeString(grad_softmax_matrix->dtype()), " vs. ", 135 DataTypeString(DataTypeToEnum<T>::value))); 136 137 OP_REQUIRES( 138 ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(), 139 errors::InvalidArgument( 140 "Ranks of softmax and grad_softmax matrices differ: ", 141 softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims())); 142 143 OP_REQUIRES( 144 ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(), 145 errors::InvalidArgument( 146 "Ranks of softmax and grad_softmax matrices differ: ", 147 softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims())); 148 149 Tensor dense_shape_t = softmax_matrix->dense_shape(); 150 auto host_dense_shape = 151 static_cast<const Tensor>(dense_shape_t).vec<int64>(); 152 153 auto host_grad_dense_shape = 154 grad_softmax_matrix->dense_shape().vec<int64>(); 155 156 for (int i = 0; i < host_dense_shape.size(); ++i) { 157 OP_REQUIRES(ctx, host_dense_shape(i) == host_grad_dense_shape(i), 158 errors::InvalidArgument( 159 "Shapes of softmax and grad_softmax matrices differ: ", 160 dense_shape_t.SummarizeValue(3), " vs. ", 161 grad_softmax_matrix->dense_shape().SummarizeValue(3))); 162 } 163 164 // Allocate output shapes. Note that since the Softmax Gradient 165 // tensor is the elementwise product of some function with the 166 // softmax value, it will keep the sparsity structure of the softmax. 167 const int total_nnz = softmax_matrix->total_nnz(); 168 PersistentTensor gradient_values_pt; 169 Tensor* gradient_values_t; 170 OP_REQUIRES_OK(ctx, ctx->allocate_persistent( 171 DataTypeToEnum<T>::value, TensorShape({total_nnz}), 172 &gradient_values_pt, &gradient_values_t)); 173 174 CSRSparseMatrix gradient_matrix; 175 176 OP_REQUIRES_OK( 177 ctx, CSRSparseMatrix::CreateCSRSparseMatrix( 178 DataTypeToEnum<T>::value, dense_shape_t, 179 softmax_matrix->batch_pointers(), 180 softmax_matrix->row_pointers(), softmax_matrix->col_indices(), 181 *gradient_values_t, &gradient_matrix)); 182 183 if (total_nnz > 0) { 184 functor::CSRSparseMatrixSoftmaxGrad<Device, T> softmax_grad; 185 OP_REQUIRES_OK(ctx, 186 softmax_grad(ctx, *softmax_matrix, *grad_softmax_matrix, 187 gradient_matrix.values().vec<T>())); 188 } 189 190 Tensor gradient_t(cpu_allocator(), DT_VARIANT, TensorShape({})); 191 gradient_t.scalar<Variant>()() = std::move(gradient_matrix); 192 ctx->set_output(0, gradient_t); 193 } 194 }; 195 196 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 197 #define REGISTER(DEV, T) \ 198 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \ 199 .Device(DEVICE_##DEV) \ 200 .TypeConstraint<T>("type"), \ 201 CSRSoftmaxGradOp<DEV##Device, T>); 202 203 REGISTER(GPU, float) 204 REGISTER(GPU, double) 205 206 #undef REGISTER 207 208 namespace functor { 209 #define DECLARE_GPU_SPEC(T) \ 210 template <> \ 211 Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \ 212 OpKernelContext* ctx, const CSRSparseMatrix& softmax, \ 213 const CSRSparseMatrix& grad_softmax, \ 214 typename TTypes<T>::Vec gradient_values); \ 215 extern template struct CSRSparseMatrixSoftmaxGrad<GPUDevice, T>; 216 217 DECLARE_GPU_SPEC(float); 218 DECLARE_GPU_SPEC(double); 219 220 #undef DECLARE_GPU_SPEC 221 } // namespace functor 222 223 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 224 225 } // namespace tensorflow 226