1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #ifdef GOOGLE_CUDA
19
20 #define EIGEN_USE_GPU
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/cuda_solvers.h"
27 #include "tensorflow/core/kernels/cuda_sparse.h"
28 #include "tensorflow/core/kernels/linalg_ops_common.h"
29 #include "tensorflow/core/kernels/transpose_functor.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/util/gpu_device_functions.h"
32 #include "tensorflow/core/util/gpu_kernel_helper.h"
33 #include "tensorflow/core/util/gpu_launch_config.h"
34
35 namespace tensorflow {
36
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38
39 static const char kNotInvertibleScalarMsg[] =
40 "The matrix is not invertible: it is a scalar with value zero.";
41
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x,bool * __restrict__ not_invertible)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44 const Scalar* __restrict__ diags,
45 const Scalar* __restrict__ rhs,
46 const int num_rhs,
47 Scalar* __restrict__ x,
48 bool* __restrict__ not_invertible) {
49 if (m == 1) {
50 if (diags[1] == Scalar(0)) {
51 *not_invertible = true;
52 return;
53 }
54 for (int i : GpuGridRangeX(num_rhs)) {
55 x[i] = rhs[i] / diags[1];
56 }
57 } else {
58 Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
59 if (det == Scalar(0)) {
60 *not_invertible = true;
61 return;
62 }
63 for (int i : GpuGridRangeX(num_rhs)) {
64 x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
65 x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
66 }
67 }
68 }
69
70 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)71 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
72 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
73 se::DeviceMemory<Scalar> typed(wrapped);
74 return typed;
75 }
76
77 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)78 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
79 Scalar* dst, const int num_elements) {
80 auto src_device_mem = AsDeviceMemory(src);
81 auto dst_device_mem = AsDeviceMemory(dst);
82 auto* stream = context->op_device_context()->stream();
83 bool copy_status = stream
84 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
85 sizeof(Scalar) * num_elements)
86 .ok();
87
88 if (!copy_status) {
89 context->SetStatus(errors::Internal("Copying device-to-device failed."));
90 }
91 }
92
93 // This implementation is used in cases when the batching mechanism of
94 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
95 template <class Scalar>
96 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
97 public:
98 INHERIT_LINALG_TYPEDEFS(Scalar);
99
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)100 explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
101 : Base(context) {
102 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
103 }
104
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const105 void ValidateInputMatrixShapes(
106 OpKernelContext* context,
107 const TensorShapes& input_matrix_shapes) const final {
108 auto num_inputs = input_matrix_shapes.size();
109 OP_REQUIRES(context, num_inputs == 2,
110 errors::InvalidArgument("Expected two input matrices, got ",
111 num_inputs, "."));
112
113 auto num_diags = input_matrix_shapes[0].dim_size(0);
114 OP_REQUIRES(
115 context, num_diags == 3,
116 errors::InvalidArgument("Expected diagonals to be provided as a "
117 "matrix with 3 columns, got ",
118 num_diags, " columns."));
119
120 auto num_rows1 = input_matrix_shapes[0].dim_size(1);
121 auto num_rows2 = input_matrix_shapes[1].dim_size(0);
122 OP_REQUIRES(context, num_rows1 == num_rows2,
123 errors::InvalidArgument("Expected same number of rows in both "
124 "arguments, got ",
125 num_rows1, " and ", num_rows2, "."));
126 }
127
EnableInputForwarding() const128 bool EnableInputForwarding() const final { return false; }
129
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const130 TensorShapes GetOutputMatrixShapes(
131 const TensorShapes& input_matrix_shapes) const final {
132 return TensorShapes({input_matrix_shapes[1]});
133 }
134
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)135 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
136 MatrixMaps* outputs) final {
137 const auto diagonals = inputs[0];
138 // Superdiagonal elements, first is ignored.
139 const auto& superdiag = diagonals.row(0);
140 // Diagonal elements.
141 const auto& diag = diagonals.row(1);
142 // Subdiagonal elements, last is ignored.
143 const auto& subdiag = diagonals.row(2);
144 // Right-hand sides.
145 const auto& rhs = inputs[1];
146 MatrixMap& x = outputs->at(0);
147 const int m = diag.size();
148 const int k = rhs.cols();
149
150 if (m == 0) {
151 return;
152 }
153 if (m < 3) {
154 // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
155 SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
156 k);
157 return;
158 }
159 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
160 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
161 if (k == 1) {
162 // rhs is copied into x, then gtsv replaces x with solution.
163 CopyDeviceToDevice(context, rhs.data(), x.data(), m);
164 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
165 subdiag.data(), x.data(), m, 1);
166 } else {
167 // Gtsv expects rhs in column-major form, so we have to transpose.
168 // rhs is transposed into temp, gtsv replaces temp with solution, then
169 // temp is transposed into x.
170 std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
171 Tensor temp;
172 TensorShape temp_shape({k, m});
173 OP_REQUIRES_OK(context,
174 cublas_solver->allocate_scoped_tensor(
175 DataTypeToEnum<Scalar>::value, temp_shape, &temp));
176 TransposeWithGeam(context, cublas_solver, rhs.data(),
177 temp.flat<Scalar>().data(), m, k);
178 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
179 subdiag.data(), temp.flat<Scalar>().data(), m, k);
180 TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
181 x.data(), k, m);
182 }
183 }
184
185 private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<CudaSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const186 void TransposeWithGeam(OpKernelContext* context,
187 const std::unique_ptr<CudaSolver>& cublas_solver,
188 const Scalar* src, Scalar* dst, const int src_rows,
189 const int src_cols) const {
190 const Scalar zero(0), one(1);
191 OP_REQUIRES_OK(context,
192 cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
193 src_cols, &one, src, src_cols, &zero,
194 static_cast<const Scalar*>(nullptr),
195 src_rows, dst, src_rows));
196 }
197
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const198 void SolveWithGtsv(OpKernelContext* context,
199 std::unique_ptr<GpuSparse>& cusparse_solver,
200 const Scalar* superdiag, const Scalar* diag,
201 const Scalar* subdiag, Scalar* rhs, const int num_eqs,
202 const int num_rhs) const {
203 #if CUDA_VERSION < 9000
204 auto function =
205 pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
206 OP_REQUIRES_OK(
207 context, (cusparse_solver.get()->*function)(
208 num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
209 #else
210 auto buffer_function = pivoting_
211 ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
212 : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
213 size_t buffer_size;
214 OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
215 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
216 num_eqs, &buffer_size));
217 Tensor temp_tensor;
218 TensorShape temp_shape({static_cast<int64>(buffer_size)});
219 OP_REQUIRES_OK(context,
220 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
221 void* buffer = temp_tensor.flat<std::uint8_t>().data();
222
223 auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
224 : &GpuSparse::Gtsv2NoPivot<Scalar>;
225 OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
226 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
227 num_eqs, buffer));
228 #endif // CUDA_VERSION < 9000
229 }
230
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)231 void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
232 const Scalar* rhs, Scalar* output, int m, int k) {
233 const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
234 GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
235 bool* not_invertible_dev;
236 cudaMalloc(¬_invertible_dev, sizeof(bool));
237 TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
238 cfg.block_count, cfg.thread_per_block, 0,
239 device.stream(), m, diagonals, rhs, k, output,
240 not_invertible_dev));
241 bool not_invertible_host;
242 cudaMemcpy(¬_invertible_host, not_invertible_dev, sizeof(bool),
243 cudaMemcpyDeviceToHost);
244 cudaFree(not_invertible_dev);
245 OP_REQUIRES(context, !not_invertible_host,
246 errors::InvalidArgument(m == 1 ? kNotInvertibleScalarMsg
247 : kNotInvertibleMsg));
248 }
249
250 bool pivoting_;
251 };
252
253 template <class Scalar>
254 class TridiagonalSolveOpGpu : public OpKernel {
255 public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)256 explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
257 : OpKernel(context), linalgOp_(context) {
258 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
259 }
260
Compute(OpKernelContext * context)261 void Compute(OpKernelContext* context) final {
262 const Tensor& lhs = context->input(0);
263 const Tensor& rhs = context->input(1);
264 const int ndims = lhs.dims();
265 const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
266 const int64 matrix_size = lhs.dim_size(ndims - 1);
267 int64 batch_size = 1;
268 for (int i = 0; i < ndims - 2; i++) {
269 batch_size *= lhs.dim_size(i);
270 }
271
272 // The batching mechanism of LinearAlgebraOp is used when it's not
273 // possible or desirable to use GtsvBatched.
274 const bool use_linalg_op =
275 pivoting_ // GtsvBatched doesn't do pivoting
276 || num_rhs > 1 // GtsvBatched doesn't support multiple rhs
277 || matrix_size < 3 // Not supported in cuSparse, use the custom kernel
278 || batch_size == 1; // No point to use GtsvBatched
279
280 if (use_linalg_op) {
281 linalgOp_.Compute(context);
282 } else {
283 ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
284 }
285 }
286
287 private:
288 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
289
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)290 void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
291 const Tensor& rhs, const int batch_size) {
292 const Scalar* rhs_data = rhs.flat<Scalar>().data();
293 const int ndims = lhs.dims();
294
295 // To use GtsvBatched we need to transpose the left-hand side from shape
296 // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
297 // between corresponding diagonal elements of consecutive batch components
298 // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
299 // GtsvBatched requires the strides to be the same. For this reason we
300 // transpose into [3, ..., M], so that diagonals, superdiagonals, and
301 // and subdiagonals are separated from each other, and have stride M.
302 Tensor lhs_transposed;
303 TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
304 int matrix_size = lhs.dim_size(ndims - 1);
305 const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
306 const Scalar* superdiag = lhs_data;
307 const Scalar* diag = lhs_data + matrix_size * batch_size;
308 const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
309
310 // Copy right-hand side into the output. GtsvBatched will replace it with
311 // the solution.
312 Tensor* output;
313 OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
314 CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
315 rhs.flat<Scalar>().size());
316 Scalar* x = output->flat<Scalar>().data();
317
318 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
319
320 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
321 #if CUDA_VERSION < 9000
322 OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch(
323 matrix_size, subdiag, diag, superdiag, x,
324 batch_size, matrix_size));
325 #else
326 size_t buffer_size;
327 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
328 matrix_size, subdiag, diag, superdiag, x,
329 batch_size, matrix_size, &buffer_size));
330 Tensor temp_tensor;
331 TensorShape temp_shape({static_cast<int64>(buffer_size)});
332 OP_REQUIRES_OK(context,
333 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
334 void* buffer = temp_tensor.flat<std::uint8_t>().data();
335 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
336 matrix_size, subdiag, diag, superdiag, x,
337 batch_size, matrix_size, buffer));
338 #endif // CUDA_VERSION < 9000
339 }
340
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)341 void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
342 Tensor& lhs_transposed) {
343 const int ndims = lhs.dims();
344
345 // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
346 // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
347 std::vector<int> perm(ndims);
348 perm[0] = ndims - 2;
349 for (int i = 0; i < ndims - 2; ++i) {
350 perm[i + 1] = i;
351 }
352 perm[ndims - 1] = ndims - 1;
353
354 std::vector<int64> dims;
355 for (int index : perm) {
356 dims.push_back(lhs.dim_size(index));
357 }
358 TensorShape lhs_transposed_shape(
359 gtl::ArraySlice<int64>(dims.data(), ndims));
360
361 std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
362 OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
363 DataTypeToEnum<Scalar>::value,
364 lhs_transposed_shape, &lhs_transposed));
365 auto device = context->eigen_device<Eigen::GpuDevice>();
366 OP_REQUIRES_OK(
367 context,
368 DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
369 &lhs_transposed));
370 }
371
372 TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
373 bool pivoting_;
374 };
375
376 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
377 float);
378 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
379 double);
380 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
381 complex64);
382 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
383 complex128);
384
385 } // namespace tensorflow
386
387 #endif // GOOGLE_CUDA
388