• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/fill_functor.h"
23 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/matmul_bcast.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 
34 template <typename Scalar>
eigen_conj(const Scalar & scalar)35 Scalar eigen_conj(const Scalar& scalar) {
36   return Eigen::numext::conj<Scalar>(scalar);
37 }
38 
39 // Sequential batch matrix triangular solve kernel that calls Eigen's
40 // matrix triangular solve.
41 template <typename Scalar>
42 struct SequentialBandedTriangularSolveKernel {
43   using Matrix =
44       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45   using ConstMatrixMap = Eigen::Map<const Matrix>;
46   using MatrixMap = Eigen::Map<Matrix>;
47   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
48 
ConstTensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel49   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
50                                                       int slice) {
51     return ConstMatrixMap(
52         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
53         t.dim_size(1), t.dim_size(2));
54   }
55 
TensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel56   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
57     return MatrixMap(
58         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
59         t->dim_size(1), t->dim_size(2));
60   }
61 
Runtensorflow::SequentialBandedTriangularSolveKernel62   static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
63                   bool adjoint, const MatMulBCast& bcast, Tensor* out,
64                   int start, int limit) {
65     const bool should_bcast = bcast.IsBroadcastingRequired();
66     const auto& x_batch_indices = bcast.x_batch_indices();
67     const auto& y_batch_indices = bcast.y_batch_indices();
68     int num_bands = in_x.dim_size(1);
69     int matrix_size = in_x.dim_size(2);
70 
71     for (int64_t i = start; i < limit; ++i) {
72       const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
73       const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
74       auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
75       auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
76       auto output = TensorSliceToEigenMatrix(out, i);
77       // Below, we use the standard algorithm for computing a triangular solve,
78       // except we band limit it.
79       // Given A x = b, where A is lower triangular,
80       // x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
81       // j = 0 to i - 1.
82       //
83       // Now, in a banded triangular matrix, when i exceeds the band size,
84       // then the sum goes from j = i - band_size to i - 1, since the other
85       // elements are zero.
86       //
87       // Finally, given the band storage format, we'll need to change the
88       // indexing.
89       if (lower) {
90         if (!adjoint) {
91           output.row(0) = rhs.row(0) / matrix(0, 0);
92           for (int i = 1; i < matrix_size; ++i) {
93             if (i < num_bands) {
94               output.row(i).noalias() =
95                   (rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
96                                     output.topRows(i)) /
97                   matrix(0, i);
98             } else {
99               output.row(i).noalias() =
100                   (rhs.row(i) -
101                    matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
102                        output.middleRows(i - (num_bands - 1), num_bands - 1)) /
103                   matrix(0, i);
104             }
105           }
106         } else {
107           // In the adjoint case, here and below, we now have an upper (lower)
108           // triangular matrix, and thus need to work through with the other
109           // case. We can't simply conjugate `matrix` and use the upper (lower)
110           // algorithm because the band storage format for upper and lower
111           // triangular matrices are different (in the lower case, we pad
112           // entries on the left, and in the upper case we pad entries on the
113           // right.
114           output.row(matrix_size - 1) =
115               rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
116           for (int i = matrix_size - 1; i >= 0; --i) {
117             output.row(i).noalias() = rhs.row(i);
118             for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
119               output.row(i).noalias() -=
120                   eigen_conj(matrix(j - i, j)) * output.row(j);
121             }
122             output.row(i) /= eigen_conj(matrix(0, i));
123           }
124         }
125       } else {
126         if (!adjoint) {
127           output.row(matrix_size - 1) =
128               rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
129           for (int i = 1; i < matrix_size; ++i) {
130             int k = matrix_size - 1 - i;
131             if (i < num_bands) {
132               output.row(k).noalias() =
133                   (rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
134                                         .reverse()
135                                         .transpose() *
136                                     output.bottomRows(i)) /
137                   matrix(num_bands - 1, k);
138             } else {
139               output.row(k).noalias() =
140                   (rhs.row(k) -
141                    matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
142                        output.middleRows(k + 1, num_bands - 1)) /
143                   matrix(num_bands - 1, k);
144             }
145           }
146         } else {
147           output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
148           for (int i = 1; i < matrix_size; ++i) {
149             output.row(i).noalias() = rhs.row(i);
150             for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
151               output.row(i).noalias() -=
152                   eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
153                   output.row(j);
154             }
155             output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
156           }
157         }
158       }
159     }
160   }
161 };
162 
163 template <typename Scalar>
164 struct LaunchBatchBandedTriangularSolve;
165 
166 template <typename Scalar>
167 struct LaunchBatchBandedTriangularSolve {
Launchtensorflow::LaunchBatchBandedTriangularSolve168   static void Launch(OpKernelContext* context, const Tensor& in_x,
169                      const Tensor& in_y, bool adjoint, bool lower,
170                      const MatMulBCast& bcast, Tensor* out) {
171     // Number of banded matrix triangular solves i.e. size of the batch.
172     const int64_t batch_size = bcast.output_batch_size();
173     const int64_t cost_per_unit =
174         in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
175     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
176 
177     using Matrix =
178         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
179     using ConstMatrixMap = Eigen::Map<const Matrix>;
180     using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
181     // Check diagonal before doing any solves. This is the first row in the
182     // lower case and else is the last row.
183     auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
184                                  in_x.dim_size(2));
185     RealScalar min_abs_pivot;
186     if (lower) {
187       min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
188     } else {
189       min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
190     }
191     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
192                 errors::InvalidArgument("Input matrix is not invertible."));
193 
194     Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
195           cost_per_unit,
196           [&in_x, &in_y, adjoint, lower, &bcast, out](int64_t start,
197                                                       int64_t limit) {
198             SequentialBandedTriangularSolveKernel<Scalar>::Run(
199                 in_x, in_y, lower, adjoint, bcast, out, start, limit);
200           });
201   }
202 };
203 
204 template <typename Scalar>
205 class BandedTriangularSolveOpCpu : public OpKernel {
206  public:
BandedTriangularSolveOpCpu(OpKernelConstruction * context)207   explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
208       : OpKernel(context) {
209     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
210     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
211   }
212 
~BandedTriangularSolveOpCpu()213   ~BandedTriangularSolveOpCpu() override {}
214 
Compute(OpKernelContext * ctx)215   void Compute(OpKernelContext* ctx) override {
216     const Tensor& in0 = ctx->input(0);
217     const Tensor& in1 = ctx->input(1);
218 
219     ValidateInputTensors(ctx, in0, in1);
220     if (!ctx->status().ok()) return;
221 
222     MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
223     OP_REQUIRES(
224         ctx, bcast.IsValid(),
225         errors::InvalidArgument(
226             "In[0] and In[1] must have compatible batch dimensions: ",
227             in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
228 
229     TensorShape out_shape = bcast.output_batch_shape();
230     auto batch_size = bcast.output_batch_size();
231     auto d0 = in0.dim_size(in0.dims() - 2);  // Band size.
232     auto d1 = in0.dim_size(in0.dims() - 1);
233     Tensor in0_reshaped;
234     OP_REQUIRES(
235         ctx,
236         in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
237         errors::Internal("Failed to reshape In[0] from ",
238                          in0.shape().DebugString()));
239     auto d2 = in1.dim_size(in1.dims() - 2);
240     auto d3 = in1.dim_size(in1.dims() - 1);
241     Tensor in1_reshaped;
242     OP_REQUIRES(
243         ctx,
244         in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
245         errors::Internal("Failed to reshape In[1] from ",
246                          in1.shape().DebugString()));
247     OP_REQUIRES(ctx, d1 == d2,
248                 errors::InvalidArgument(
249                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
250                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
251                     " ", lower_, " ", adjoint_));
252     out_shape.AddDim(d1);
253     out_shape.AddDim(d3);
254     Tensor* out = nullptr;
255     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
256     if (out->NumElements() == 0) {
257       return;
258     }
259     Tensor out_reshaped;
260     OP_REQUIRES(ctx,
261                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
262                 errors::Internal("Failed to reshape output from ",
263                                  out->shape().DebugString()));
264     LaunchBatchBandedTriangularSolve<Scalar>::Launch(
265         ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
266         &out_reshaped);
267   }
268 
269  private:
ValidateInputTensors(OpKernelContext * ctx,const Tensor & in0,const Tensor & in1)270   void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
271                             const Tensor& in1) {
272     OP_REQUIRES(
273         ctx, in0.dims() >= 2,
274         errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
275 
276     OP_REQUIRES(
277         ctx, in1.dims() >= 2,
278         errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
279 
280     OP_REQUIRES(ctx, in0.NumElements() > 0,
281                 errors::InvalidArgument("In[0] must not be an empty tensor: ",
282                                         in0.DebugString()));
283 
284     OP_REQUIRES(ctx, in1.NumElements() > 0,
285                 errors::InvalidArgument("In[1] must not be an empty tensor: ",
286                                         in1.DebugString()));
287   }
288   bool lower_;
289   bool adjoint_;
290 };
291 
292 #define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE)        \
293   REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve")   \
294                               .Device(DEVICE_CPU)         \
295                               .TypeConstraint<TYPE>("T"), \
296                           BandedTriangularSolveOpCpu<TYPE>);
297 
298 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
299 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
300 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
301 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
302 
303 }  // namespace tensorflow
304