• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA
17 
18 #define EIGEN_USE_GPU
19 
20 #include <complex>
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/kernels/linalg/determinant_op.h"
25 #include "tensorflow/core/util/gpu_device_functions.h"
26 #include "tensorflow/core/util/gpu_kernel_helper.h"
27 #include "tensorflow/core/util/gpu_solvers.h"
28 
29 namespace tensorflow {
30 namespace functor {
31 
32 typedef Eigen::GpuDevice GPUDevice;
33 namespace {
PermutationOrder(int n,const int * __restrict__ pivots)34 __device__ int PermutationOrder(int n, const int* __restrict__ pivots) {
35   // Compute the order of the permutation from the number of transpositions
36   // encoded in the pivot array, see:
37   // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340
38   int order = 0;
39   for (int i = 0; i < n - 1; ++i) {
40     // Notice: Internally, the cuBlas code uses Fortran convention (1-based)
41     // indexing so we expect pivots[i] == i + 1 for rows that were not moved.
42     order += pivots[i] != (i + 1);
43   }
44   return order;
45 }
46 }  // namespace
47 
48 // This kernel computes either determinant or log_abs_determinant, depending
49 // on the value of the template parameter. If compute_log_abs_det is false,
50 // the sign argument is ignored.
51 template <typename Scalar, bool compute_log_abs_det = true>
DeterminantFromPivotedLUKernel(int nthreads,int n,const Scalar * __restrict__ lu_factor,const int * __restrict__ all_pivots,Scalar * __restrict__ sign,Scalar * __restrict__ log_abs_det)52 __global__ void DeterminantFromPivotedLUKernel(
53     int nthreads, int n, const Scalar* __restrict__ lu_factor,
54     const int* __restrict__ all_pivots, Scalar* __restrict__ sign,
55     Scalar* __restrict__ log_abs_det) {
56   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
57   const int matrix_size = n * n;
58   const int stride = n + 1;
59   // We only parallelize over batches here. Performance is not critical,
60   // since this cheap O(n) kernel always follows an O(n^3) LU factorization.
61   // The main purpose is to avoid having to copy the LU decomposition to
62   // host memory.
63   GPU_1D_KERNEL_LOOP(o_idx, nthreads) {
64     // Initialize sign to (-1)^order.
65     const int order = PermutationOrder(n, all_pivots + o_idx * n);
66     Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1);
67     RealScalar sum_log_abs_det = RealScalar(0);
68     int i_idx = matrix_size * o_idx;
69     for (int i = 0; i < n; ++i, i_idx += stride) {
70       const RealScalar abs_i = Eigen::numext::abs(lu_factor[i_idx]);
71       sum_log_abs_det += Eigen::numext::log(abs_i);
72       prod_sign = prod_sign * (lu_factor[i_idx] / abs_i);
73     }
74     if (!Eigen::numext::isfinite(sum_log_abs_det)) {
75       prod_sign = Scalar(0);
76       sum_log_abs_det = sum_log_abs_det > 0 ? -Eigen::numext::log(RealScalar(0))
77                                             : Eigen::numext::log(RealScalar(0));
78     }
79     if (compute_log_abs_det) {
80       sign[o_idx] = prod_sign;
81       log_abs_det[o_idx] = Scalar(sum_log_abs_det);
82     } else {
83       log_abs_det[o_idx] = prod_sign * Eigen::numext::exp(sum_log_abs_det);
84     }
85   }
86 }
87 
88 template <typename Scalar>
89 struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
operator ()tensorflow::functor::DeterminantFromPivotedLUFunctor90   void operator()(const GPUDevice& device,
91                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
92                   const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
93                   int* info) {
94     const int64 num_matrices = output.size();
95     const int64 n = lu_factor.dimension(2);
96     GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
97 
98     TF_CHECK_OK(GpuLaunchKernel(
99         DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>,
100         config.block_count, config.thread_per_block, 0, device.stream(),
101         config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr,
102         output.data()));
103   }
104 };
105 
106 template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
107 template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
108 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
109 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
110 
111 template <typename Scalar>
112 struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
operator ()tensorflow::functor::LogDeterminantFromPivotedLUFunctor113   void operator()(const GPUDevice& device,
114                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
115                   const int* pivots, typename TTypes<Scalar, 1>::Tensor sign,
116                   typename TTypes<Scalar, 1>::Tensor log_abs_det) {
117     const int64 num_matrices = sign.size();
118     const int64 n = lu_factor.dimension(2);
119     GpuLaunchConfig config = GetGpuLaunchConfig(num_matrices, device);
120     TF_CHECK_OK(GpuLaunchKernel(
121         DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>,
122         config.block_count, config.thread_per_block, 0, device.stream(),
123         config.virtual_thread_count, n, lu_factor.data(), pivots, sign.data(),
124         log_abs_det.data()));
125   }
126 };
127 
128 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, float>;
129 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, double>;
130 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
131 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
132 
133 }  // namespace functor
134 }  // namespace tensorflow
135 
136 #endif  // GOOGLE_CUDA
137