1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 18 19 // See docs in ../ops/linalg_ops.cc. 20 // 21 // This header file is used by the individual qr_*op*.cc files for registering 22 // individual kernels. A separate file is used for each instantiated kernel to 23 // improve compilation times. 24 #include <algorithm> 25 #include <numeric> 26 27 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 28 #define EIGEN_USE_GPU 29 #endif 30 31 #include "third_party/eigen3/Eigen/QR" 32 #include "tensorflow/core/framework/kernel_def_builder.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/platform/types.h" 41 42 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 43 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 44 #include "tensorflow/core/kernels/cwise_ops.h" 45 #include "tensorflow/core/kernels/linalg/eye_functor.h" 46 #include "tensorflow/core/kernels/linalg/matrix_band_part_op.h" 47 #include "tensorflow/core/kernels/transpose_functor.h" 48 #include "tensorflow/core/util/gpu_solvers.h" 49 #endif 50 51 namespace tensorflow { 52 53 template <class Scalar> 54 class QrOp : public LinearAlgebraOp<Scalar> { 55 public: 56 typedef LinearAlgebraOp<Scalar> Base; 57 QrOp(OpKernelConstruction * context)58 explicit QrOp(OpKernelConstruction* context) : Base(context) { 59 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 60 } 61 62 using TensorShapes = typename Base::TensorShapes; 63 ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)64 void ValidateInputMatrixShapes( 65 OpKernelContext* context, 66 const TensorShapes& input_matrix_shapes) const final { 67 Base::ValidateSingleMatrix(context, input_matrix_shapes); 68 } 69 GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)70 TensorShapes GetOutputMatrixShapes( 71 const TensorShapes& input_matrix_shapes) const final { 72 int64_t m = input_matrix_shapes[0].dim_size(0); 73 int64_t n = input_matrix_shapes[0].dim_size(1); 74 int64_t min_size = std::min(m, n); 75 if (full_matrices_) { 76 return TensorShapes({TensorShape({m, m}), TensorShape({m, n})}); 77 } else { 78 return TensorShapes( 79 {TensorShape({m, min_size}), TensorShape({min_size, n})}); 80 } 81 } 82 GetCostPerUnit(const TensorShapes & input_matrix_shapes)83 int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 84 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 85 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 86 double max_size = std::max(m, n); 87 double min_size = std::min(m, n); 88 double cost = 2 * max_size * min_size * min_size - 89 2 * min_size * min_size * min_size / 3.; 90 // TODO(jpoulson): Increase the cost if full_matrices is true in a manner 91 // that reflects the algorithm used for the expansion. 92 return cost >= static_cast<double>(kint64max) ? kint64max 93 : static_cast<int64_t>(cost); 94 } 95 96 using Matrix = typename Base::Matrix; 97 using MatrixMaps = typename Base::MatrixMaps; 98 using ConstMatrixMap = typename Base::ConstMatrixMap; 99 using ConstMatrixMaps = typename Base::ConstMatrixMaps; 100 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)101 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 102 MatrixMaps* outputs) final { 103 Eigen::HouseholderQR<Matrix> qr(inputs[0]); 104 const int m = inputs[0].rows(); 105 const int n = inputs[0].cols(); 106 const int min_size = std::min(m, n); 107 108 if (full_matrices_) { 109 outputs->at(0) = qr.householderQ(); 110 outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>(); 111 } else { 112 // TODO(jpoulson): Exploit the fact that Householder transformations can 113 // be expanded faster than they can be applied to an arbitrary matrix 114 // (Cf. LAPACK's DORGQR). 115 Matrix tmp = Matrix::Identity(m, min_size); 116 outputs->at(0) = qr.householderQ() * tmp; 117 auto qr_top = qr.matrixQR().block(0, 0, min_size, n); 118 outputs->at(1) = qr_top.template triangularView<Eigen::Upper>(); 119 } 120 } 121 122 private: 123 bool full_matrices_; 124 125 TF_DISALLOW_COPY_AND_ASSIGN(QrOp); 126 }; 127 128 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 129 130 typedef Eigen::GpuDevice GPUDevice; 131 132 template <class Scalar> 133 class QrOpGpu : public AsyncOpKernel { 134 public: QrOpGpu(OpKernelConstruction * context)135 explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) { 136 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 137 } 138 ComputeAsync(OpKernelContext * context,DoneCallback done)139 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 140 const Tensor& input = context->input(0); 141 const int ndims = input.dims(); 142 const int64_t m = input.dim_size(ndims - 2); 143 const int64_t n = input.dim_size(ndims - 1); 144 const int64_t min_size = std::min(m, n); 145 const int64_t batch_size = 146 input.template flat_inner_dims<Scalar, 3>().dimension(0); 147 148 // Validate inputs. 149 OP_REQUIRES_ASYNC( 150 context, ndims >= 2, 151 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 152 done); 153 154 // Allocate output. 155 // If full_matrices_ is true then Q is m x m and R is m x n. 156 // Otherwise, Q is m x min(m, n), and R is min(m, n) x n. 157 Tensor* q; 158 TensorShape q_shape = input.shape(); 159 q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size); 160 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q), 161 done); 162 Tensor* r; 163 TensorShape r_shape = input.shape(); 164 r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size); 165 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r), 166 done); 167 168 if (input.NumElements() == 0) { 169 done(); 170 return; 171 } 172 173 // TODO(rmlarsen): Convert to std::make_unique when available. 174 std::unique_ptr<GpuSolver> solver(new GpuSolver(context)); 175 176 // Allocate temporaries. 177 Tensor input_transposed; 178 TensorShape transposed_shape = input.shape(); 179 transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1)); 180 transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2)); 181 182 OP_REQUIRES_OK_ASYNC( 183 context, 184 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 185 transposed_shape, &input_transposed), 186 done); 187 188 Tensor tau; 189 OP_REQUIRES_OK_ASYNC(context, 190 solver->allocate_scoped_tensor( 191 DataTypeToEnum<Scalar>::value, 192 TensorShape({batch_size, min_size}), &tau), 193 done); 194 195 // Transpose input, since cuSolver uses column-major, while TensorFlow uses 196 // row-major storage. 197 const GPUDevice& device = context->eigen_device<GPUDevice>(); 198 OP_REQUIRES_OK_ASYNC( 199 context, DoMatrixTranspose(device, input, &input_transposed), done); 200 201 // Compute QR decomposition in-place in input_transposed. 202 std::vector<DeviceLapackInfo> dev_info; 203 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "geqrf")); 204 auto input_transposed_reshaped = 205 input_transposed.flat_inner_dims<Scalar, 3>(); 206 auto tau_matrix = tau.matrix<Scalar>(); 207 auto r_reshaped = r->flat_inner_dims<Scalar, 3>(); 208 for (int batch = 0; batch < batch_size; ++batch) { 209 OP_REQUIRES_OK_ASYNC( 210 context, 211 solver->Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m, 212 &tau_matrix(batch, 0), 213 dev_info.back().mutable_data() + batch), 214 done); 215 } 216 217 #if GOOGLE_CUDA 218 cublasOperation_t transa = CUBLAS_OP_T; 219 cublasOperation_t transb = CUBLAS_OP_N; 220 cublasSideMode_t side = CUBLAS_SIDE_LEFT; 221 #elif TENSORFLOW_USE_ROCM 222 rocblas_operation transa = rocblas_operation_transpose; 223 rocblas_operation transb = rocblas_operation_none; 224 rocblas_side side = rocblas_side_left; 225 #endif 226 227 // Generate R. R is equal to the upper triangle of the decomposition 228 // stored in input_transposed. Crop, transpose (to get back to row-major) 229 // and copy it to the output buffer. 230 if (full_matrices_ || m == n) { 231 OP_REQUIRES_OK_ASYNC( 232 context, DoMatrixTranspose(device, input_transposed, r), done); 233 } else { 234 const Scalar alpha(1); 235 const Scalar beta(0); 236 const Scalar* dummy = nullptr; 237 for (int batch = 0; batch < batch_size; ++batch) { 238 OP_REQUIRES_OK_ASYNC( 239 context, 240 solver->Geam(transa, transb, n, full_matrices_ ? m : min_size, 241 &alpha, &input_transposed_reshaped(batch, 0, 0), m, 242 &beta, dummy, n, &r_reshaped(batch, 0, 0), n), 243 done); 244 } 245 } 246 // Extract the upper triangle of r (i.e. zero out the strictly lower 247 // triangle). 248 functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part; 249 auto r_reshaped_const = 250 const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>(); 251 band_part(context, device, 0 /* num_lower_diags */, 252 -1 /* num_upper_diags */, r_reshaped_const, r_reshaped); 253 254 // Generate Q from the decomposition in input_transposed. 255 if (m != n && (full_matrices_ || m < n)) { 256 // Generate full m x m matrix Q by computing the product Q^T * I, 257 // where the transpose is to get back to row-major form. 258 // In the complex case we actually form Q^H * I and conjugate it 259 // to get Q in row-major form. 260 functor::EyeFunctor<GPUDevice, Scalar> eye; 261 auto q_reshaped = q->flat_inner_dims<Scalar, 3>(); 262 eye(device, q_reshaped); 263 #if GOOGLE_CUDA 264 cublasOperation_t trans = CublasAdjointOp<Scalar>(); 265 #elif TENSORFLOW_USE_ROCM 266 rocblas_operation trans = RocblasAdjointOp<Scalar>(); 267 #endif 268 for (int batch = 0; batch < batch_size; ++batch) { 269 // Notice: It appears that Unmqr does not write a zero into *info upon 270 // success (probably a bug), so we simply re-use the info array already 271 // zeroed by Geqrf above. 272 OP_REQUIRES_OK_ASYNC( 273 context, 274 solver->Unmqr(side, trans, m, m, min_size, 275 &input_transposed_reshaped(batch, 0, 0), m, 276 &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, 277 dev_info.back().mutable_data() + batch), 278 done); 279 } 280 if (Eigen::NumTraits<Scalar>::IsComplex) { 281 functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 282 conj(device, q->flat<Scalar>() /*out*/, 283 const_cast<const Tensor*>(q)->flat<Scalar>() /*in*/); 284 } 285 } else { 286 // Generate m x n matrix Q. In this case we can use the more efficient 287 // algorithm in Ungqr to generate Q in place. 288 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "orgqr")); 289 for (int batch = 0; batch < batch_size; ++batch) { 290 OP_REQUIRES_OK_ASYNC( 291 context, 292 solver->Ungqr( 293 m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, 294 &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), 295 done); 296 } 297 OP_REQUIRES_OK_ASYNC( 298 context, DoMatrixTranspose(device, input_transposed, q), done); 299 } 300 301 // Asynchronously check return status from cuSolver kernels. 302 GpuSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 303 std::move(done)); 304 } 305 306 private: 307 bool full_matrices_; 308 309 TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu); 310 }; 311 312 #endif // GOOGLE_CUDA 313 314 } // namespace tensorflow 315 316 #endif // TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ 317