1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17 //
18 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
20
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/fill_functor.h"
27 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/matmul_bcast.h"
33
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #include "tensorflow/core/kernels/transpose_functor.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/util/cuda_solvers.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "tensorflow/core/util/rocm_solvers.h"
44 #endif
45
46 namespace tensorflow {
47
48 typedef Eigen::ThreadPoolDevice CPUDevice;
49 typedef Eigen::GpuDevice GPUDevice;
50
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 template <typename Scalar>
AsDeviceMemory(const Scalar * gpu_memory)53 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
54 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
55 se::DeviceMemory<Scalar> typed(wrapped);
56 return typed;
57 }
58
59 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
60
61 // Sequential batch matrix triangular solve kernel that calls Eigen's
62 // matrix triangular solve.
63 template <typename Scalar>
64 struct SequentialMatrixTriangularSolveKernel {
65 using Matrix =
66 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
67 using ConstMatrixMap = Eigen::Map<const Matrix>;
68 using MatrixMap = Eigen::Map<Matrix>;
69 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
70
ConstTensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel71 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
72 int slice) {
73 return ConstMatrixMap(
74 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
75 t.dim_size(1), t.dim_size(2));
76 }
77
TensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel78 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
79 return MatrixMap(
80 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
81 t->dim_size(1), t->dim_size(2));
82 }
83
RunSequentialMatrixTriangularSolveKernel84 static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
85 bool adjoint, const MatMulBCast& bcast, Tensor* out,
86 int start, int limit) {
87 const bool should_bcast = bcast.IsBroadcastingRequired();
88 const auto& x_batch_indices = bcast.x_batch_indices();
89 const auto& y_batch_indices = bcast.y_batch_indices();
90 for (int64 i = start; i < limit; ++i) {
91 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
92 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
93 auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
94 auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
95 auto output = TensorSliceToEigenMatrix(out, i);
96 if (lower) {
97 auto triangle = matrix.template triangularView<Eigen::Lower>();
98 if (adjoint) {
99 output.noalias() = triangle.adjoint().solve(rhs);
100 } else {
101 output.noalias() = triangle.solve(rhs);
102 }
103 } else {
104 auto triangle = matrix.template triangularView<Eigen::Upper>();
105 if (adjoint) {
106 output.noalias() = triangle.adjoint().solve(rhs);
107 } else {
108 output.noalias() = triangle.solve(rhs);
109 }
110 }
111 }
112 }
113 };
114
115 template <typename Device, typename Scalar>
116 struct LaunchBatchMatrixTriangularSolve;
117
118 template <typename Scalar>
119 struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
120 static void Launch(OpKernelContext* context, const Tensor& in_x,
121 const Tensor& in_y, bool adjoint, bool lower,
122 const MatMulBCast& bcast, Tensor* out) {
123 // Number of matrix triangular solves i.e. size of the batch.
124 const int64 batch_size = bcast.output_batch_size();
125 const int64 cost_per_unit =
126 in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
127 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
128
129 using Matrix =
130 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
131 using ConstMatrixMap = Eigen::Map<const Matrix>;
132 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
133 // Check diagonal before doing any solves.
134 auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
135 in_x.dim_size(2));
136 const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
137 OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
138 errors::InvalidArgument("Input matrix is not invertible."));
139
140 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
141 cost_per_unit,
142 [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
143 SequentialMatrixTriangularSolveKernel<Scalar>::Run(
144 in_x, in_y, lower, adjoint, bcast, out, start, limit);
145 });
146 }
147 };
148
149 template <typename Device, typename Scalar>
150 class BaseMatrixTriangularSolveOp : public OpKernel {
151 public:
152 explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
153 : OpKernel(context) {
154 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
155 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
156 }
157
158 ~BaseMatrixTriangularSolveOp() override {}
159
160 void Compute(OpKernelContext* ctx) override {
161 const Tensor& in0 = ctx->input(0);
162 const Tensor& in1 = ctx->input(1);
163
164 ValidateInputTensors(ctx, in0, in1);
165
166 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
167 OP_REQUIRES(
168 ctx, bcast.IsValid(),
169 errors::InvalidArgument(
170 "In[0] and In[1] must have compatible batch dimensions: ",
171 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
172
173 TensorShape out_shape = bcast.output_batch_shape();
174 auto batch_size = bcast.output_batch_size();
175 auto d0 = in0.dim_size(in0.dims() - 2);
176 auto d1 = in0.dim_size(in0.dims() - 1);
177 Tensor in0_reshaped;
178 OP_REQUIRES(
179 ctx,
180 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
181 errors::Internal("Failed to reshape In[0] from ",
182 in0.shape().DebugString()));
183 auto d2 = in1.dim_size(in1.dims() - 2);
184 auto d3 = in1.dim_size(in1.dims() - 1);
185 Tensor in1_reshaped;
186 OP_REQUIRES(
187 ctx,
188 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
189 errors::Internal("Failed to reshape In[1] from ",
190 in1.shape().DebugString()));
191 if (adjoint_) std::swap(d0, d1);
192 OP_REQUIRES(ctx, d1 == d2,
193 errors::InvalidArgument(
194 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
195 in0.shape().DebugString(), " ", in1.shape().DebugString(),
196 " ", lower_, " ", adjoint_));
197 out_shape.AddDim(d0);
198 out_shape.AddDim(d3);
199 Tensor* out = nullptr;
200 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
201 if (out->NumElements() == 0) {
202 return;
203 }
204 Tensor out_reshaped;
205 OP_REQUIRES(ctx,
206 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
207 errors::Internal("Failed to reshape output from ",
208 out->shape().DebugString()));
209 LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
210 ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
211 &out_reshaped);
212 }
213
214 private:
215 virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
216 const Tensor& in1) = 0;
217 bool lower_;
218 bool adjoint_;
219 };
220
221 template <class Device, class Scalar>
222 class MatrixTriangularSolveOp
223 : public BaseMatrixTriangularSolveOp<Device, Scalar> {
224 public:
225 explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
226 : BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
227
228 ~MatrixTriangularSolveOp() override {}
229
230 private:
231 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
232 const Tensor& in1) override {
233 OP_REQUIRES(
234 ctx, in0.dims() >= 2,
235 errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
236
237 OP_REQUIRES(
238 ctx, in1.dims() >= 2,
239 errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
240 }
241 };
242
243 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \
244 REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
245 .Device(DEVICE_CPU) \
246 .TypeConstraint<TYPE>("T"), \
247 MatrixTriangularSolveOp<CPUDevice, TYPE>); \
248 REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
249 .Device(DEVICE_CPU) \
250 .TypeConstraint<TYPE>("T"), \
251 MatrixTriangularSolveOp<CPUDevice, TYPE>);
252
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254
255 template <typename Scalar>
256 struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
257 static void Launch(OpKernelContext* context, const Tensor& in_x,
258 const Tensor& in_y, bool adjoint, bool lower,
259 const MatMulBCast& bcast, Tensor* out) {
260 auto* stream = context->op_device_context()->stream();
261
262 const uint64 m = in_x.dim_size(1);
263 const uint64 n = out->dim_size(2);
264
265 // Do a memcpy when we don't need to broadcast.
266 if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
267 auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
268 auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
269 OP_REQUIRES(
270 context,
271 stream
272 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
273 bcast.y_batch_size() * m * n * sizeof(Scalar))
274 .ok(),
275 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
276 "from device"));
277 } else {
278 std::vector<Scalar*> out_ptrs;
279 std::vector<const Scalar*> b_tmp_ptrs;
280 auto* b_base_ptr = in_y.template flat<Scalar>().data();
281 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
282 for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
283 b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
284 }
285 for (int64 i = 0; i < bcast.output_batch_size(); ++i) {
286 auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
287 auto dst_device_mem =
288 AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
289 OP_REQUIRES(
290 context,
291 stream
292 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
293 m * n * sizeof(Scalar))
294 .ok(),
295 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
296 "from device"));
297 }
298 }
299
300 if (out->NumElements() == 0) {
301 return;
302 }
303
304 #if GOOGLE_CUDA
305
306 cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
307 cublasFillMode_t uplo;
308 cublasOperation_t trans;
309 cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
310
311 // Cublas does
312 // output = matrix \ rhs
313 // where matrix, rhs and output are assumed to be in column major.
314 // We want the output to be in row-major, so we can compute
315 // output' = rhs' / matrix' (' stands for transpose)
316 // Upper/lower needs to be swapped for this.
317
318 uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
319 trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
320 auto solver = absl::make_unique<CudaSolver>(context);
321
322 #elif TENSORFLOW_USE_ROCM
323 rocblas_side side = rocblas_side_right;
324 rocblas_fill uplo;
325 rocblas_operation trans;
326 rocblas_diagonal diag = rocblas_diagonal_non_unit;
327
328 // rocblas does
329 // output = matrix \ rhs
330 // where matrix, rhs and output are assumed to be in column major.
331 // We want the output to be in row-major, so we can compute
332 // output' = rhs' / matrix' (' stands for transpose)
333 // Upper/lower needs to be swapped for this.
334
335 uplo = lower ? rocblas_fill_upper : rocblas_fill_lower;
336 trans = adjoint ? rocblas_operation_conjugate_transpose
337 : rocblas_operation_none;
338 auto solver = absl::make_unique<ROCmSolver>(context);
339
340 #endif
341
342 const uint64 leading_dim_matrix = m;
343 const uint64 leading_dim_output = n;
344 const uint64 colmajor_rows = n;
345 const uint64 colmajor_cols = m;
346
347 const int64 batch_size = bcast.output_batch_size();
348 std::vector<const Scalar*> a_ptrs;
349 std::vector<Scalar*> out_ptrs;
350 std::vector<const Scalar*> a_tmp_ptrs;
351 a_ptrs.reserve(batch_size);
352 out_ptrs.reserve(batch_size);
353 a_tmp_ptrs.reserve(bcast.x_batch_size());
354 auto* a_base_ptr = in_x.template flat<Scalar>().data();
355 auto* out_base_ptr = out->template flat<Scalar>().data();
356
357 if (!bcast.IsBroadcastingRequired()) {
358 for (int64 i = 0; i < batch_size; ++i) {
359 a_ptrs.push_back(a_base_ptr + i * m * m);
360 out_ptrs.push_back(out_base_ptr + i * m * n);
361 }
362 } else {
363 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
364 for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
365 a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
366 }
367 for (int64 i = 0; i < batch_size; ++i) {
368 a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
369 out_ptrs.push_back(out_base_ptr + i * m * n);
370 }
371 }
372
373 typedef Scalar Coefficient;
374 const Scalar alpha = Scalar(1.0);
375
376 #if GOOGLE_CUDA
377
378 // TODO(b/146763573): Consider using Trsv here when the right hand side is
379 // a vector. This will require an explicit transpose since Trsv assumes
380 // CUBLAS_SIDE_LEFT.
381 if (batch_size == 1) {
382 OP_REQUIRES_OK(
383 context,
384 solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
385 &alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
386 out_ptrs[0], leading_dim_output /*ldb*/));
387 } else {
388 // Heuristic for choosing between batched interface vs. non-batched
389 // interface. This is inspired by matrix_solve_op and can probably be
390 // tuned.
391 // TODO(b/146763573): Tune this heuristic.
392 const int kMaxMatrixSizeToBatchSizeRatio = 128;
393 const bool use_batched_solver =
394 m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
395 if (use_batched_solver) {
396 OP_REQUIRES_OK(
397 context, solver->TrsmBatched(
398 side, uplo, trans, diag, colmajor_rows, colmajor_cols,
399 &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
400 &out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
401 } else {
402 for (int batch = 0; batch < batch_size; ++batch) {
403 OP_REQUIRES_OK(
404 context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
405 colmajor_cols, &alpha, a_ptrs[batch],
406 leading_dim_matrix /*lda*/, out_ptrs[batch],
407 leading_dim_output /*ldb*/));
408 }
409 }
410 }
411 #elif TENSORFLOW_USE_ROCM
412 for (int batch = 0; batch < batch_size; ++batch) {
413 OP_REQUIRES_OK(
414 context,
415 solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
416 &alpha, a_ptrs[batch], leading_dim_matrix /*lda*/,
417 out_ptrs[batch], leading_dim_output /*ldb*/));
418 }
419 #endif
420 }
421 };
422
423 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \
424 REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
425 .Device(DEVICE_GPU) \
426 .TypeConstraint<TYPE>("T"), \
427 MatrixTriangularSolveOp<GPUDevice, TYPE>); \
428 REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
429 .Device(DEVICE_GPU) \
430 .TypeConstraint<TYPE>("T"), \
431 MatrixTriangularSolveOp<GPUDevice, TYPE>);
432
433 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
434
435 } // namespace tensorflow
436
437 #endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
438