1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/linalg_ops_common.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/platform/types.h"
27
28 #if GOOGLE_CUDA
29 #include "tensorflow/core/platform/stream_executor.h"
30 #endif // GOOGLE_CUDA
31
32 namespace tensorflow {
33
34 #if GOOGLE_CUDA
35 namespace {
36 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)37 perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
38 const Scalar* cuda_memory) {
39 perftools::gputools::DeviceMemoryBase wrapped(
40 const_cast<Scalar*>(cuda_memory));
41 perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
42 return typed;
43 }
44 } // namespace
45 #endif // GOOGLE_CUDA
46
47 template <class Scalar>
48 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
49 public:
50 INHERIT_LINALG_TYPEDEFS(Scalar);
51
MatrixTriangularSolveOp(OpKernelConstruction * context)52 explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
53 : Base(context), lower_(true), adjoint_(false) {
54 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
55 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
56 }
57
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const58 void ValidateInputMatrixShapes(
59 OpKernelContext* context,
60 const TensorShapes& input_matrix_shapes) const final {
61 Base::ValidateSquareSolver(context, input_matrix_shapes);
62 }
63
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const64 TensorShapes GetOutputMatrixShapes(
65 const TensorShapes& input_matrix_shapes) const final {
66 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
67 input_matrix_shapes[1].dim_size(1)})});
68 }
69
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const70 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
71 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
72 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
73 double cost = rows * rows * num_rhss *
74 (Eigen::TensorOpCost::AddCost<Scalar>() +
75 Eigen::TensorOpCost::MulCost<Scalar>());
76 return cost >= static_cast<double>(kint64max) ? kint64max
77 : static_cast<int64>(cost);
78 }
79
EnableInputForwarding() const80 bool EnableInputForwarding() const final { return false; }
81
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)82 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
83 MatrixMaps* outputs) final {
84 const ConstMatrixMap& matrix = inputs[0];
85 const ConstMatrixMap& rhs = inputs[1];
86 MatrixMap& output = outputs->at(0);
87
88 if (matrix.rows() == 0 || rhs.cols() == 0) {
89 // To be consistent with the MatrixInverse op, we define the solution for
90 // an empty set of equation as the empty matrix.
91 return;
92 }
93 const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
94 OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
95 errors::InvalidArgument("Input matrix is not invertible."));
96 if (lower_) {
97 auto triangle = matrix.template triangularView<Eigen::Lower>();
98 if (adjoint_) {
99 output.noalias() = triangle.adjoint().solve(rhs);
100 } else {
101 output.noalias() = triangle.solve(rhs);
102 }
103 } else {
104 auto triangle = matrix.template triangularView<Eigen::Upper>();
105 if (adjoint_) {
106 output.noalias() = triangle.adjoint().solve(rhs);
107 } else {
108 output.noalias() = triangle.solve(rhs);
109 }
110 }
111 }
112
113 private:
114 bool lower_;
115 bool adjoint_;
116
117 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
118 };
119
120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
121 (MatrixTriangularSolveOp<float>), float);
122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
123 (MatrixTriangularSolveOp<double>), double);
124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
125 (MatrixTriangularSolveOp<complex64>), complex64);
126 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
127 (MatrixTriangularSolveOp<complex128>), complex128);
128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
129 (MatrixTriangularSolveOp<float>), float);
130 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
131 (MatrixTriangularSolveOp<double>), double);
132
133 #ifdef GOOGLE_CUDA
134
135 // TODO(rmlarsen): Re-factor to
136 // 1. Enable buffer forwarding from rhs->out.
137 // 2. Save Memcpy when buffer forwarding is used.
138 // 3. Copy entire rhs in a single Memcpy when forwarding is not used.
139 template <class Scalar>
140 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
141 public:
142 INHERIT_LINALG_TYPEDEFS(Scalar);
143
MatrixTriangularSolveOpGPU(OpKernelConstruction * context)144 explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
145 : Base(context), lower_(true), adjoint_(false) {
146 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
147 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
148 }
149
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const150 void ValidateInputMatrixShapes(
151 OpKernelContext* context,
152 const TensorShapes& input_matrix_shapes) const final {
153 Base::ValidateSquareSolver(context, input_matrix_shapes);
154 }
155
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const156 TensorShapes GetOutputMatrixShapes(
157 const TensorShapes& input_matrix_shapes) const final {
158 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
159 input_matrix_shapes[1].dim_size(1)})});
160 }
161
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const162 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
163 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
164 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
165 double cost = rows * rows * num_rhss *
166 (Eigen::TensorOpCost::AddCost<Scalar>() +
167 Eigen::TensorOpCost::MulCost<Scalar>());
168 return cost >= static_cast<double>(kint64max) ? kint64max
169 : static_cast<int64>(cost);
170 }
171
EnableInputForwarding() const172 bool EnableInputForwarding() const final { return false; }
173
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)174 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
175 MatrixMaps* outputs) final {
176 const ConstMatrixMap& matrix = inputs[0];
177 const ConstMatrixMap& rhs = inputs[1];
178 MatrixMap& output = outputs->at(0);
179
180 if (matrix.rows() == 0 || rhs.cols() == 0) {
181 // To be consistent with the MatrixInverse op, we define the solution for
182 // an empty set of equation as the empty matrix.
183 return;
184 }
185
186 auto matrix_ptr = AsDeviceMemory(matrix.data());
187 auto rhs_ptr = AsDeviceMemory(rhs.data());
188 auto out_ptr = AsDeviceMemory(output.data());
189
190 auto* stream = context->op_device_context()->stream();
191 uint64 rhs_elems = rhs.rows() * rhs.cols();
192 bool copy_status =
193 stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
194 .ok();
195 if (!copy_status) {
196 context->SetStatus(
197 errors::Internal("Failed to copy rhs into output before solve"));
198 }
199
200 // Cublas does
201 // output = matrix \ rhs
202 // where matrix, rhs and output are assumed to be in column major.
203 // We want the output to be in row-major, so we can compute
204 // output' = rhs' / matrix' (' stands for transpose)
205 // Upper/lower needs to be swapped for this.
206
207 perftools::gputools::blas::UpperLower upper_lower_matrix;
208 perftools::gputools::blas::Transpose transpose_matrix;
209 if (lower_) {
210 upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper;
211 } else {
212 upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
213 }
214 if (adjoint_) {
215 transpose_matrix =
216 perftools::gputools::blas::Transpose::kConjugateTranspose;
217 } else {
218 transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
219 }
220 uint64 leading_dim_matrix = matrix.cols();
221 uint64 leading_dim_output = output.cols();
222 uint64 colmajor_rows = output.cols();
223 uint64 colmajor_cols = output.rows();
224 bool blas_launch_status =
225 stream
226 ->ThenBlasTrsm(
227 perftools::gputools::blas::Side::kRight /*side*/,
228 upper_lower_matrix /*uplo*/, transpose_matrix /*trans*/,
229 perftools::gputools::blas::Diagonal::kNonUnit /*diag*/,
230 colmajor_rows /*m*/, colmajor_cols /*n*/, Scalar(1.0) /*alpha*/,
231 matrix_ptr, leading_dim_matrix /*lda*/, &out_ptr,
232 leading_dim_output /*ldb*/)
233 .ok();
234 if (!blas_launch_status) {
235 context->SetStatus(errors::Internal("Blas TRSM launch failed"));
236 }
237 }
238
239 private:
240 bool lower_;
241 bool adjoint_;
242
243 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
244 };
245
246 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
247 (MatrixTriangularSolveOpGPU<float>), float);
248 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
249 (MatrixTriangularSolveOpGPU<double>), double);
250 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
251 (MatrixTriangularSolveOpGPU<complex64>), complex64);
252 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
253 (MatrixTriangularSolveOpGPU<complex128>), complex128);
254 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
255 (MatrixTriangularSolveOpGPU<float>), float);
256 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
257 (MatrixTriangularSolveOpGPU<double>), double);
258
259 #endif // GOOGLE_CUDA
260
261 } // namespace tensorflow
262