• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/fill_functor.h"
23 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/matmul_bcast.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 
34 template <typename Scalar>
eigen_conj(const Scalar & scalar)35 Scalar eigen_conj(const Scalar& scalar) {
36   return Eigen::numext::conj<Scalar>(scalar);
37 }
38 
39 // Sequential batch matrix triangular solve kernel that calls Eigen's
40 // matrix triangular solve.
41 template <typename Scalar>
42 struct SequentialBandedTriangularSolveKernel {
43   using Matrix =
44       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45   using ConstMatrixMap = Eigen::Map<const Matrix>;
46   using MatrixMap = Eigen::Map<Matrix>;
47   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
48 
ConstTensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel49   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
50                                                       int slice) {
51     return ConstMatrixMap(
52         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
53         t.dim_size(1), t.dim_size(2));
54   }
55 
TensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel56   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
57     return MatrixMap(
58         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
59         t->dim_size(1), t->dim_size(2));
60   }
61 
Runtensorflow::SequentialBandedTriangularSolveKernel62   static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
63                   bool adjoint, const MatMulBCast& bcast, Tensor* out,
64                   int start, int limit) {
65     const bool should_bcast = bcast.IsBroadcastingRequired();
66     const auto& x_batch_indices = bcast.x_batch_indices();
67     const auto& y_batch_indices = bcast.y_batch_indices();
68     int num_bands = in_x.dim_size(1);
69     int matrix_size = in_x.dim_size(2);
70 
71     for (int64 i = start; i < limit; ++i) {
72       const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
73       const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
74       auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
75       auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
76       auto output = TensorSliceToEigenMatrix(out, i);
77       // Below, we use the standard algorithm for computing a triangular solve,
78       // except we band limit it.
79       // Given A x = b, where A is lower triangular,
80       // x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
81       // j = 0 to i - 1.
82       //
83       // Now, in a banded triangular matrix, when i exceeds the band size,
84       // then the sum goes from j = i - band_size to i - 1, since the other
85       // elements are zero.
86       //
87       // Finally, given the band storage format, we'll need to change the
88       // indexing.
89       if (lower) {
90         if (!adjoint) {
91           output.row(0) = rhs.row(0) / matrix(0, 0);
92           for (int i = 1; i < matrix_size; ++i) {
93             if (i < num_bands) {
94               output.row(i).noalias() =
95                   (rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
96                                     output.topRows(i)) /
97                   matrix(0, i);
98             } else {
99               output.row(i).noalias() =
100                   (rhs.row(i) -
101                    matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
102                        output.middleRows(i - (num_bands - 1), num_bands - 1)) /
103                   matrix(0, i);
104             }
105           }
106         } else {
107           // In the adjoint case, here and below, we now have an upper (lower)
108           // triangular matrix, and thus need to work through with the other
109           // case. We can't simply conjugate `matrix` and use the upper (lower)
110           // algorithm because the band storage format for upper and lower
111           // triangular matrices are different (in the lower case, we pad
112           // entries on the left, and in the upper case we pad entries on the
113           // right.
114           output.row(matrix_size - 1) =
115               rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
116           for (int i = matrix_size - 1; i >= 0; --i) {
117             output.row(i).noalias() = rhs.row(i);
118             for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
119               output.row(i).noalias() -=
120                   eigen_conj(matrix(j - i, j)) * output.row(j);
121             }
122             output.row(i) /= eigen_conj(matrix(0, i));
123           }
124         }
125       } else {
126         if (!adjoint) {
127           output.row(matrix_size - 1) =
128               rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
129           for (int i = 1; i < matrix_size; ++i) {
130             int k = matrix_size - 1 - i;
131             if (i < num_bands) {
132               output.row(k).noalias() =
133                   (rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
134                                         .reverse()
135                                         .transpose() *
136                                     output.bottomRows(i)) /
137                   matrix(num_bands - 1, k);
138             } else {
139               output.row(k).noalias() =
140                   (rhs.row(k) -
141                    matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
142                        output.middleRows(k + 1, num_bands - 1)) /
143                   matrix(num_bands - 1, k);
144             }
145           }
146         } else {
147           output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
148           for (int i = 1; i < matrix_size; ++i) {
149             output.row(i).noalias() = rhs.row(i);
150             for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
151               output.row(i).noalias() -=
152                   eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
153                   output.row(j);
154             }
155             output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
156           }
157         }
158       }
159     }
160   }
161 };
162 
163 template <typename Scalar>
164 struct LaunchBatchBandedTriangularSolve;
165 
166 template <typename Scalar>
167 struct LaunchBatchBandedTriangularSolve {
Launchtensorflow::LaunchBatchBandedTriangularSolve168   static void Launch(OpKernelContext* context, const Tensor& in_x,
169                      const Tensor& in_y, bool adjoint, bool lower,
170                      const MatMulBCast& bcast, Tensor* out) {
171     // Number of banded matrix triangular solves i.e. size of the batch.
172     const int64 batch_size = bcast.output_batch_size();
173     const int64 cost_per_unit =
174         in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
175     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
176 
177     using Matrix =
178         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
179     using ConstMatrixMap = Eigen::Map<const Matrix>;
180     using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
181     // Check diagonal before doing any solves. This is the first row in the
182     // lower case and else is the last row.
183     auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
184                                  in_x.dim_size(2));
185     RealScalar min_abs_pivot;
186     if (lower) {
187       min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
188     } else {
189       min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
190     }
191     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
192                 errors::InvalidArgument("Input matrix is not invertible."));
193 
194     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
195           cost_per_unit,
196           [&in_x, &in_y, adjoint, lower, &bcast, out](int64 start,
197                                                       int64 limit) {
198             SequentialBandedTriangularSolveKernel<Scalar>::Run(
199                 in_x, in_y, lower, adjoint, bcast, out, start, limit);
200           });
201   }
202 };
203 
204 template <typename Scalar>
205 class BandedTriangularSolveOpCpu : public OpKernel {
206  public:
BandedTriangularSolveOpCpu(OpKernelConstruction * context)207   explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
208       : OpKernel(context) {
209     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
210     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
211   }
212 
~BandedTriangularSolveOpCpu()213   ~BandedTriangularSolveOpCpu() override {}
214 
Compute(OpKernelContext * ctx)215   void Compute(OpKernelContext* ctx) override {
216     const Tensor& in0 = ctx->input(0);
217     const Tensor& in1 = ctx->input(1);
218 
219     ValidateInputTensors(ctx, in0, in1);
220 
221     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
222     OP_REQUIRES(
223         ctx, bcast.IsValid(),
224         errors::InvalidArgument(
225             "In[0] and In[1] must have compatible batch dimensions: ",
226             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
227 
228     TensorShape out_shape = bcast.output_batch_shape();
229     auto batch_size = bcast.output_batch_size();
230     auto d0 = in0.dim_size(in0.dims() - 2);  // Band size.
231     auto d1 = in0.dim_size(in0.dims() - 1);
232     Tensor in0_reshaped;
233     OP_REQUIRES(
234         ctx,
235         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
236         errors::Internal("Failed to reshape In[0] from ",
237                          in0.shape().DebugString()));
238     auto d2 = in1.dim_size(in1.dims() - 2);
239     auto d3 = in1.dim_size(in1.dims() - 1);
240     Tensor in1_reshaped;
241     OP_REQUIRES(
242         ctx,
243         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
244         errors::Internal("Failed to reshape In[1] from ",
245                          in1.shape().DebugString()));
246     OP_REQUIRES(ctx, d1 == d2,
247                 errors::InvalidArgument(
248                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
249                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
250                     " ", lower_, " ", adjoint_));
251     out_shape.AddDim(d1);
252     out_shape.AddDim(d3);
253     Tensor* out = nullptr;
254     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
255     if (out->NumElements() == 0) {
256       return;
257     }
258     Tensor out_reshaped;
259     OP_REQUIRES(ctx,
260                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
261                 errors::Internal("Failed to reshape output from ",
262                                  out->shape().DebugString()));
263     LaunchBatchBandedTriangularSolve<Scalar>::Launch(
264         ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
265         &out_reshaped);
266   }
267 
268  private:
ValidateInputTensors(OpKernelContext * ctx,const Tensor & in0,const Tensor & in1)269   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
270                             const Tensor& in1) {
271     OP_REQUIRES(
272         ctx, in0.dims() >= 2,
273         errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
274 
275     OP_REQUIRES(
276         ctx, in1.dims() >= 2,
277         errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
278   }
279   bool lower_;
280   bool adjoint_;
281 };
282 
283 #define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE)        \
284   REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve")   \
285                               .Device(DEVICE_CPU)         \
286                               .TypeConstraint<TYPE>("T"), \
287                           BandedTriangularSolveOpCpu<TYPE>);
288 
289 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
290 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
291 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
292 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
293 
294 }  // namespace tensorflow
295