1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17 #define EIGEN_USE_GPU
18
19 #include <algorithm>
20 #include <vector>
21
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/transpose_functor.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/cuda_solvers.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::GpuDevice GPUDevice;
36
37 namespace {
38 template <typename Scalar>
ComputePermutationFromTranspositions(int64 num_rows,const int * __restrict__ pivots,Scalar * __restrict__ permutation_indices)39 __device__ void ComputePermutationFromTranspositions(
40 int64 num_rows, const int* __restrict__ pivots,
41 Scalar* __restrict__ permutation_indices) {
42 // Fill in the output array with the identity permutation.
43 for (int i = 0; i < num_rows; ++i) {
44 permutation_indices[i] = Scalar(i);
45 }
46
47 // Compute the permutation from a sequence of transpositions encoded
48 // in the pivot array by applying the transpositions in order on the
49 // identity permutation.
50 for (int i = 0; i < num_rows; ++i) {
51 // Note: Internally, the cuBlas code uses Fortran convention (1-based)
52 // indexing so ith row was swapped with (pivots[i]-1)'th row in 0-based
53 // indexing.
54 Scalar t = permutation_indices[i];
55 permutation_indices[i] = permutation_indices[pivots[i] - 1];
56 permutation_indices[pivots[i] - 1] = t;
57 }
58 }
59 } // namespace
60
61 // Kernel to compute the inverse of a permutation from a sequence of
62 // transpositions.
63 template <typename Scalar>
ComputePermutationFromTranspositionsKernel(GpuLaunchConfig config,const int64 num_rows,const int * __restrict__ all_pivots,Scalar * __restrict__ all_permutation_indices)64 __global__ void ComputePermutationFromTranspositionsKernel(
65 GpuLaunchConfig config, const int64 num_rows,
66 const int* __restrict__ all_pivots,
67 Scalar* __restrict__ all_permutation_indices) {
68 // We only parallelize over batches here. Performance is not critical,
69 // since this cheap O(num_rows) kernel always follows an O(num_rows^3)
70 // LU factorization.
71 GPU_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
72 ComputePermutationFromTranspositions(
73 num_rows, all_pivots + index * num_rows,
74 all_permutation_indices + index * num_rows);
75 }
76 }
77
78 template <class Scalar, class Tidx>
79 class LuOpGpu : public AsyncOpKernel {
80 public:
LuOpGpu(OpKernelConstruction * context)81 explicit LuOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {}
82
ComputeAsync(OpKernelContext * context,DoneCallback done)83 void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
84 const Tensor& input = context->input(0);
85
86 // Analyze shape and validate inputs.
87 const int input_rank = input.dims();
88
89 OP_REQUIRES_ASYNC(
90 context, input_rank >= 2,
91 errors::InvalidArgument("Input must have rank >= 2, got ", input_rank),
92 done);
93
94 const int64 num_rows = input.dim_size(input_rank - 2);
95 const int64 num_cols = input.dim_size(input_rank - 1);
96
97 OP_REQUIRES_ASYNC(
98 context, num_rows == num_cols,
99 errors::InvalidArgument("Input matrices must be squares, got", num_rows,
100 " != ", num_cols),
101 done);
102
103 TensorShape batch_shape;
104 for (int dim = 0; dim < input_rank - 2; ++dim) {
105 batch_shape.AddDim(input.dim_size(dim));
106 }
107 TensorShape permutation_indices_shape = batch_shape;
108 permutation_indices_shape.AddDim(num_rows);
109
110 const GPUDevice& device = context->eigen_device<GPUDevice>();
111 auto solver = absl::make_unique<CudaSolver>(context);
112
113 // We output the packed triangular factors in a dense form.
114 // The lower triangular factor L corresponds to the strictly lower
115 // triangular part of packed_triangular_factors with an implicit unit
116 // diagonal. The upper triangular factor U is the upper triangular part of
117 // packed_triangular_factors. The triangular factors satisfy the equation
118 // P * input_matrix = L * U
119 // where P is the permutation matrix corresponding to the indices in
120 // permutation_indices.
121 //
122 // Reuse the input buffer or make a copy for the factorization step,
123 // depending on whether this ops owns it exclusively.
124 Tensor* packed_triangular_factors;
125 OP_REQUIRES_OK_ASYNC(context,
126 context->forward_input_or_allocate_output(
127 {0}, 0, input.shape(), &packed_triangular_factors),
128 done);
129 if (!packed_triangular_factors->SharesBufferWith(input)) {
130 device.memcpy(packed_triangular_factors->flat<Scalar>().data(),
131 input.flat<Scalar>().data(),
132 input.NumElements() * sizeof(Scalar));
133 }
134
135 // Allocate output permutation.
136 Tensor* permutation_indices = nullptr;
137 OP_REQUIRES_OK_ASYNC(context,
138 context->allocate_output(1, permutation_indices_shape,
139 &permutation_indices),
140 done);
141
142 if (input.NumElements() == 0) {
143 done();
144 return;
145 }
146
147 // Allocate a temporary Tensor to store the transposed packed triangular
148 // factors.
149 Tensor packed_triangular_factors_transpose;
150 OP_REQUIRES_OK_ASYNC(
151 context,
152 context->allocate_temp(DataTypeToEnum<Scalar>::value, input.shape(),
153 &packed_triangular_factors_transpose),
154 done);
155 auto packed_triangular_factors_transpose_reshaped =
156 packed_triangular_factors_transpose
157 .template flat_inner_dims<Scalar, 3>();
158 const int64 batch_size =
159 packed_triangular_factors_transpose_reshaped.dimension(0);
160
161 // Allocate pivots on the device.
162 Tensor pivots;
163 OP_REQUIRES_OK_ASYNC(context,
164 solver->allocate_scoped_tensor(
165 DataTypeToEnum<int32>::value,
166 TensorShape{batch_size, num_rows}, &pivots),
167 done);
168 auto pivots_mat = pivots.template matrix<int32>();
169
170 // Transpose the input. This is necessary because cuBLAS assumes
171 // column-major storage while TensorFlow uses row-major.
172 OP_REQUIRES_OK_ASYNC(
173 context,
174 DoMatrixTranspose(device, *packed_triangular_factors,
175 &packed_triangular_factors_transpose),
176 done);
177
178 std::vector<DeviceLapackInfo> dev_info;
179 if (num_rows == num_cols && num_rows / batch_size <= 128) {
180 // For small matrices or large batch sizes, we use the batched
181 // interface from cuBlas.
182 auto packed_triangular_factors_ptrs = solver->GetScratchSpace<uint8>(
183 sizeof(Scalar*) * batch_size, "packed_triangular_factors_ptrs",
184 /* on_host */ true);
185 const Scalar** packed_triangular_factors_ptrs_base =
186 reinterpret_cast<const Scalar**>(
187 packed_triangular_factors_ptrs.mutable_data());
188 for (int batch = 0; batch < batch_size; ++batch) {
189 packed_triangular_factors_ptrs_base[batch] =
190 &packed_triangular_factors_transpose_reshaped(batch, 0, 0);
191 }
192 dev_info.push_back(
193 solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
194 OP_REQUIRES_OK_ASYNC(
195 context,
196 solver->GetrfBatched(num_rows, packed_triangular_factors_ptrs_base,
197 num_rows, pivots_mat.data(), &dev_info.back(),
198 batch_size),
199 done);
200 } else {
201 // For small batch sizes we use the non-batched interface from cuSolver,
202 // which is much faster for large matrices.
203 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
204 for (int batch = 0; batch < batch_size; ++batch) {
205 OP_REQUIRES_OK_ASYNC(
206 context,
207 solver->Getrf(
208 num_rows, num_cols,
209 &packed_triangular_factors_transpose_reshaped(batch, 0, 0),
210 num_rows, &pivots_mat(batch, 0), &dev_info.back()(batch)),
211 done);
212 }
213 }
214
215 // Transpose the result since we had transposed the input.
216 OP_REQUIRES_OK_ASYNC(
217 context,
218 DoMatrixTranspose(device, packed_triangular_factors_transpose,
219 packed_triangular_factors),
220 done);
221
222 // Pivots encode the permutation of the rows as a sequences of row swaps.
223 // For each index i, row i is swapped with row pivots[i].
224 int* pivots_ptr = pivots.flat<int>().data();
225 Tidx* permutation_indices_ptr =
226 permutation_indices->template flat<Tidx>().data();
227 GpuLaunchConfig cfgPivots = GetGpuLaunchConfig(batch_size, device);
228 TF_CHECK_OK(GpuLaunchKernel(
229 ComputePermutationFromTranspositionsKernel<Tidx>, cfgPivots.block_count,
230 cfgPivots.thread_per_block, 0, device.stream(), cfgPivots, num_rows,
231 pivots_ptr, permutation_indices_ptr));
232
233 // Callback for checking info after kernels finish. Also capture the
234 // temporary Tensors/ScratchSpace so they don't get deallocated before the
235 // kernels run.
236 // TODO(rmlarsen): Use move capture once C++14 becomes available.
237 auto info_checker = [context, done, dev_info](
238 const Status& status,
239 const std::vector<HostLapackInfo>& host_infos) {
240 if (!status.ok() && errors::IsInvalidArgument(status) &&
241 !host_infos.empty()) {
242 for (int i = 0; i < host_infos[0].size(); ++i) {
243 // Match the CPU error message for singular matrices. Otherwise
244 // just print the original error message from the status below.
245 OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0,
246 errors::InvalidArgument("Input is not invertible."),
247 done);
248 }
249 }
250 OP_REQUIRES_OK_ASYNC(context, status, done);
251 done();
252 };
253
254 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
255 std::move(info_checker));
256 }
257 };
258
259 #define REGISTER_LU_GPU(type, idx_type) \
260 REGISTER_KERNEL_BUILDER(Name("Lu") \
261 .Device(DEVICE_GPU) \
262 .TypeConstraint<type>("T") \
263 .TypeConstraint<idx_type>("output_idx_type"), \
264 LuOpGpu<type, idx_type>);
265
266 REGISTER_LU_GPU(float, int32);
267 REGISTER_LU_GPU(double, int32);
268 REGISTER_LU_GPU(complex64, int32);
269 REGISTER_LU_GPU(complex128, int32);
270
271 REGISTER_LU_GPU(float, int64);
272 REGISTER_LU_GPU(double, int64);
273 REGISTER_LU_GPU(complex64, int64);
274 REGISTER_LU_GPU(complex128, int64);
275 } // namespace tensorflow
276
277 #endif // GOOGLE_CUDA
278