1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17 //
18 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
20
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/fill_functor.h"
27 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/matmul_bcast.h"
33
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #include "tensorflow/core/kernels/transpose_functor.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/util/cuda_solvers.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "tensorflow/core/util/rocm_solvers.h"
44 #endif
45
46 namespace tensorflow {
47
48 typedef Eigen::ThreadPoolDevice CPUDevice;
49 typedef Eigen::GpuDevice GPUDevice;
50
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 template <typename Scalar>
AsDeviceMemory(const Scalar * gpu_memory)53 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
54 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
55 se::DeviceMemory<Scalar> typed(wrapped);
56 return typed;
57 }
58
59 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
60
61 // Sequential batch matrix triangular solve kernel that calls Eigen's
62 // matrix triangular solve.
63 template <typename Scalar>
64 struct SequentialMatrixTriangularSolveKernel {
65 using Matrix =
66 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
67 using ConstMatrixMap = Eigen::Map<const Matrix>;
68 using MatrixMap = Eigen::Map<Matrix>;
69 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
70
ConstTensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel71 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
72 int slice) {
73 return ConstMatrixMap(
74 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
75 t.dim_size(1), t.dim_size(2));
76 }
77
TensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel78 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
79 return MatrixMap(
80 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
81 t->dim_size(1), t->dim_size(2));
82 }
83
RunSequentialMatrixTriangularSolveKernel84 static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
85 bool adjoint, const MatMulBCast& bcast, Tensor* out,
86 int start, int limit) {
87 const bool should_bcast = bcast.IsBroadcastingRequired();
88 const auto& x_batch_indices = bcast.x_batch_indices();
89 const auto& y_batch_indices = bcast.y_batch_indices();
90 for (int64_t i = start; i < limit; ++i) {
91 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
92 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
93 auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
94 auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
95 auto output = TensorSliceToEigenMatrix(out, i);
96 if (lower) {
97 auto triangle = matrix.template triangularView<Eigen::Lower>();
98 if (adjoint) {
99 output.noalias() = triangle.adjoint().solve(rhs);
100 } else {
101 output.noalias() = triangle.solve(rhs);
102 }
103 } else {
104 auto triangle = matrix.template triangularView<Eigen::Upper>();
105 if (adjoint) {
106 output.noalias() = triangle.adjoint().solve(rhs);
107 } else {
108 output.noalias() = triangle.solve(rhs);
109 }
110 }
111 }
112 }
113 };
114
115 template <typename Device, typename Scalar>
116 struct LaunchBatchMatrixTriangularSolve;
117
118 template <typename Scalar>
119 struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
120 static void Launch(OpKernelContext* context, const Tensor& in_x,
121 const Tensor& in_y, bool adjoint, bool lower,
122 const MatMulBCast& bcast, Tensor* out) {
123 // Number of matrix triangular solves i.e. size of the batch.
124 const int64_t batch_size = bcast.output_batch_size();
125 const int64_t cost_per_unit =
126 in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
127 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
128
129 using Matrix =
130 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
131 using ConstMatrixMap = Eigen::Map<const Matrix>;
132 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
133 // Check diagonal before doing any solves.
134 auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
135 in_x.dim_size(2));
136 const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
137 OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
138 errors::InvalidArgument("Input matrix is not invertible."));
139
140 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
141 cost_per_unit,
142 [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
143 SequentialMatrixTriangularSolveKernel<Scalar>::Run(
144 in_x, in_y, lower, adjoint, bcast, out, start, limit);
145 });
146 }
147 };
148
149 template <typename Device, typename Scalar>
150 class BaseMatrixTriangularSolveOp : public OpKernel {
151 public:
152 explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
153 : OpKernel(context) {
154 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
155 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
156 }
157
158 ~BaseMatrixTriangularSolveOp() override {}
159
160 void Compute(OpKernelContext* ctx) override {
161 const Tensor& in0 = ctx->input(0);
162 const Tensor& in1 = ctx->input(1);
163
164 ValidateInputTensors(ctx, in0, in1);
165 if (!ctx->status().ok()) {
166 return;
167 }
168
169 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
170 OP_REQUIRES(
171 ctx, bcast.IsValid(),
172 errors::InvalidArgument(
173 "In[0] and In[1] must have compatible batch dimensions: ",
174 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
175
176 TensorShape out_shape = bcast.output_batch_shape();
177 auto batch_size = bcast.output_batch_size();
178 auto d0 = in0.dim_size(in0.dims() - 2);
179 auto d1 = in0.dim_size(in0.dims() - 1);
180 Tensor in0_reshaped;
181 OP_REQUIRES(
182 ctx,
183 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
184 errors::Internal("Failed to reshape In[0] from ",
185 in0.shape().DebugString()));
186 auto d2 = in1.dim_size(in1.dims() - 2);
187 auto d3 = in1.dim_size(in1.dims() - 1);
188 Tensor in1_reshaped;
189 OP_REQUIRES(
190 ctx,
191 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
192 errors::Internal("Failed to reshape In[1] from ",
193 in1.shape().DebugString()));
194 if (adjoint_) std::swap(d0, d1);
195 OP_REQUIRES(ctx, d1 == d2,
196 errors::InvalidArgument(
197 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
198 in0.shape().DebugString(), " ", in1.shape().DebugString(),
199 " ", lower_, " ", adjoint_));
200 out_shape.AddDim(d0);
201 out_shape.AddDim(d3);
202 Tensor* out = nullptr;
203 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
204 if (out->NumElements() == 0) {
205 return;
206 }
207 Tensor out_reshaped;
208 OP_REQUIRES(ctx,
209 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
210 errors::Internal("Failed to reshape output from ",
211 out->shape().DebugString()));
212 LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
213 ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
214 &out_reshaped);
215 }
216
217 private:
218 virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
219 const Tensor& in1) = 0;
220 bool lower_;
221 bool adjoint_;
222 };
223
224 template <class Device, class Scalar>
225 class MatrixTriangularSolveOp
226 : public BaseMatrixTriangularSolveOp<Device, Scalar> {
227 public:
228 explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
229 : BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
230
231 ~MatrixTriangularSolveOp() override {}
232
233 private:
234 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
235 const Tensor& in1) override {
236 const auto in0_num_dims = in0.dims();
237 OP_REQUIRES(
238 ctx, in0_num_dims >= 2,
239 errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));
240
241 const auto in1_num_dims = in1.dims();
242 OP_REQUIRES(
243 ctx, in1_num_dims >= 2,
244 errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));
245
246 const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
247 const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
248 OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
249 errors::InvalidArgument(
250 "In[0] matrices in the last dimensions must be square (",
251 in0_last_dim, " =/= ", in0_prev_dim, ")"));
252 }
253 };
254
255 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \
256 REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
257 .Device(DEVICE_CPU) \
258 .TypeConstraint<TYPE>("T"), \
259 MatrixTriangularSolveOp<CPUDevice, TYPE>); \
260 REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
261 .Device(DEVICE_CPU) \
262 .TypeConstraint<TYPE>("T"), \
263 MatrixTriangularSolveOp<CPUDevice, TYPE>);
264
265 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266
267 template <typename Scalar>
268 struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
269 static void Launch(OpKernelContext* context, const Tensor& in_x,
270 const Tensor& in_y, bool adjoint, bool lower,
271 const MatMulBCast& bcast, Tensor* out) {
272 auto* stream = context->op_device_context()->stream();
273
274 const uint64 m = in_x.dim_size(1);
275 const uint64 n = out->dim_size(2);
276
277 // Do a memcpy when we don't need to broadcast.
278 if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
279 auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
280 auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
281 OP_REQUIRES(
282 context,
283 stream
284 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
285 bcast.y_batch_size() * m * n * sizeof(Scalar))
286 .ok(),
287 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
288 "from device"));
289 } else {
290 std::vector<Scalar*> out_ptrs;
291 std::vector<const Scalar*> b_tmp_ptrs;
292 auto* b_base_ptr = in_y.template flat<Scalar>().data();
293 const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
294 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
295 b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
296 }
297 for (int64_t i = 0; i < bcast.output_batch_size(); ++i) {
298 auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
299 auto dst_device_mem =
300 AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
301 OP_REQUIRES(
302 context,
303 stream
304 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
305 m * n * sizeof(Scalar))
306 .ok(),
307 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
308 "from device"));
309 }
310 }
311
312 if (out->NumElements() == 0) {
313 return;
314 }
315
316 #if GOOGLE_CUDA
317
318 cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
319 cublasFillMode_t uplo;
320 cublasOperation_t trans;
321 cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
322
323 // Cublas does
324 // output = matrix \ rhs
325 // where matrix, rhs and output are assumed to be in column major.
326 // We want the output to be in row-major, so we can compute
327 // output' = rhs' / matrix' (' stands for transpose)
328 // Upper/lower needs to be swapped for this.
329
330 uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
331 trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
332 auto solver = absl::make_unique<CudaSolver>(context);
333
334 #elif TENSORFLOW_USE_ROCM
335 rocblas_side side = rocblas_side_right;
336 rocblas_fill uplo;
337 rocblas_operation trans;
338 rocblas_diagonal diag = rocblas_diagonal_non_unit;
339
340 // rocblas does
341 // output = matrix \ rhs
342 // where matrix, rhs and output are assumed to be in column major.
343 // We want the output to be in row-major, so we can compute
344 // output' = rhs' / matrix' (' stands for transpose)
345 // Upper/lower needs to be swapped for this.
346
347 uplo = lower ? rocblas_fill_upper : rocblas_fill_lower;
348 trans = adjoint ? rocblas_operation_conjugate_transpose
349 : rocblas_operation_none;
350 auto solver = absl::make_unique<ROCmSolver>(context);
351
352 #endif
353
354 const uint64 leading_dim_matrix = m;
355 const uint64 leading_dim_output = n;
356 const uint64 colmajor_rows = n;
357 const uint64 colmajor_cols = m;
358
359 const int64_t batch_size = bcast.output_batch_size();
360 std::vector<const Scalar*> a_ptrs;
361 std::vector<Scalar*> out_ptrs;
362 std::vector<const Scalar*> a_tmp_ptrs;
363 a_ptrs.reserve(batch_size);
364 out_ptrs.reserve(batch_size);
365 a_tmp_ptrs.reserve(bcast.x_batch_size());
366 auto* a_base_ptr = in_x.template flat<Scalar>().data();
367 auto* out_base_ptr = out->template flat<Scalar>().data();
368
369 if (!bcast.IsBroadcastingRequired()) {
370 for (int64_t i = 0; i < batch_size; ++i) {
371 a_ptrs.push_back(a_base_ptr + i * m * m);
372 out_ptrs.push_back(out_base_ptr + i * m * n);
373 }
374 } else {
375 const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
376 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
377 a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
378 }
379 for (int64_t i = 0; i < batch_size; ++i) {
380 a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
381 out_ptrs.push_back(out_base_ptr + i * m * n);
382 }
383 }
384
385 typedef Scalar Coefficient;
386 const Scalar alpha = Scalar(1.0);
387
388 #if GOOGLE_CUDA
389
390 // TODO(b/146763573): Consider using Trsv here when the right hand side is
391 // a vector. This will require an explicit transpose since Trsv assumes
392 // CUBLAS_SIDE_LEFT.
393 if (batch_size == 1) {
394 OP_REQUIRES_OK(
395 context,
396 solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
397 &alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
398 out_ptrs[0], leading_dim_output /*ldb*/));
399 } else {
400 // Heuristic for choosing between batched interface vs. non-batched
401 // interface. This is inspired by matrix_solve_op and can probably be
402 // tuned.
403 // TODO(b/146763573): Tune this heuristic.
404 const int kMaxMatrixSizeToBatchSizeRatio = 128;
405 const bool use_batched_solver =
406 m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
407 if (use_batched_solver) {
408 OP_REQUIRES_OK(
409 context, solver->TrsmBatched(
410 side, uplo, trans, diag, colmajor_rows, colmajor_cols,
411 &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
412 &out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
413 } else {
414 for (int batch = 0; batch < batch_size; ++batch) {
415 OP_REQUIRES_OK(
416 context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
417 colmajor_cols, &alpha, a_ptrs[batch],
418 leading_dim_matrix /*lda*/, out_ptrs[batch],
419 leading_dim_output /*ldb*/));
420 }
421 }
422 }
423 #elif TENSORFLOW_USE_ROCM
424 for (int batch = 0; batch < batch_size; ++batch) {
425 OP_REQUIRES_OK(
426 context,
427 solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
428 &alpha, a_ptrs[batch], leading_dim_matrix /*lda*/,
429 out_ptrs[batch], leading_dim_output /*ldb*/));
430 }
431 #endif
432 }
433 };
434
435 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \
436 REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
437 .Device(DEVICE_GPU) \
438 .TypeConstraint<TYPE>("T"), \
439 MatrixTriangularSolveOp<GPUDevice, TYPE>); \
440 REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
441 .Device(DEVICE_GPU) \
442 .TypeConstraint<TYPE>("T"), \
443 MatrixTriangularSolveOp<GPUDevice, TYPE>);
444
445 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
446
447 } // namespace tensorflow
448
449 #endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
450