• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the kernel for the CSRSoftmax op, which performs softmax
17 // along the innermost (col) dimension of a CSRSparseMatrix object
18 // stored in a DT_VARIANT.
19 
20 #define EIGEN_USE_THREADS
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/kernels/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/dense_update_functor.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/kernels/slice_op.h"
35 #include "tensorflow/core/kernels/sparse/kernels.h"
36 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename T>
44 class CSRSoftmaxOp : public OpKernel {
45  public:
CSRSoftmaxOp(OpKernelConstruction * ctx)46   explicit CSRSoftmaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
47 
Compute(OpKernelContext * ctx)48   void Compute(OpKernelContext* ctx) override {
49     const CSRSparseMatrix* logits_matrix;
50     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &logits_matrix));
51     OP_REQUIRES(
52         ctx, logits_matrix->dtype() == DataTypeToEnum<T>::value,
53         errors::InvalidArgument("dtype of logits is not equal to 'type': ",
54                                 DataTypeString(logits_matrix->dtype()), " vs. ",
55                                 DataTypeString(DataTypeToEnum<T>::value)));
56 
57     // Allocate output shapes
58     const int total_nnz = logits_matrix->total_nnz();
59     Tensor output_values_t;
60     OP_REQUIRES_OK(
61         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
62                                 TensorShape({total_nnz}), &output_values_t));
63 
64     CSRSparseMatrix output_matrix;
65 
66     Tensor dense_shape_t = logits_matrix->dense_shape();
67 
68     OP_REQUIRES_OK(
69         ctx,
70         CSRSparseMatrix::CreateCSRSparseMatrix(
71             DataTypeToEnum<T>::value, dense_shape_t,
72             logits_matrix->batch_pointers(), logits_matrix->row_pointers(),
73             logits_matrix->col_indices(), output_values_t, &output_matrix));
74 
75     if (total_nnz > 0) {
76       functor::CSRSparseMatrixSoftmax<Device, T> softmax;
77       OP_REQUIRES_OK(
78           ctx, softmax(ctx, *logits_matrix, output_matrix.values().vec<T>()));
79     }
80 
81     Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
82     output_t.scalar<Variant>()() = std::move(output_matrix);
83     ctx->set_output(0, output_t);
84   }
85 };
86 
87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88 #define REGISTER(DEV, T)                                  \
89   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax")     \
90                               .Device(DEVICE_##DEV)       \
91                               .TypeConstraint<T>("type"), \
92                           CSRSoftmaxOp<DEV##Device, T>);
93 
94 REGISTER(GPU, float)
95 REGISTER(GPU, double)
96 
97 #undef REGISTER
98 
99 namespace functor {
100 #define DECLARE_GPU_SPEC(T)                                \
101   template <>                                              \
102   Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
103       OpKernelContext* ctx, const CSRSparseMatrix& logits, \
104       typename TTypes<T>::Vec softmax_values);             \
105   extern template struct CSRSparseMatrixSoftmax<GPUDevice, T>;
106 
107 DECLARE_GPU_SPEC(float);
108 DECLARE_GPU_SPEC(double);
109 
110 #undef DECLARE_GPU_SPEC
111 }  // namespace functor
112 
113 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
114 
115 template <typename Device, typename T>
116 class CSRSoftmaxGradOp : public OpKernel {
117  public:
CSRSoftmaxGradOp(OpKernelConstruction * ctx)118   explicit CSRSoftmaxGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
119 
Compute(OpKernelContext * ctx)120   void Compute(OpKernelContext* ctx) override {
121     const CSRSparseMatrix* softmax_matrix;
122     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &softmax_matrix));
123     OP_REQUIRES(ctx, softmax_matrix->dtype() == DataTypeToEnum<T>::value,
124                 errors::InvalidArgument(
125                     "dtype of softmax is not equal to 'type': ",
126                     DataTypeString(softmax_matrix->dtype()), " vs. ",
127                     DataTypeString(DataTypeToEnum<T>::value)));
128 
129     const CSRSparseMatrix* grad_softmax_matrix;
130     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &grad_softmax_matrix));
131     OP_REQUIRES(ctx, grad_softmax_matrix->dtype() == DataTypeToEnum<T>::value,
132                 errors::InvalidArgument(
133                     "dtype of grad_softmax is not equal to 'type': ",
134                     DataTypeString(grad_softmax_matrix->dtype()), " vs. ",
135                     DataTypeString(DataTypeToEnum<T>::value)));
136 
137     OP_REQUIRES(
138         ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
139         errors::InvalidArgument(
140             "Ranks of softmax and grad_softmax matrices differ: ",
141             softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
142 
143     OP_REQUIRES(
144         ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
145         errors::InvalidArgument(
146             "Ranks of softmax and grad_softmax matrices differ: ",
147             softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
148 
149     Tensor dense_shape_t = softmax_matrix->dense_shape();
150     auto host_dense_shape =
151         static_cast<const Tensor>(dense_shape_t).vec<int64>();
152 
153     auto host_grad_dense_shape =
154         grad_softmax_matrix->dense_shape().vec<int64>();
155 
156     for (int i = 0; i < host_dense_shape.size(); ++i) {
157       OP_REQUIRES(ctx, host_dense_shape(i) == host_grad_dense_shape(i),
158                   errors::InvalidArgument(
159                       "Shapes of softmax and grad_softmax matrices differ: ",
160                       dense_shape_t.SummarizeValue(3), " vs. ",
161                       grad_softmax_matrix->dense_shape().SummarizeValue(3)));
162     }
163 
164     // Allocate output shapes.  Note that since the Softmax Gradient
165     // tensor is the elementwise product of some function with the
166     // softmax value, it will keep the sparsity structure of the softmax.
167     const int total_nnz = softmax_matrix->total_nnz();
168     PersistentTensor gradient_values_pt;
169     Tensor* gradient_values_t;
170     OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
171                             DataTypeToEnum<T>::value, TensorShape({total_nnz}),
172                             &gradient_values_pt, &gradient_values_t));
173 
174     CSRSparseMatrix gradient_matrix;
175 
176     OP_REQUIRES_OK(
177         ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
178                  DataTypeToEnum<T>::value, dense_shape_t,
179                  softmax_matrix->batch_pointers(),
180                  softmax_matrix->row_pointers(), softmax_matrix->col_indices(),
181                  *gradient_values_t, &gradient_matrix));
182 
183     if (total_nnz > 0) {
184       functor::CSRSparseMatrixSoftmaxGrad<Device, T> softmax_grad;
185       OP_REQUIRES_OK(ctx,
186                      softmax_grad(ctx, *softmax_matrix, *grad_softmax_matrix,
187                                   gradient_matrix.values().vec<T>()));
188     }
189 
190     Tensor gradient_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
191     gradient_t.scalar<Variant>()() = std::move(gradient_matrix);
192     ctx->set_output(0, gradient_t);
193   }
194 };
195 
196 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
197 #define REGISTER(DEV, T)                                  \
198   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
199                               .Device(DEVICE_##DEV)       \
200                               .TypeConstraint<T>("type"), \
201                           CSRSoftmaxGradOp<DEV##Device, T>);
202 
203 REGISTER(GPU, float)
204 REGISTER(GPU, double)
205 
206 #undef REGISTER
207 
208 namespace functor {
209 #define DECLARE_GPU_SPEC(T)                                    \
210   template <>                                                  \
211   Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
212       OpKernelContext* ctx, const CSRSparseMatrix& softmax,    \
213       const CSRSparseMatrix& grad_softmax,                     \
214       typename TTypes<T>::Vec gradient_values);                \
215   extern template struct CSRSparseMatrixSoftmaxGrad<GPUDevice, T>;
216 
217 DECLARE_GPU_SPEC(float);
218 DECLARE_GPU_SPEC(double);
219 
220 #undef DECLARE_GPU_SPEC
221 }  // namespace functor
222 
223 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
224 
225 }  // namespace tensorflow
226