• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 //
18 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/fill_functor.h"
27 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/matmul_bcast.h"
33 
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #include "tensorflow/core/kernels/transpose_functor.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/util/cuda_solvers.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "tensorflow/core/util/rocm_solvers.h"
44 #endif
45 
46 namespace tensorflow {
47 
48 typedef Eigen::ThreadPoolDevice CPUDevice;
49 typedef Eigen::GpuDevice GPUDevice;
50 
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 template <typename Scalar>
AsDeviceMemory(const Scalar * gpu_memory)53 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
54   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
55   se::DeviceMemory<Scalar> typed(wrapped);
56   return typed;
57 }
58 
59 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
60 
61 // Sequential batch matrix triangular solve kernel that calls Eigen's
62 // matrix triangular solve.
63 template <typename Scalar>
64 struct SequentialMatrixTriangularSolveKernel {
65   using Matrix =
66       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
67   using ConstMatrixMap = Eigen::Map<const Matrix>;
68   using MatrixMap = Eigen::Map<Matrix>;
69   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
70 
ConstTensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel71   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
72                                                       int slice) {
73     return ConstMatrixMap(
74         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
75         t.dim_size(1), t.dim_size(2));
76   }
77 
TensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel78   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
79     return MatrixMap(
80         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
81         t->dim_size(1), t->dim_size(2));
82   }
83 
RunSequentialMatrixTriangularSolveKernel84   static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
85                   bool adjoint, const MatMulBCast& bcast, Tensor* out,
86                   int start, int limit) {
87     const bool should_bcast = bcast.IsBroadcastingRequired();
88     const auto& x_batch_indices = bcast.x_batch_indices();
89     const auto& y_batch_indices = bcast.y_batch_indices();
90     for (int64_t i = start; i < limit; ++i) {
91       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
92       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
93       auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
94       auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
95       auto output = TensorSliceToEigenMatrix(out, i);
96       if (lower) {
97         auto triangle = matrix.template triangularView<Eigen::Lower>();
98         if (adjoint) {
99           output.noalias() = triangle.adjoint().solve(rhs);
100         } else {
101           output.noalias() = triangle.solve(rhs);
102         }
103       } else {
104         auto triangle = matrix.template triangularView<Eigen::Upper>();
105         if (adjoint) {
106           output.noalias() = triangle.adjoint().solve(rhs);
107         } else {
108           output.noalias() = triangle.solve(rhs);
109         }
110       }
111     }
112   }
113 };
114 
115 template <typename Device, typename Scalar>
116 struct LaunchBatchMatrixTriangularSolve;
117 
118 template <typename Scalar>
119 struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
120   static void Launch(OpKernelContext* context, const Tensor& in_x,
121                      const Tensor& in_y, bool adjoint, bool lower,
122                      const MatMulBCast& bcast, Tensor* out) {
123     // Number of matrix triangular solves i.e. size of the batch.
124     const int64_t batch_size = bcast.output_batch_size();
125     const int64_t cost_per_unit =
126         in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
127     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
128 
129     using Matrix =
130         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
131     using ConstMatrixMap = Eigen::Map<const Matrix>;
132     using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
133     // Check diagonal before doing any solves.
134     auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
135                                  in_x.dim_size(2));
136     const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
137     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
138                 errors::InvalidArgument("Input matrix is not invertible."));
139 
140     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
141           cost_per_unit,
142           [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
143             SequentialMatrixTriangularSolveKernel<Scalar>::Run(
144                 in_x, in_y, lower, adjoint, bcast, out, start, limit);
145           });
146   }
147 };
148 
149 template <typename Device, typename Scalar>
150 class BaseMatrixTriangularSolveOp : public OpKernel {
151  public:
152   explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
153       : OpKernel(context) {
154     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
155     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
156   }
157 
158   ~BaseMatrixTriangularSolveOp() override {}
159 
160   void Compute(OpKernelContext* ctx) override {
161     const Tensor& in0 = ctx->input(0);
162     const Tensor& in1 = ctx->input(1);
163 
164     ValidateInputTensors(ctx, in0, in1);
165     if (!ctx->status().ok()) {
166       return;
167     }
168 
169     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
170     OP_REQUIRES(
171         ctx, bcast.IsValid(),
172         errors::InvalidArgument(
173             "In[0] and In[1] must have compatible batch dimensions: ",
174             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
175 
176     TensorShape out_shape = bcast.output_batch_shape();
177     auto batch_size = bcast.output_batch_size();
178     auto d0 = in0.dim_size(in0.dims() - 2);
179     auto d1 = in0.dim_size(in0.dims() - 1);
180     Tensor in0_reshaped;
181     OP_REQUIRES(
182         ctx,
183         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
184         errors::Internal("Failed to reshape In[0] from ",
185                          in0.shape().DebugString()));
186     auto d2 = in1.dim_size(in1.dims() - 2);
187     auto d3 = in1.dim_size(in1.dims() - 1);
188     Tensor in1_reshaped;
189     OP_REQUIRES(
190         ctx,
191         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
192         errors::Internal("Failed to reshape In[1] from ",
193                          in1.shape().DebugString()));
194     if (adjoint_) std::swap(d0, d1);
195     OP_REQUIRES(ctx, d1 == d2,
196                 errors::InvalidArgument(
197                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
198                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
199                     " ", lower_, " ", adjoint_));
200     out_shape.AddDim(d0);
201     out_shape.AddDim(d3);
202     Tensor* out = nullptr;
203     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
204     if (out->NumElements() == 0) {
205       return;
206     }
207     Tensor out_reshaped;
208     OP_REQUIRES(ctx,
209                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
210                 errors::Internal("Failed to reshape output from ",
211                                  out->shape().DebugString()));
212     LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
213         ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
214         &out_reshaped);
215   }
216 
217  private:
218   virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
219                                     const Tensor& in1) = 0;
220   bool lower_;
221   bool adjoint_;
222 };
223 
224 template <class Device, class Scalar>
225 class MatrixTriangularSolveOp
226     : public BaseMatrixTriangularSolveOp<Device, Scalar> {
227  public:
228   explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
229       : BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
230 
231   ~MatrixTriangularSolveOp() override {}
232 
233  private:
234   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
235                             const Tensor& in1) override {
236     const auto in0_num_dims = in0.dims();
237     OP_REQUIRES(
238         ctx, in0_num_dims >= 2,
239         errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));
240 
241     const auto in1_num_dims = in1.dims();
242     OP_REQUIRES(
243         ctx, in1_num_dims >= 2,
244         errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));
245 
246     const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
247     const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
248     OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
249                 errors::InvalidArgument(
250                     "In[0] matrices in the last dimensions must be square (",
251                     in0_last_dim, " =/= ", in0_prev_dim, ")"));
252   }
253 };
254 
255 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE)             \
256   REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve")              \
257                               .Device(DEVICE_CPU)                    \
258                               .TypeConstraint<TYPE>("T"),            \
259                           MatrixTriangularSolveOp<CPUDevice, TYPE>); \
260   REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve")         \
261                               .Device(DEVICE_CPU)                    \
262                               .TypeConstraint<TYPE>("T"),            \
263                           MatrixTriangularSolveOp<CPUDevice, TYPE>);
264 
265 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266 
267 template <typename Scalar>
268 struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
269   static void Launch(OpKernelContext* context, const Tensor& in_x,
270                      const Tensor& in_y, bool adjoint, bool lower,
271                      const MatMulBCast& bcast, Tensor* out) {
272     auto* stream = context->op_device_context()->stream();
273 
274     const uint64 m = in_x.dim_size(1);
275     const uint64 n = out->dim_size(2);
276 
277     //  Do a memcpy when we don't need to broadcast.
278     if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
279       auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
280       auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
281       OP_REQUIRES(
282           context,
283           stream
284               ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
285                               bcast.y_batch_size() * m * n * sizeof(Scalar))
286               .ok(),
287           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
288                            "from device"));
289     } else {
290       std::vector<Scalar*> out_ptrs;
291       std::vector<const Scalar*> b_tmp_ptrs;
292       auto* b_base_ptr = in_y.template flat<Scalar>().data();
293       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
294       for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
295         b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
296       }
297       for (int64_t i = 0; i < bcast.output_batch_size(); ++i) {
298         auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
299         auto dst_device_mem =
300             AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
301         OP_REQUIRES(
302             context,
303             stream
304                 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
305                                 m * n * sizeof(Scalar))
306                 .ok(),
307             errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
308                              "from device"));
309       }
310     }
311 
312     if (out->NumElements() == 0) {
313       return;
314     }
315 
316 #if GOOGLE_CUDA
317 
318     cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
319     cublasFillMode_t uplo;
320     cublasOperation_t trans;
321     cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
322 
323     // Cublas does
324     // output = matrix \ rhs
325     // where matrix, rhs and output are assumed to be in column major.
326     // We want the output to be in row-major, so we can compute
327     // output' = rhs' / matrix' (' stands for transpose)
328     // Upper/lower needs to be swapped for this.
329 
330     uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
331     trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
332     auto solver = absl::make_unique<CudaSolver>(context);
333 
334 #elif TENSORFLOW_USE_ROCM
335     rocblas_side side = rocblas_side_right;
336     rocblas_fill uplo;
337     rocblas_operation trans;
338     rocblas_diagonal diag = rocblas_diagonal_non_unit;
339 
340     // rocblas does
341     // output = matrix \ rhs
342     // where matrix, rhs and output are assumed to be in column major.
343     // We want the output to be in row-major, so we can compute
344     // output' = rhs' / matrix' (' stands for transpose)
345     // Upper/lower needs to be swapped for this.
346 
347     uplo = lower ? rocblas_fill_upper : rocblas_fill_lower;
348     trans = adjoint ? rocblas_operation_conjugate_transpose
349                     : rocblas_operation_none;
350     auto solver = absl::make_unique<ROCmSolver>(context);
351 
352 #endif
353 
354     const uint64 leading_dim_matrix = m;
355     const uint64 leading_dim_output = n;
356     const uint64 colmajor_rows = n;
357     const uint64 colmajor_cols = m;
358 
359     const int64_t batch_size = bcast.output_batch_size();
360     std::vector<const Scalar*> a_ptrs;
361     std::vector<Scalar*> out_ptrs;
362     std::vector<const Scalar*> a_tmp_ptrs;
363     a_ptrs.reserve(batch_size);
364     out_ptrs.reserve(batch_size);
365     a_tmp_ptrs.reserve(bcast.x_batch_size());
366     auto* a_base_ptr = in_x.template flat<Scalar>().data();
367     auto* out_base_ptr = out->template flat<Scalar>().data();
368 
369     if (!bcast.IsBroadcastingRequired()) {
370       for (int64_t i = 0; i < batch_size; ++i) {
371         a_ptrs.push_back(a_base_ptr + i * m * m);
372         out_ptrs.push_back(out_base_ptr + i * m * n);
373       }
374     } else {
375       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
376       for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
377         a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
378       }
379       for (int64_t i = 0; i < batch_size; ++i) {
380         a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
381         out_ptrs.push_back(out_base_ptr + i * m * n);
382       }
383     }
384 
385     typedef Scalar Coefficient;
386     const Scalar alpha = Scalar(1.0);
387 
388 #if GOOGLE_CUDA
389 
390     // TODO(b/146763573): Consider using Trsv here when the right hand side is
391     // a vector. This will require an explicit transpose since Trsv assumes
392     // CUBLAS_SIDE_LEFT.
393     if (batch_size == 1) {
394       OP_REQUIRES_OK(
395           context,
396           solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
397                        &alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
398                        out_ptrs[0], leading_dim_output /*ldb*/));
399     } else {
400       // Heuristic for choosing between batched interface vs. non-batched
401       // interface. This is inspired by matrix_solve_op and can probably be
402       // tuned.
403       // TODO(b/146763573): Tune this heuristic.
404       const int kMaxMatrixSizeToBatchSizeRatio = 128;
405       const bool use_batched_solver =
406           m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
407       if (use_batched_solver) {
408         OP_REQUIRES_OK(
409             context, solver->TrsmBatched(
410                          side, uplo, trans, diag, colmajor_rows, colmajor_cols,
411                          &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
412                          &out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
413       } else {
414         for (int batch = 0; batch < batch_size; ++batch) {
415           OP_REQUIRES_OK(
416               context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
417                                     colmajor_cols, &alpha, a_ptrs[batch],
418                                     leading_dim_matrix /*lda*/, out_ptrs[batch],
419                                     leading_dim_output /*ldb*/));
420         }
421       }
422     }
423 #elif TENSORFLOW_USE_ROCM
424     for (int batch = 0; batch < batch_size; ++batch) {
425       OP_REQUIRES_OK(
426           context,
427           solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
428                        &alpha, a_ptrs[batch], leading_dim_matrix /*lda*/,
429                        out_ptrs[batch], leading_dim_output /*ldb*/));
430     }
431 #endif
432   }
433 };
434 
435 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE)             \
436   REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve")              \
437                               .Device(DEVICE_GPU)                    \
438                               .TypeConstraint<TYPE>("T"),            \
439                           MatrixTriangularSolveOp<GPUDevice, TYPE>); \
440   REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve")         \
441                               .Device(DEVICE_GPU)                    \
442                               .TypeConstraint<TYPE>("T"),            \
443                           MatrixTriangularSolveOp<GPUDevice, TYPE>);
444 
445 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
446 
447 }  // namespace tensorflow
448 
449 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
450