• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/linalg_ops_common.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 #if GOOGLE_CUDA
29 #include "tensorflow/core/platform/stream_executor.h"
30 #endif  // GOOGLE_CUDA
31 
32 namespace tensorflow {
33 
34 #if GOOGLE_CUDA
35 namespace {
36 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)37 perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
38     const Scalar* cuda_memory) {
39   perftools::gputools::DeviceMemoryBase wrapped(
40       const_cast<Scalar*>(cuda_memory));
41   perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
42   return typed;
43 }
44 }  // namespace
45 #endif  // GOOGLE_CUDA
46 
47 template <class Scalar>
48 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
49  public:
50   INHERIT_LINALG_TYPEDEFS(Scalar);
51 
MatrixTriangularSolveOp(OpKernelConstruction * context)52   explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
53       : Base(context), lower_(true), adjoint_(false) {
54     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
55     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
56   }
57 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const58   void ValidateInputMatrixShapes(
59       OpKernelContext* context,
60       const TensorShapes& input_matrix_shapes) const final {
61     Base::ValidateSquareSolver(context, input_matrix_shapes);
62   }
63 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const64   TensorShapes GetOutputMatrixShapes(
65       const TensorShapes& input_matrix_shapes) const final {
66     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
67                                       input_matrix_shapes[1].dim_size(1)})});
68   }
69 
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const70   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
71     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
72     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
73     double cost = rows * rows * num_rhss *
74                   (Eigen::TensorOpCost::AddCost<Scalar>() +
75                    Eigen::TensorOpCost::MulCost<Scalar>());
76     return cost >= static_cast<double>(kint64max) ? kint64max
77                                                   : static_cast<int64>(cost);
78   }
79 
EnableInputForwarding() const80   bool EnableInputForwarding() const final { return false; }
81 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)82   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
83                      MatrixMaps* outputs) final {
84     const ConstMatrixMap& matrix = inputs[0];
85     const ConstMatrixMap& rhs = inputs[1];
86     MatrixMap& output = outputs->at(0);
87 
88     if (matrix.rows() == 0 || rhs.cols() == 0) {
89       // To be consistent with the MatrixInverse op, we define the solution for
90       // an empty set of equation as the empty matrix.
91       return;
92     }
93     const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
94     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
95                 errors::InvalidArgument("Input matrix is not invertible."));
96     if (lower_) {
97       auto triangle = matrix.template triangularView<Eigen::Lower>();
98       if (adjoint_) {
99         output.noalias() = triangle.adjoint().solve(rhs);
100       } else {
101         output.noalias() = triangle.solve(rhs);
102       }
103     } else {
104       auto triangle = matrix.template triangularView<Eigen::Upper>();
105       if (adjoint_) {
106         output.noalias() = triangle.adjoint().solve(rhs);
107       } else {
108         output.noalias() = triangle.solve(rhs);
109       }
110     }
111   }
112 
113  private:
114   bool lower_;
115   bool adjoint_;
116 
117   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
118 };
119 
120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
121                        (MatrixTriangularSolveOp<float>), float);
122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
123                        (MatrixTriangularSolveOp<double>), double);
124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
125                        (MatrixTriangularSolveOp<complex64>), complex64);
126 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
127                        (MatrixTriangularSolveOp<complex128>), complex128);
128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
129                        (MatrixTriangularSolveOp<float>), float);
130 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
131                        (MatrixTriangularSolveOp<double>), double);
132 
133 #ifdef GOOGLE_CUDA
134 
135 // TODO(rmlarsen): Re-factor to
136 // 1. Enable buffer forwarding from rhs->out.
137 // 2. Save Memcpy when buffer forwarding is used.
138 // 3. Copy entire rhs in a single Memcpy when forwarding is not used.
139 template <class Scalar>
140 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
141  public:
142   INHERIT_LINALG_TYPEDEFS(Scalar);
143 
MatrixTriangularSolveOpGPU(OpKernelConstruction * context)144   explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
145       : Base(context), lower_(true), adjoint_(false) {
146     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
147     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
148   }
149 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const150   void ValidateInputMatrixShapes(
151       OpKernelContext* context,
152       const TensorShapes& input_matrix_shapes) const final {
153     Base::ValidateSquareSolver(context, input_matrix_shapes);
154   }
155 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const156   TensorShapes GetOutputMatrixShapes(
157       const TensorShapes& input_matrix_shapes) const final {
158     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
159                                       input_matrix_shapes[1].dim_size(1)})});
160   }
161 
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const162   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
163     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
164     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
165     double cost = rows * rows * num_rhss *
166                   (Eigen::TensorOpCost::AddCost<Scalar>() +
167                    Eigen::TensorOpCost::MulCost<Scalar>());
168     return cost >= static_cast<double>(kint64max) ? kint64max
169                                                   : static_cast<int64>(cost);
170   }
171 
EnableInputForwarding() const172   bool EnableInputForwarding() const final { return false; }
173 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)174   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
175                      MatrixMaps* outputs) final {
176     const ConstMatrixMap& matrix = inputs[0];
177     const ConstMatrixMap& rhs = inputs[1];
178     MatrixMap& output = outputs->at(0);
179 
180     if (matrix.rows() == 0 || rhs.cols() == 0) {
181       // To be consistent with the MatrixInverse op, we define the solution for
182       // an empty set of equation as the empty matrix.
183       return;
184     }
185 
186     auto matrix_ptr = AsDeviceMemory(matrix.data());
187     auto rhs_ptr = AsDeviceMemory(rhs.data());
188     auto out_ptr = AsDeviceMemory(output.data());
189 
190     auto* stream = context->op_device_context()->stream();
191     uint64 rhs_elems = rhs.rows() * rhs.cols();
192     bool copy_status =
193         stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
194             .ok();
195     if (!copy_status) {
196       context->SetStatus(
197           errors::Internal("Failed to copy rhs into output before solve"));
198     }
199 
200     // Cublas does
201     // output = matrix \ rhs
202     // where matrix, rhs and output are assumed to be in column major.
203     // We want the output to be in row-major, so we can compute
204     // output' = rhs' / matrix' (' stands for transpose)
205     // Upper/lower needs to be swapped for this.
206 
207     perftools::gputools::blas::UpperLower upper_lower_matrix;
208     perftools::gputools::blas::Transpose transpose_matrix;
209     if (lower_) {
210       upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper;
211     } else {
212       upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
213     }
214     if (adjoint_) {
215       transpose_matrix =
216           perftools::gputools::blas::Transpose::kConjugateTranspose;
217     } else {
218       transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
219     }
220     uint64 leading_dim_matrix = matrix.cols();
221     uint64 leading_dim_output = output.cols();
222     uint64 colmajor_rows = output.cols();
223     uint64 colmajor_cols = output.rows();
224     bool blas_launch_status =
225         stream
226             ->ThenBlasTrsm(
227                 perftools::gputools::blas::Side::kRight /*side*/,
228                 upper_lower_matrix /*uplo*/, transpose_matrix /*trans*/,
229                 perftools::gputools::blas::Diagonal::kNonUnit /*diag*/,
230                 colmajor_rows /*m*/, colmajor_cols /*n*/, Scalar(1.0) /*alpha*/,
231                 matrix_ptr, leading_dim_matrix /*lda*/, &out_ptr,
232                 leading_dim_output /*ldb*/)
233             .ok();
234     if (!blas_launch_status) {
235       context->SetStatus(errors::Internal("Blas TRSM launch failed"));
236     }
237   }
238 
239  private:
240   bool lower_;
241   bool adjoint_;
242 
243   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
244 };
245 
246 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
247                        (MatrixTriangularSolveOpGPU<float>), float);
248 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
249                        (MatrixTriangularSolveOpGPU<double>), double);
250 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
251                        (MatrixTriangularSolveOpGPU<complex64>), complex64);
252 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
253                        (MatrixTriangularSolveOpGPU<complex128>), complex128);
254 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
255                        (MatrixTriangularSolveOpGPU<float>), float);
256 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
257                        (MatrixTriangularSolveOpGPU<double>), double);
258 
259 #endif  // GOOGLE_CUDA
260 
261 }  // namespace tensorflow
262