• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #ifdef GOOGLE_CUDA
19 
20 #define EIGEN_USE_GPU
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/cuda_solvers.h"
27 #include "tensorflow/core/kernels/cuda_sparse.h"
28 #include "tensorflow/core/kernels/linalg_ops_common.h"
29 #include "tensorflow/core/kernels/transpose_functor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/util/gpu_device_functions.h"
32 #include "tensorflow/core/util/gpu_kernel_helper.h"
33 #include "tensorflow/core/util/gpu_launch_config.h"
34 
35 namespace tensorflow {
36 
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38 
39 static const char kNotInvertibleScalarMsg[] =
40     "The matrix is not invertible: it is a scalar with value zero.";
41 
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x,bool * __restrict__ not_invertible)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44                                            const Scalar* __restrict__ diags,
45                                            const Scalar* __restrict__ rhs,
46                                            const int num_rhs,
47                                            Scalar* __restrict__ x,
48                                            bool* __restrict__ not_invertible) {
49   if (m == 1) {
50     if (diags[1] == Scalar(0)) {
51       *not_invertible = true;
52       return;
53     }
54     for (int i : GpuGridRangeX(num_rhs)) {
55       x[i] = rhs[i] / diags[1];
56     }
57   } else {
58     Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
59     if (det == Scalar(0)) {
60       *not_invertible = true;
61       return;
62     }
63     for (int i : GpuGridRangeX(num_rhs)) {
64       x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
65       x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
66     }
67   }
68 }
69 
70 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)71 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
72   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
73   se::DeviceMemory<Scalar> typed(wrapped);
74   return typed;
75 }
76 
77 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)78 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
79                         Scalar* dst, const int num_elements) {
80   auto src_device_mem = AsDeviceMemory(src);
81   auto dst_device_mem = AsDeviceMemory(dst);
82   auto* stream = context->op_device_context()->stream();
83   bool copy_status = stream
84                          ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
85                                          sizeof(Scalar) * num_elements)
86                          .ok();
87 
88   if (!copy_status) {
89     context->SetStatus(errors::Internal("Copying device-to-device failed."));
90   }
91 }
92 
93 // This implementation is used in cases when the batching mechanism of
94 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
95 template <class Scalar>
96 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
97  public:
98   INHERIT_LINALG_TYPEDEFS(Scalar);
99 
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)100   explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
101       : Base(context) {
102     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
103   }
104 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const105   void ValidateInputMatrixShapes(
106       OpKernelContext* context,
107       const TensorShapes& input_matrix_shapes) const final {
108     auto num_inputs = input_matrix_shapes.size();
109     OP_REQUIRES(context, num_inputs == 2,
110                 errors::InvalidArgument("Expected two input matrices, got ",
111                                         num_inputs, "."));
112 
113     auto num_diags = input_matrix_shapes[0].dim_size(0);
114     OP_REQUIRES(
115         context, num_diags == 3,
116         errors::InvalidArgument("Expected diagonals to be provided as a "
117                                 "matrix with 3 columns, got ",
118                                 num_diags, " columns."));
119 
120     auto num_rows1 = input_matrix_shapes[0].dim_size(1);
121     auto num_rows2 = input_matrix_shapes[1].dim_size(0);
122     OP_REQUIRES(context, num_rows1 == num_rows2,
123                 errors::InvalidArgument("Expected same number of rows in both "
124                                         "arguments, got ",
125                                         num_rows1, " and ", num_rows2, "."));
126   }
127 
EnableInputForwarding() const128   bool EnableInputForwarding() const final { return false; }
129 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const130   TensorShapes GetOutputMatrixShapes(
131       const TensorShapes& input_matrix_shapes) const final {
132     return TensorShapes({input_matrix_shapes[1]});
133   }
134 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)135   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
136                      MatrixMaps* outputs) final {
137     const auto diagonals = inputs[0];
138     // Superdiagonal elements, first is ignored.
139     const auto& superdiag = diagonals.row(0);
140     // Diagonal elements.
141     const auto& diag = diagonals.row(1);
142     // Subdiagonal elements, last is ignored.
143     const auto& subdiag = diagonals.row(2);
144     // Right-hand sides.
145     const auto& rhs = inputs[1];
146     MatrixMap& x = outputs->at(0);
147     const int m = diag.size();
148     const int k = rhs.cols();
149 
150     if (m == 0) {
151       return;
152     }
153     if (m < 3) {
154       // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
155       SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
156                            k);
157       return;
158     }
159     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
160     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
161     if (k == 1) {
162       // rhs is copied into x, then gtsv replaces x with solution.
163       CopyDeviceToDevice(context, rhs.data(), x.data(), m);
164       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
165                     subdiag.data(), x.data(), m, 1);
166     } else {
167       // Gtsv expects rhs in column-major form, so we have to transpose.
168       // rhs is transposed into temp, gtsv replaces temp with solution, then
169       // temp is transposed into x.
170       std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
171       Tensor temp;
172       TensorShape temp_shape({k, m});
173       OP_REQUIRES_OK(context,
174                      cublas_solver->allocate_scoped_tensor(
175                          DataTypeToEnum<Scalar>::value, temp_shape, &temp));
176       TransposeWithGeam(context, cublas_solver, rhs.data(),
177                         temp.flat<Scalar>().data(), m, k);
178       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
179                     subdiag.data(), temp.flat<Scalar>().data(), m, k);
180       TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
181                         x.data(), k, m);
182     }
183   }
184 
185  private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<CudaSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const186   void TransposeWithGeam(OpKernelContext* context,
187                          const std::unique_ptr<CudaSolver>& cublas_solver,
188                          const Scalar* src, Scalar* dst, const int src_rows,
189                          const int src_cols) const {
190     const Scalar zero(0), one(1);
191     OP_REQUIRES_OK(context,
192                    cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
193                                        src_cols, &one, src, src_cols, &zero,
194                                        static_cast<const Scalar*>(nullptr),
195                                        src_rows, dst, src_rows));
196   }
197 
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const198   void SolveWithGtsv(OpKernelContext* context,
199                      std::unique_ptr<GpuSparse>& cusparse_solver,
200                      const Scalar* superdiag, const Scalar* diag,
201                      const Scalar* subdiag, Scalar* rhs, const int num_eqs,
202                      const int num_rhs) const {
203 #if CUDA_VERSION < 9000
204     auto function =
205         pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
206     OP_REQUIRES_OK(
207         context, (cusparse_solver.get()->*function)(
208                      num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
209 #else
210     auto buffer_function = pivoting_
211                                ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
212                                : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
213     size_t buffer_size;
214     OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
215                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
216                                 num_eqs, &buffer_size));
217     Tensor temp_tensor;
218     TensorShape temp_shape({static_cast<int64>(buffer_size)});
219     OP_REQUIRES_OK(context,
220                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
221     void* buffer = temp_tensor.flat<std::uint8_t>().data();
222 
223     auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
224                                      : &GpuSparse::Gtsv2NoPivot<Scalar>;
225     OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
226                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
227                                 num_eqs, buffer));
228 #endif  // CUDA_VERSION < 9000
229   }
230 
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)231   void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
232                             const Scalar* rhs, Scalar* output, int m, int k) {
233     const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
234     GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
235     bool* not_invertible_dev;
236     cudaMalloc(&not_invertible_dev, sizeof(bool));
237     TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
238                                 cfg.block_count, cfg.thread_per_block, 0,
239                                 device.stream(), m, diagonals, rhs, k, output,
240                                 not_invertible_dev));
241     bool not_invertible_host;
242     cudaMemcpy(&not_invertible_host, not_invertible_dev, sizeof(bool),
243                cudaMemcpyDeviceToHost);
244     cudaFree(not_invertible_dev);
245     OP_REQUIRES(context, !not_invertible_host,
246                 errors::InvalidArgument(m == 1 ? kNotInvertibleScalarMsg
247                                                : kNotInvertibleMsg));
248   }
249 
250   bool pivoting_;
251 };
252 
253 template <class Scalar>
254 class TridiagonalSolveOpGpu : public OpKernel {
255  public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)256   explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
257       : OpKernel(context), linalgOp_(context) {
258     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
259   }
260 
Compute(OpKernelContext * context)261   void Compute(OpKernelContext* context) final {
262     const Tensor& lhs = context->input(0);
263     const Tensor& rhs = context->input(1);
264     const int ndims = lhs.dims();
265     const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
266     const int64 matrix_size = lhs.dim_size(ndims - 1);
267     int64 batch_size = 1;
268     for (int i = 0; i < ndims - 2; i++) {
269       batch_size *= lhs.dim_size(i);
270     }
271 
272     // The batching mechanism of LinearAlgebraOp is used when it's not
273     // possible or desirable to use GtsvBatched.
274     const bool use_linalg_op =
275         pivoting_            // GtsvBatched doesn't do pivoting
276         || num_rhs > 1       // GtsvBatched doesn't support multiple rhs
277         || matrix_size < 3   // Not supported in cuSparse, use the custom kernel
278         || batch_size == 1;  // No point to use GtsvBatched
279 
280     if (use_linalg_op) {
281       linalgOp_.Compute(context);
282     } else {
283       ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
284     }
285   }
286 
287  private:
288   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
289 
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)290   void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
291                               const Tensor& rhs, const int batch_size) {
292     const Scalar* rhs_data = rhs.flat<Scalar>().data();
293     const int ndims = lhs.dims();
294 
295     // To use GtsvBatched we need to transpose the left-hand side from shape
296     // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
297     // between corresponding diagonal elements of consecutive batch components
298     // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
299     // GtsvBatched requires the strides to be the same. For this reason we
300     // transpose into [3, ..., M], so that diagonals, superdiagonals, and
301     // and subdiagonals are separated from each other, and have stride M.
302     Tensor lhs_transposed;
303     TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
304     int matrix_size = lhs.dim_size(ndims - 1);
305     const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
306     const Scalar* superdiag = lhs_data;
307     const Scalar* diag = lhs_data + matrix_size * batch_size;
308     const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
309 
310     // Copy right-hand side into the output. GtsvBatched will replace it with
311     // the solution.
312     Tensor* output;
313     OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
314     CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
315                        rhs.flat<Scalar>().size());
316     Scalar* x = output->flat<Scalar>().data();
317 
318     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
319 
320     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
321 #if CUDA_VERSION < 9000
322     OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch(
323                                 matrix_size, subdiag, diag, superdiag, x,
324                                 batch_size, matrix_size));
325 #else
326     size_t buffer_size;
327     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
328                                 matrix_size, subdiag, diag, superdiag, x,
329                                 batch_size, matrix_size, &buffer_size));
330     Tensor temp_tensor;
331     TensorShape temp_shape({static_cast<int64>(buffer_size)});
332     OP_REQUIRES_OK(context,
333                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
334     void* buffer = temp_tensor.flat<std::uint8_t>().data();
335     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
336                                 matrix_size, subdiag, diag, superdiag, x,
337                                 batch_size, matrix_size, buffer));
338 #endif  // CUDA_VERSION < 9000
339   }
340 
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)341   void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
342                                   Tensor& lhs_transposed) {
343     const int ndims = lhs.dims();
344 
345     // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
346     // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
347     std::vector<int> perm(ndims);
348     perm[0] = ndims - 2;
349     for (int i = 0; i < ndims - 2; ++i) {
350       perm[i + 1] = i;
351     }
352     perm[ndims - 1] = ndims - 1;
353 
354     std::vector<int64> dims;
355     for (int index : perm) {
356       dims.push_back(lhs.dim_size(index));
357     }
358     TensorShape lhs_transposed_shape(
359         gtl::ArraySlice<int64>(dims.data(), ndims));
360 
361     std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
362     OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
363                                 DataTypeToEnum<Scalar>::value,
364                                 lhs_transposed_shape, &lhs_transposed));
365     auto device = context->eigen_device<Eigen::GpuDevice>();
366     OP_REQUIRES_OK(
367         context,
368         DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
369                     &lhs_transposed));
370   }
371 
372   TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
373   bool pivoting_;
374 };
375 
376 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
377                        float);
378 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
379                        double);
380 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
381                        complex64);
382 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
383                        complex128);
384 
385 }  // namespace tensorflow
386 
387 #endif  // GOOGLE_CUDA
388