1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #ifdef GOOGLE_CUDA
19
20 #define EIGEN_USE_GPU
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/util/cuda_solvers.h"
30 #include "tensorflow/core/util/cuda_sparse.h"
31 #include "tensorflow/core/util/gpu_device_functions.h"
32 #include "tensorflow/core/util/gpu_kernel_helper.h"
33 #include "tensorflow/core/util/gpu_launch_config.h"
34
35 namespace tensorflow {
36
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38
39 static const char kNotInvertibleScalarMsg[] =
40 "The matrix is not invertible: it is a scalar with value zero.";
41
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x,bool * __restrict__ not_invertible)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44 const Scalar* __restrict__ diags,
45 const Scalar* __restrict__ rhs,
46 const int num_rhs,
47 Scalar* __restrict__ x,
48 bool* __restrict__ not_invertible) {
49 if (m == 1) {
50 if (diags[1] == Scalar(0)) {
51 *not_invertible = true;
52 return;
53 }
54 for (int i : GpuGridRangeX(num_rhs)) {
55 x[i] = rhs[i] / diags[1];
56 }
57 } else {
58 Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
59 if (det == Scalar(0)) {
60 *not_invertible = true;
61 return;
62 }
63 for (int i : GpuGridRangeX(num_rhs)) {
64 x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
65 x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
66 }
67 }
68 }
69
70 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)71 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
72 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
73 se::DeviceMemory<Scalar> typed(wrapped);
74 return typed;
75 }
76
77 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)78 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
79 Scalar* dst, const int num_elements) {
80 auto src_device_mem = AsDeviceMemory(src);
81 auto dst_device_mem = AsDeviceMemory(dst);
82 auto* stream = context->op_device_context()->stream();
83 bool copy_status = stream
84 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
85 sizeof(Scalar) * num_elements)
86 .ok();
87
88 if (!copy_status) {
89 context->SetStatus(errors::Internal("Copying device-to-device failed."));
90 }
91 }
92
93 // This implementation is used in cases when the batching mechanism of
94 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
95 template <class Scalar>
96 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
97 public:
98 INHERIT_LINALG_TYPEDEFS(Scalar);
99
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)100 explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
101 : Base(context) {
102 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
103 }
104
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const105 void ValidateInputMatrixShapes(
106 OpKernelContext* context,
107 const TensorShapes& input_matrix_shapes) const final {
108 auto num_inputs = input_matrix_shapes.size();
109 OP_REQUIRES(context, num_inputs == 2,
110 errors::InvalidArgument("Expected two input matrices, got ",
111 num_inputs, "."));
112
113 auto num_diags = input_matrix_shapes[0].dim_size(0);
114 OP_REQUIRES(
115 context, num_diags == 3,
116 errors::InvalidArgument("Expected diagonals to be provided as a "
117 "matrix with 3 columns, got ",
118 num_diags, " columns."));
119
120 auto num_rows1 = input_matrix_shapes[0].dim_size(1);
121 auto num_rows2 = input_matrix_shapes[1].dim_size(0);
122 OP_REQUIRES(context, num_rows1 == num_rows2,
123 errors::InvalidArgument("Expected same number of rows in both "
124 "arguments, got ",
125 num_rows1, " and ", num_rows2, "."));
126 }
127
EnableInputForwarding() const128 bool EnableInputForwarding() const final { return false; }
129
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const130 TensorShapes GetOutputMatrixShapes(
131 const TensorShapes& input_matrix_shapes) const final {
132 return TensorShapes({input_matrix_shapes[1]});
133 }
134
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)135 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
136 MatrixMaps* outputs) final {
137 const auto diagonals = inputs[0];
138 // Superdiagonal elements, first is ignored.
139 const auto& superdiag = diagonals.row(0);
140 // Diagonal elements.
141 const auto& diag = diagonals.row(1);
142 // Subdiagonal elements, last is ignored.
143 const auto& subdiag = diagonals.row(2);
144 // Right-hand sides.
145 const auto& rhs = inputs[1];
146 MatrixMap& x = outputs->at(0);
147 const int m = diag.size();
148 const int k = rhs.cols();
149
150 if (m == 0) {
151 return;
152 }
153 if (m < 3) {
154 // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
155 SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
156 k);
157 return;
158 }
159 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
160 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
161 if (k == 1) {
162 // rhs is copied into x, then gtsv replaces x with solution.
163 CopyDeviceToDevice(context, rhs.data(), x.data(), m);
164 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
165 subdiag.data(), x.data(), m, 1);
166 } else {
167 // Gtsv expects rhs in column-major form, so we have to transpose.
168 // rhs is transposed into temp, gtsv replaces temp with solution, then
169 // temp is transposed into x.
170 std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
171 Tensor temp;
172 TensorShape temp_shape({k, m});
173 OP_REQUIRES_OK(context,
174 cublas_solver->allocate_scoped_tensor(
175 DataTypeToEnum<Scalar>::value, temp_shape, &temp));
176 TransposeWithGeam(context, cublas_solver, rhs.data(),
177 temp.flat<Scalar>().data(), m, k);
178 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
179 subdiag.data(), temp.flat<Scalar>().data(), m, k);
180 TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
181 x.data(), k, m);
182 }
183 }
184
185 private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<CudaSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const186 void TransposeWithGeam(OpKernelContext* context,
187 const std::unique_ptr<CudaSolver>& cublas_solver,
188 const Scalar* src, Scalar* dst, const int src_rows,
189 const int src_cols) const {
190 const Scalar zero(0), one(1);
191 OP_REQUIRES_OK(context,
192 cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
193 src_cols, &one, src, src_cols, &zero,
194 static_cast<const Scalar*>(nullptr),
195 src_rows, dst, src_rows));
196 }
197
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const198 void SolveWithGtsv(OpKernelContext* context,
199 std::unique_ptr<GpuSparse>& cusparse_solver,
200 const Scalar* superdiag, const Scalar* diag,
201 const Scalar* subdiag, Scalar* rhs, const int num_eqs,
202 const int num_rhs) const {
203 auto buffer_function = pivoting_
204 ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
205 : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
206 size_t buffer_size;
207 OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
208 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
209 num_eqs, &buffer_size));
210 Tensor temp_tensor;
211 TensorShape temp_shape({static_cast<int64>(buffer_size)});
212 OP_REQUIRES_OK(context,
213 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
214 void* buffer = temp_tensor.flat<std::uint8_t>().data();
215
216 auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
217 : &GpuSparse::Gtsv2NoPivot<Scalar>;
218 OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
219 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
220 num_eqs, buffer));
221 }
222
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)223 void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
224 const Scalar* rhs, Scalar* output, int m, int k) {
225 const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
226 GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
227 bool* not_invertible_dev;
228 cudaMalloc(¬_invertible_dev, sizeof(bool));
229 TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
230 cfg.block_count, cfg.thread_per_block, 0,
231 device.stream(), m, diagonals, rhs, k, output,
232 not_invertible_dev));
233 bool not_invertible_host;
234 cudaMemcpy(¬_invertible_host, not_invertible_dev, sizeof(bool),
235 cudaMemcpyDeviceToHost);
236 cudaFree(not_invertible_dev);
237 OP_REQUIRES(context, !not_invertible_host,
238 errors::InvalidArgument(m == 1 ? kNotInvertibleScalarMsg
239 : kNotInvertibleMsg));
240 }
241
242 bool pivoting_;
243 };
244
245 template <class Scalar>
246 class TridiagonalSolveOpGpu : public OpKernel {
247 public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)248 explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
249 : OpKernel(context), linalgOp_(context) {
250 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
251 }
252
Compute(OpKernelContext * context)253 void Compute(OpKernelContext* context) final {
254 const Tensor& lhs = context->input(0);
255 const Tensor& rhs = context->input(1);
256 const int ndims = lhs.dims();
257 const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
258 const int64 matrix_size = lhs.dim_size(ndims - 1);
259 int64 batch_size = 1;
260 for (int i = 0; i < ndims - 2; i++) {
261 batch_size *= lhs.dim_size(i);
262 }
263
264 // The batching mechanism of LinearAlgebraOp is used when it's not
265 // possible or desirable to use GtsvBatched.
266 const bool use_linalg_op =
267 pivoting_ // GtsvBatched doesn't do pivoting
268 || num_rhs > 1 // GtsvBatched doesn't support multiple rhs
269 || matrix_size < 3 // Not supported in cuSparse, use the custom kernel
270 || batch_size == 1; // No point to use GtsvBatched
271
272 if (use_linalg_op) {
273 linalgOp_.Compute(context);
274 } else {
275 ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
276 }
277 }
278
279 private:
280 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
281
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)282 void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
283 const Tensor& rhs, const int batch_size) {
284 const Scalar* rhs_data = rhs.flat<Scalar>().data();
285 const int ndims = lhs.dims();
286
287 // To use GtsvBatched we need to transpose the left-hand side from shape
288 // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
289 // between corresponding diagonal elements of consecutive batch components
290 // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
291 // GtsvBatched requires the strides to be the same. For this reason we
292 // transpose into [3, ..., M], so that diagonals, superdiagonals, and
293 // and subdiagonals are separated from each other, and have stride M.
294 Tensor lhs_transposed;
295 TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
296 int matrix_size = lhs.dim_size(ndims - 1);
297 const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
298 const Scalar* superdiag = lhs_data;
299 const Scalar* diag = lhs_data + matrix_size * batch_size;
300 const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
301
302 // Copy right-hand side into the output. GtsvBatched will replace it with
303 // the solution.
304 Tensor* output;
305 OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
306 CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
307 rhs.flat<Scalar>().size());
308 Scalar* x = output->flat<Scalar>().data();
309
310 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
311
312 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
313
314 size_t buffer_size;
315 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
316 matrix_size, subdiag, diag, superdiag, x,
317 batch_size, matrix_size, &buffer_size));
318 Tensor temp_tensor;
319 TensorShape temp_shape({static_cast<int64>(buffer_size)});
320 OP_REQUIRES_OK(context,
321 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
322 void* buffer = temp_tensor.flat<std::uint8_t>().data();
323 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
324 matrix_size, subdiag, diag, superdiag, x,
325 batch_size, matrix_size, buffer));
326 }
327
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)328 void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
329 Tensor& lhs_transposed) {
330 const int ndims = lhs.dims();
331
332 // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
333 // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
334 std::vector<int> perm(ndims);
335 perm[0] = ndims - 2;
336 for (int i = 0; i < ndims - 2; ++i) {
337 perm[i + 1] = i;
338 }
339 perm[ndims - 1] = ndims - 1;
340
341 std::vector<int64> dims;
342 for (int index : perm) {
343 dims.push_back(lhs.dim_size(index));
344 }
345 TensorShape lhs_transposed_shape(
346 gtl::ArraySlice<int64>(dims.data(), ndims));
347
348 std::unique_ptr<CudaSolver> cublas_solver(new CudaSolver(context));
349 OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
350 DataTypeToEnum<Scalar>::value,
351 lhs_transposed_shape, &lhs_transposed));
352 auto device = context->eigen_device<Eigen::GpuDevice>();
353 OP_REQUIRES_OK(
354 context,
355 DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
356 &lhs_transposed));
357 }
358
359 TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
360 bool pivoting_;
361 };
362
363 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
364 float);
365 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
366 double);
367 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
368 complex64);
369 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
370 complex128);
371
372 } // namespace tensorflow
373
374 #endif // GOOGLE_CUDA
375