1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 18 19 // See docs in ../ops/linalg_ops.cc. 20 // 21 // This header file is used by the individual qr_*op*.cc files for registering 22 // individual kernels. A separate file is used for each instantiated kernel to 23 // improve compilation times. 24 #include <algorithm> 25 #include <numeric> 26 27 #if GOOGLE_CUDA 28 #define EIGEN_USE_GPU 29 #endif 30 31 #include "third_party/eigen3/Eigen/QR" 32 #include "tensorflow/core/framework/kernel_def_builder.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/platform/types.h" 41 42 #if GOOGLE_CUDA 43 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 44 #include "tensorflow/core/kernels/cwise_ops.h" 45 #include "tensorflow/core/kernels/linalg/eye_functor.h" 46 #include "tensorflow/core/kernels/linalg/matrix_band_part_op.h" 47 #include "tensorflow/core/kernels/transpose_functor.h" 48 #include "tensorflow/core/util/cuda_solvers.h" 49 #endif 50 51 namespace tensorflow { 52 53 template <class Scalar> 54 class QrOp : public LinearAlgebraOp<Scalar> { 55 public: 56 typedef LinearAlgebraOp<Scalar> Base; 57 QrOp(OpKernelConstruction * context)58 explicit QrOp(OpKernelConstruction* context) : Base(context) { 59 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 60 } 61 62 using TensorShapes = typename Base::TensorShapes; 63 ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)64 void ValidateInputMatrixShapes( 65 OpKernelContext* context, 66 const TensorShapes& input_matrix_shapes) const final { 67 Base::ValidateSingleMatrix(context, input_matrix_shapes); 68 } 69 GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)70 TensorShapes GetOutputMatrixShapes( 71 const TensorShapes& input_matrix_shapes) const final { 72 int64 m = input_matrix_shapes[0].dim_size(0); 73 int64 n = input_matrix_shapes[0].dim_size(1); 74 int64 min_size = std::min(m, n); 75 if (full_matrices_) { 76 return TensorShapes({TensorShape({m, m}), TensorShape({m, n})}); 77 } else { 78 return TensorShapes( 79 {TensorShape({m, min_size}), TensorShape({min_size, n})}); 80 } 81 } 82 GetCostPerUnit(const TensorShapes & input_matrix_shapes)83 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 84 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 85 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 86 double max_size = std::max(m, n); 87 double min_size = std::min(m, n); 88 double cost = 2 * max_size * min_size * min_size - 89 2 * min_size * min_size * min_size / 3.; 90 // TODO(jpoulson): Increase the cost if full_matrices is true in a manner 91 // that reflects the algorithm used for the expansion. 92 return cost >= static_cast<double>(kint64max) ? kint64max 93 : static_cast<int64>(cost); 94 } 95 96 using Matrix = typename Base::Matrix; 97 using MatrixMaps = typename Base::MatrixMaps; 98 using ConstMatrixMap = typename Base::ConstMatrixMap; 99 using ConstMatrixMaps = typename Base::ConstMatrixMaps; 100 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)101 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 102 MatrixMaps* outputs) final { 103 Eigen::HouseholderQR<Matrix> qr(inputs[0]); 104 const int m = inputs[0].rows(); 105 const int n = inputs[0].cols(); 106 const int min_size = std::min(m, n); 107 108 if (full_matrices_) { 109 outputs->at(0) = qr.householderQ(); 110 outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>(); 111 } else { 112 // TODO(jpoulson): Exploit the fact that Householder transformations can 113 // be expanded faster than they can be applied to an arbitrary matrix 114 // (Cf. LAPACK's DORGQR). 115 Matrix tmp = Matrix::Identity(m, min_size); 116 outputs->at(0) = qr.householderQ() * tmp; 117 auto qr_top = qr.matrixQR().block(0, 0, min_size, n); 118 outputs->at(1) = qr_top.template triangularView<Eigen::Upper>(); 119 } 120 } 121 122 private: 123 bool full_matrices_; 124 125 TF_DISALLOW_COPY_AND_ASSIGN(QrOp); 126 }; 127 128 #if GOOGLE_CUDA 129 130 typedef Eigen::GpuDevice GPUDevice; 131 132 template <class Scalar> 133 class QrOpGpu : public AsyncOpKernel { 134 public: QrOpGpu(OpKernelConstruction * context)135 explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) { 136 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 137 } 138 ComputeAsync(OpKernelContext * context,DoneCallback done)139 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 140 const Tensor& input = context->input(0); 141 const int ndims = input.dims(); 142 const int64 m = input.dim_size(ndims - 2); 143 const int64 n = input.dim_size(ndims - 1); 144 const int64 min_size = std::min(m, n); 145 const int64 batch_size = 146 input.template flat_inner_dims<Scalar, 3>().dimension(0); 147 148 // Validate inputs. 149 OP_REQUIRES_ASYNC( 150 context, ndims >= 2, 151 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 152 done); 153 154 // Allocate output. 155 // If full_matrices_ is true then Q is m x m and R is m x n. 156 // Otherwise, Q is m x min(m, n), and R is min(m, n) x n. 157 Tensor* q; 158 TensorShape q_shape = input.shape(); 159 q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size); 160 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q), 161 done); 162 Tensor* r; 163 TensorShape r_shape = input.shape(); 164 r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size); 165 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r), 166 done); 167 168 if (input.NumElements() == 0) { 169 done(); 170 return; 171 } 172 173 // TODO(rmlarsen): Convert to std::make_unique when available. 174 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 175 176 // Allocate temporaries. 177 Tensor input_transposed; 178 TensorShape transposed_shape = input.shape(); 179 transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1)); 180 transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2)); 181 182 OP_REQUIRES_OK_ASYNC( 183 context, 184 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 185 transposed_shape, &input_transposed), 186 done); 187 188 Tensor tau; 189 OP_REQUIRES_OK_ASYNC(context, 190 solver->allocate_scoped_tensor( 191 DataTypeToEnum<Scalar>::value, 192 TensorShape({batch_size, min_size}), &tau), 193 done); 194 195 // Transpose input, since cuSolver uses column-major, while TensorFlow uses 196 // row-major storage. 197 const GPUDevice& device = context->eigen_device<GPUDevice>(); 198 OP_REQUIRES_OK_ASYNC( 199 context, DoMatrixTranspose(device, input, &input_transposed), done); 200 201 // Compute QR decomposition in-place in input_transposed. 202 std::vector<DeviceLapackInfo> dev_info; 203 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "geqrf")); 204 auto input_transposed_reshaped = 205 input_transposed.flat_inner_dims<Scalar, 3>(); 206 auto tau_matrix = tau.matrix<Scalar>(); 207 auto r_reshaped = r->flat_inner_dims<Scalar, 3>(); 208 for (int batch = 0; batch < batch_size; ++batch) { 209 OP_REQUIRES_OK_ASYNC( 210 context, 211 solver->Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m, 212 &tau_matrix(batch, 0), 213 dev_info.back().mutable_data() + batch), 214 done); 215 } 216 217 // Generate R. R is equal to the upper triangle of the decomposition 218 // stored in input_transposed. Crop, transpose (to get back to row-major) 219 // and copy it to the output buffer. 220 if (full_matrices_ || m == n) { 221 OP_REQUIRES_OK_ASYNC( 222 context, DoMatrixTranspose(device, input_transposed, r), done); 223 } else { 224 const Scalar alpha(1); 225 const Scalar beta(0); 226 const Scalar* dummy = nullptr; 227 for (int batch = 0; batch < batch_size; ++batch) { 228 OP_REQUIRES_OK_ASYNC( 229 context, 230 solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, n, 231 full_matrices_ ? m : min_size, &alpha, 232 &input_transposed_reshaped(batch, 0, 0), m, &beta, 233 dummy, n, &r_reshaped(batch, 0, 0), n), 234 done); 235 } 236 } 237 // Extract the upper triangle of r (i.e. zero out the strictly lower 238 // triangle). 239 functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part; 240 auto r_reshaped_const = 241 const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>(); 242 band_part(context, device, 0 /* num_lower_diags */, 243 -1 /* num_upper_diags */, r_reshaped_const, r_reshaped); 244 245 // Generate Q from the decomposition in input_transposed. 246 if (m != n && (full_matrices_ || m < n)) { 247 // Generate full m x m matrix Q by computing the product Q^T * I, 248 // where the transpose is to get back to row-major form. 249 // In the complex case we actually form Q^H * I and conjugate it 250 // to get Q in row-major form. 251 functor::EyeFunctor<GPUDevice, Scalar> eye; 252 auto q_reshaped = q->flat_inner_dims<Scalar, 3>(); 253 eye(device, q_reshaped); 254 for (int batch = 0; batch < batch_size; ++batch) { 255 // Notice: It appears that Unmqr does not write a zero into *info upon 256 // success (probably a bug), so we simply re-use the info array already 257 // zeroed by Geqrf above. 258 OP_REQUIRES_OK_ASYNC( 259 context, 260 solver->Unmqr(CUBLAS_SIDE_LEFT, CublasAdjointOp<Scalar>(), m, m, 261 min_size, &input_transposed_reshaped(batch, 0, 0), m, 262 &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, 263 dev_info.back().mutable_data() + batch), 264 done); 265 } 266 if (Eigen::NumTraits<Scalar>::IsComplex) { 267 functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 268 conj(device, q->flat<Scalar>() /*out*/, 269 const_cast<const Tensor*>(q)->flat<Scalar>() /*in*/); 270 } 271 } else { 272 // Generate m x n matrix Q. In this case we can use the more efficient 273 // algorithm in Ungqr to generate Q in place. 274 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "orgqr")); 275 for (int batch = 0; batch < batch_size; ++batch) { 276 OP_REQUIRES_OK_ASYNC( 277 context, 278 solver->Ungqr( 279 m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, 280 &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), 281 done); 282 } 283 OP_REQUIRES_OK_ASYNC( 284 context, DoMatrixTranspose(device, input_transposed, q), done); 285 } 286 287 // Asynchronously check return status from cuSolver kernels. 288 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 289 std::move(done)); 290 } 291 292 private: 293 bool full_matrices_; 294 295 TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu); 296 }; 297 298 #endif // GOOGLE_CUDA 299 300 } // namespace tensorflow 301 302 #endif // TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 303