• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 //
18 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
19 #define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/fill_functor.h"
27 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/matmul_bcast.h"
33 
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #include "tensorflow/core/kernels/transpose_functor.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39 
40 #if GOOGLE_CUDA
41 #include "tensorflow/core/util/cuda_solvers.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "tensorflow/core/util/rocm_solvers.h"
44 #endif
45 
46 namespace tensorflow {
47 
48 typedef Eigen::ThreadPoolDevice CPUDevice;
49 typedef Eigen::GpuDevice GPUDevice;
50 
51 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 template <typename Scalar>
AsDeviceMemory(const Scalar * gpu_memory)53 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
54   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
55   se::DeviceMemory<Scalar> typed(wrapped);
56   return typed;
57 }
58 
59 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
60 
61 // Sequential batch matrix triangular solve kernel that calls Eigen's
62 // matrix triangular solve.
63 template <typename Scalar>
64 struct SequentialMatrixTriangularSolveKernel {
65   using Matrix =
66       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
67   using ConstMatrixMap = Eigen::Map<const Matrix>;
68   using MatrixMap = Eigen::Map<Matrix>;
69   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
70 
ConstTensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel71   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
72                                                       int slice) {
73     return ConstMatrixMap(
74         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
75         t.dim_size(1), t.dim_size(2));
76   }
77 
TensorSliceToEigenMatrixSequentialMatrixTriangularSolveKernel78   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
79     return MatrixMap(
80         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
81         t->dim_size(1), t->dim_size(2));
82   }
83 
RunSequentialMatrixTriangularSolveKernel84   static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
85                   bool adjoint, const MatMulBCast& bcast, Tensor* out,
86                   int start, int limit) {
87     const bool should_bcast = bcast.IsBroadcastingRequired();
88     const auto& x_batch_indices = bcast.x_batch_indices();
89     const auto& y_batch_indices = bcast.y_batch_indices();
90     for (int64 i = start; i < limit; ++i) {
91       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
92       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
93       auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
94       auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
95       auto output = TensorSliceToEigenMatrix(out, i);
96       if (lower) {
97         auto triangle = matrix.template triangularView<Eigen::Lower>();
98         if (adjoint) {
99           output.noalias() = triangle.adjoint().solve(rhs);
100         } else {
101           output.noalias() = triangle.solve(rhs);
102         }
103       } else {
104         auto triangle = matrix.template triangularView<Eigen::Upper>();
105         if (adjoint) {
106           output.noalias() = triangle.adjoint().solve(rhs);
107         } else {
108           output.noalias() = triangle.solve(rhs);
109         }
110       }
111     }
112   }
113 };
114 
115 template <typename Device, typename Scalar>
116 struct LaunchBatchMatrixTriangularSolve;
117 
118 template <typename Scalar>
119 struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
120   static void Launch(OpKernelContext* context, const Tensor& in_x,
121                      const Tensor& in_y, bool adjoint, bool lower,
122                      const MatMulBCast& bcast, Tensor* out) {
123     // Number of matrix triangular solves i.e. size of the batch.
124     const int64 batch_size = bcast.output_batch_size();
125     const int64 cost_per_unit =
126         in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
127     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
128 
129     using Matrix =
130         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
131     using ConstMatrixMap = Eigen::Map<const Matrix>;
132     using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
133     // Check diagonal before doing any solves.
134     auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
135                                  in_x.dim_size(2));
136     const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
137     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
138                 errors::InvalidArgument("Input matrix is not invertible."));
139 
140     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
141           cost_per_unit,
142           [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
143             SequentialMatrixTriangularSolveKernel<Scalar>::Run(
144                 in_x, in_y, lower, adjoint, bcast, out, start, limit);
145           });
146   }
147 };
148 
149 template <typename Device, typename Scalar>
150 class BaseMatrixTriangularSolveOp : public OpKernel {
151  public:
152   explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
153       : OpKernel(context) {
154     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
155     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
156   }
157 
158   ~BaseMatrixTriangularSolveOp() override {}
159 
160   void Compute(OpKernelContext* ctx) override {
161     const Tensor& in0 = ctx->input(0);
162     const Tensor& in1 = ctx->input(1);
163 
164     ValidateInputTensors(ctx, in0, in1);
165 
166     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
167     OP_REQUIRES(
168         ctx, bcast.IsValid(),
169         errors::InvalidArgument(
170             "In[0] and In[1] must have compatible batch dimensions: ",
171             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
172 
173     TensorShape out_shape = bcast.output_batch_shape();
174     auto batch_size = bcast.output_batch_size();
175     auto d0 = in0.dim_size(in0.dims() - 2);
176     auto d1 = in0.dim_size(in0.dims() - 1);
177     Tensor in0_reshaped;
178     OP_REQUIRES(
179         ctx,
180         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
181         errors::Internal("Failed to reshape In[0] from ",
182                          in0.shape().DebugString()));
183     auto d2 = in1.dim_size(in1.dims() - 2);
184     auto d3 = in1.dim_size(in1.dims() - 1);
185     Tensor in1_reshaped;
186     OP_REQUIRES(
187         ctx,
188         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
189         errors::Internal("Failed to reshape In[1] from ",
190                          in1.shape().DebugString()));
191     if (adjoint_) std::swap(d0, d1);
192     OP_REQUIRES(ctx, d1 == d2,
193                 errors::InvalidArgument(
194                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
195                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
196                     " ", lower_, " ", adjoint_));
197     out_shape.AddDim(d0);
198     out_shape.AddDim(d3);
199     Tensor* out = nullptr;
200     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
201     if (out->NumElements() == 0) {
202       return;
203     }
204     Tensor out_reshaped;
205     OP_REQUIRES(ctx,
206                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
207                 errors::Internal("Failed to reshape output from ",
208                                  out->shape().DebugString()));
209     LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
210         ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
211         &out_reshaped);
212   }
213 
214  private:
215   virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
216                                     const Tensor& in1) = 0;
217   bool lower_;
218   bool adjoint_;
219 };
220 
221 template <class Device, class Scalar>
222 class MatrixTriangularSolveOp
223     : public BaseMatrixTriangularSolveOp<Device, Scalar> {
224  public:
225   explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
226       : BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
227 
228   ~MatrixTriangularSolveOp() override {}
229 
230  private:
231   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
232                             const Tensor& in1) override {
233     OP_REQUIRES(
234         ctx, in0.dims() >= 2,
235         errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
236 
237     OP_REQUIRES(
238         ctx, in1.dims() >= 2,
239         errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
240   }
241 };
242 
243 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE)             \
244   REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve")              \
245                               .Device(DEVICE_CPU)                    \
246                               .TypeConstraint<TYPE>("T"),            \
247                           MatrixTriangularSolveOp<CPUDevice, TYPE>); \
248   REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve")         \
249                               .Device(DEVICE_CPU)                    \
250                               .TypeConstraint<TYPE>("T"),            \
251                           MatrixTriangularSolveOp<CPUDevice, TYPE>);
252 
253 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
254 
255 template <typename Scalar>
256 struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
257   static void Launch(OpKernelContext* context, const Tensor& in_x,
258                      const Tensor& in_y, bool adjoint, bool lower,
259                      const MatMulBCast& bcast, Tensor* out) {
260     auto* stream = context->op_device_context()->stream();
261 
262     const uint64 m = in_x.dim_size(1);
263     const uint64 n = out->dim_size(2);
264 
265     //  Do a memcpy when we don't need to broadcast.
266     if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
267       auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
268       auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
269       OP_REQUIRES(
270           context,
271           stream
272               ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
273                               bcast.y_batch_size() * m * n * sizeof(Scalar))
274               .ok(),
275           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
276                            "from device"));
277     } else {
278       std::vector<Scalar*> out_ptrs;
279       std::vector<const Scalar*> b_tmp_ptrs;
280       auto* b_base_ptr = in_y.template flat<Scalar>().data();
281       const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
282       for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
283         b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
284       }
285       for (int64 i = 0; i < bcast.output_batch_size(); ++i) {
286         auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
287         auto dst_device_mem =
288             AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
289         OP_REQUIRES(
290             context,
291             stream
292                 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
293                                 m * n * sizeof(Scalar))
294                 .ok(),
295             errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
296                              "from device"));
297       }
298     }
299 
300     if (out->NumElements() == 0) {
301       return;
302     }
303 
304 #if GOOGLE_CUDA
305 
306     cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
307     cublasFillMode_t uplo;
308     cublasOperation_t trans;
309     cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
310 
311     // Cublas does
312     // output = matrix \ rhs
313     // where matrix, rhs and output are assumed to be in column major.
314     // We want the output to be in row-major, so we can compute
315     // output' = rhs' / matrix' (' stands for transpose)
316     // Upper/lower needs to be swapped for this.
317 
318     uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
319     trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
320     auto solver = absl::make_unique<CudaSolver>(context);
321 
322 #elif TENSORFLOW_USE_ROCM
323     rocblas_side side = rocblas_side_right;
324     rocblas_fill uplo;
325     rocblas_operation trans;
326     rocblas_diagonal diag = rocblas_diagonal_non_unit;
327 
328     // rocblas does
329     // output = matrix \ rhs
330     // where matrix, rhs and output are assumed to be in column major.
331     // We want the output to be in row-major, so we can compute
332     // output' = rhs' / matrix' (' stands for transpose)
333     // Upper/lower needs to be swapped for this.
334 
335     uplo = lower ? rocblas_fill_upper : rocblas_fill_lower;
336     trans = adjoint ? rocblas_operation_conjugate_transpose
337                     : rocblas_operation_none;
338     auto solver = absl::make_unique<ROCmSolver>(context);
339 
340 #endif
341 
342     const uint64 leading_dim_matrix = m;
343     const uint64 leading_dim_output = n;
344     const uint64 colmajor_rows = n;
345     const uint64 colmajor_cols = m;
346 
347     const int64 batch_size = bcast.output_batch_size();
348     std::vector<const Scalar*> a_ptrs;
349     std::vector<Scalar*> out_ptrs;
350     std::vector<const Scalar*> a_tmp_ptrs;
351     a_ptrs.reserve(batch_size);
352     out_ptrs.reserve(batch_size);
353     a_tmp_ptrs.reserve(bcast.x_batch_size());
354     auto* a_base_ptr = in_x.template flat<Scalar>().data();
355     auto* out_base_ptr = out->template flat<Scalar>().data();
356 
357     if (!bcast.IsBroadcastingRequired()) {
358       for (int64 i = 0; i < batch_size; ++i) {
359         a_ptrs.push_back(a_base_ptr + i * m * m);
360         out_ptrs.push_back(out_base_ptr + i * m * n);
361       }
362     } else {
363       const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
364       for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
365         a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
366       }
367       for (int64 i = 0; i < batch_size; ++i) {
368         a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
369         out_ptrs.push_back(out_base_ptr + i * m * n);
370       }
371     }
372 
373     typedef Scalar Coefficient;
374     const Scalar alpha = Scalar(1.0);
375 
376 #if GOOGLE_CUDA
377 
378     // TODO(b/146763573): Consider using Trsv here when the right hand side is
379     // a vector. This will require an explicit transpose since Trsv assumes
380     // CUBLAS_SIDE_LEFT.
381     if (batch_size == 1) {
382       OP_REQUIRES_OK(
383           context,
384           solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
385                        &alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
386                        out_ptrs[0], leading_dim_output /*ldb*/));
387     } else {
388       // Heuristic for choosing between batched interface vs. non-batched
389       // interface. This is inspired by matrix_solve_op and can probably be
390       // tuned.
391       // TODO(b/146763573): Tune this heuristic.
392       const int kMaxMatrixSizeToBatchSizeRatio = 128;
393       const bool use_batched_solver =
394           m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
395       if (use_batched_solver) {
396         OP_REQUIRES_OK(
397             context, solver->TrsmBatched(
398                          side, uplo, trans, diag, colmajor_rows, colmajor_cols,
399                          &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
400                          &out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
401       } else {
402         for (int batch = 0; batch < batch_size; ++batch) {
403           OP_REQUIRES_OK(
404               context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
405                                     colmajor_cols, &alpha, a_ptrs[batch],
406                                     leading_dim_matrix /*lda*/, out_ptrs[batch],
407                                     leading_dim_output /*ldb*/));
408         }
409       }
410     }
411 #elif TENSORFLOW_USE_ROCM
412     for (int batch = 0; batch < batch_size; ++batch) {
413       OP_REQUIRES_OK(
414           context,
415           solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
416                        &alpha, a_ptrs[batch], leading_dim_matrix /*lda*/,
417                        out_ptrs[batch], leading_dim_output /*ldb*/));
418     }
419 #endif
420   }
421 };
422 
423 #define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE)             \
424   REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve")              \
425                               .Device(DEVICE_GPU)                    \
426                               .TypeConstraint<TYPE>("T"),            \
427                           MatrixTriangularSolveOp<GPUDevice, TYPE>); \
428   REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve")         \
429                               .Device(DEVICE_GPU)                    \
430                               .TypeConstraint<TYPE>("T"),            \
431                           MatrixTriangularSolveOp<GPUDevice, TYPE>);
432 
433 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
434 
435 }  // namespace tensorflow
436 
437 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
438