1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/fill_functor.h"
23 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/matmul_bcast.h"
29
30 namespace tensorflow {
31
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33
34 template <typename Scalar>
eigen_conj(const Scalar & scalar)35 Scalar eigen_conj(const Scalar& scalar) {
36 return Eigen::numext::conj<Scalar>(scalar);
37 }
38
39 // Sequential batch matrix triangular solve kernel that calls Eigen's
40 // matrix triangular solve.
41 template <typename Scalar>
42 struct SequentialBandedTriangularSolveKernel {
43 using Matrix =
44 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45 using ConstMatrixMap = Eigen::Map<const Matrix>;
46 using MatrixMap = Eigen::Map<Matrix>;
47 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
48
ConstTensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel49 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
50 int slice) {
51 return ConstMatrixMap(
52 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
53 t.dim_size(1), t.dim_size(2));
54 }
55
TensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel56 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
57 return MatrixMap(
58 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
59 t->dim_size(1), t->dim_size(2));
60 }
61
Runtensorflow::SequentialBandedTriangularSolveKernel62 static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
63 bool adjoint, const MatMulBCast& bcast, Tensor* out,
64 int start, int limit) {
65 const bool should_bcast = bcast.IsBroadcastingRequired();
66 const auto& x_batch_indices = bcast.x_batch_indices();
67 const auto& y_batch_indices = bcast.y_batch_indices();
68 int num_bands = in_x.dim_size(1);
69 int matrix_size = in_x.dim_size(2);
70
71 for (int64 i = start; i < limit; ++i) {
72 const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
73 const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
74 auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
75 auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
76 auto output = TensorSliceToEigenMatrix(out, i);
77 // Below, we use the standard algorithm for computing a triangular solve,
78 // except we band limit it.
79 // Given A x = b, where A is lower triangular,
80 // x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
81 // j = 0 to i - 1.
82 //
83 // Now, in a banded triangular matrix, when i exceeds the band size,
84 // then the sum goes from j = i - band_size to i - 1, since the other
85 // elements are zero.
86 //
87 // Finally, given the band storage format, we'll need to change the
88 // indexing.
89 if (lower) {
90 if (!adjoint) {
91 output.row(0) = rhs.row(0) / matrix(0, 0);
92 for (int i = 1; i < matrix_size; ++i) {
93 if (i < num_bands) {
94 output.row(i).noalias() =
95 (rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
96 output.topRows(i)) /
97 matrix(0, i);
98 } else {
99 output.row(i).noalias() =
100 (rhs.row(i) -
101 matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
102 output.middleRows(i - (num_bands - 1), num_bands - 1)) /
103 matrix(0, i);
104 }
105 }
106 } else {
107 // In the adjoint case, here and below, we now have an upper (lower)
108 // triangular matrix, and thus need to work through with the other
109 // case. We can't simply conjugate `matrix` and use the upper (lower)
110 // algorithm because the band storage format for upper and lower
111 // triangular matrices are different (in the lower case, we pad
112 // entries on the left, and in the upper case we pad entries on the
113 // right.
114 output.row(matrix_size - 1) =
115 rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
116 for (int i = matrix_size - 1; i >= 0; --i) {
117 output.row(i).noalias() = rhs.row(i);
118 for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
119 output.row(i).noalias() -=
120 eigen_conj(matrix(j - i, j)) * output.row(j);
121 }
122 output.row(i) /= eigen_conj(matrix(0, i));
123 }
124 }
125 } else {
126 if (!adjoint) {
127 output.row(matrix_size - 1) =
128 rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
129 for (int i = 1; i < matrix_size; ++i) {
130 int k = matrix_size - 1 - i;
131 if (i < num_bands) {
132 output.row(k).noalias() =
133 (rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
134 .reverse()
135 .transpose() *
136 output.bottomRows(i)) /
137 matrix(num_bands - 1, k);
138 } else {
139 output.row(k).noalias() =
140 (rhs.row(k) -
141 matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
142 output.middleRows(k + 1, num_bands - 1)) /
143 matrix(num_bands - 1, k);
144 }
145 }
146 } else {
147 output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
148 for (int i = 1; i < matrix_size; ++i) {
149 output.row(i).noalias() = rhs.row(i);
150 for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
151 output.row(i).noalias() -=
152 eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
153 output.row(j);
154 }
155 output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
156 }
157 }
158 }
159 }
160 }
161 };
162
163 template <typename Scalar>
164 struct LaunchBatchBandedTriangularSolve;
165
166 template <typename Scalar>
167 struct LaunchBatchBandedTriangularSolve {
Launchtensorflow::LaunchBatchBandedTriangularSolve168 static void Launch(OpKernelContext* context, const Tensor& in_x,
169 const Tensor& in_y, bool adjoint, bool lower,
170 const MatMulBCast& bcast, Tensor* out) {
171 // Number of banded matrix triangular solves i.e. size of the batch.
172 const int64 batch_size = bcast.output_batch_size();
173 const int64 cost_per_unit =
174 in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
175 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
176
177 using Matrix =
178 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
179 using ConstMatrixMap = Eigen::Map<const Matrix>;
180 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
181 // Check diagonal before doing any solves. This is the first row in the
182 // lower case and else is the last row.
183 auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
184 in_x.dim_size(2));
185 RealScalar min_abs_pivot;
186 if (lower) {
187 min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
188 } else {
189 min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
190 }
191 OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
192 errors::InvalidArgument("Input matrix is not invertible."));
193
194 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
195 cost_per_unit,
196 [&in_x, &in_y, adjoint, lower, &bcast, out](int64 start,
197 int64 limit) {
198 SequentialBandedTriangularSolveKernel<Scalar>::Run(
199 in_x, in_y, lower, adjoint, bcast, out, start, limit);
200 });
201 }
202 };
203
204 template <typename Scalar>
205 class BandedTriangularSolveOpCpu : public OpKernel {
206 public:
BandedTriangularSolveOpCpu(OpKernelConstruction * context)207 explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
208 : OpKernel(context) {
209 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
210 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
211 }
212
~BandedTriangularSolveOpCpu()213 ~BandedTriangularSolveOpCpu() override {}
214
Compute(OpKernelContext * ctx)215 void Compute(OpKernelContext* ctx) override {
216 const Tensor& in0 = ctx->input(0);
217 const Tensor& in1 = ctx->input(1);
218
219 ValidateInputTensors(ctx, in0, in1);
220
221 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
222 OP_REQUIRES(
223 ctx, bcast.IsValid(),
224 errors::InvalidArgument(
225 "In[0] and In[1] must have compatible batch dimensions: ",
226 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
227
228 TensorShape out_shape = bcast.output_batch_shape();
229 auto batch_size = bcast.output_batch_size();
230 auto d0 = in0.dim_size(in0.dims() - 2); // Band size.
231 auto d1 = in0.dim_size(in0.dims() - 1);
232 Tensor in0_reshaped;
233 OP_REQUIRES(
234 ctx,
235 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
236 errors::Internal("Failed to reshape In[0] from ",
237 in0.shape().DebugString()));
238 auto d2 = in1.dim_size(in1.dims() - 2);
239 auto d3 = in1.dim_size(in1.dims() - 1);
240 Tensor in1_reshaped;
241 OP_REQUIRES(
242 ctx,
243 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
244 errors::Internal("Failed to reshape In[1] from ",
245 in1.shape().DebugString()));
246 OP_REQUIRES(ctx, d1 == d2,
247 errors::InvalidArgument(
248 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
249 in0.shape().DebugString(), " ", in1.shape().DebugString(),
250 " ", lower_, " ", adjoint_));
251 out_shape.AddDim(d1);
252 out_shape.AddDim(d3);
253 Tensor* out = nullptr;
254 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
255 if (out->NumElements() == 0) {
256 return;
257 }
258 Tensor out_reshaped;
259 OP_REQUIRES(ctx,
260 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
261 errors::Internal("Failed to reshape output from ",
262 out->shape().DebugString()));
263 LaunchBatchBandedTriangularSolve<Scalar>::Launch(
264 ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
265 &out_reshaped);
266 }
267
268 private:
ValidateInputTensors(OpKernelContext * ctx,const Tensor & in0,const Tensor & in1)269 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
270 const Tensor& in1) {
271 OP_REQUIRES(
272 ctx, in0.dims() >= 2,
273 errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
274
275 OP_REQUIRES(
276 ctx, in1.dims() >= 2,
277 errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
278 }
279 bool lower_;
280 bool adjoint_;
281 };
282
283 #define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE) \
284 REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve") \
285 .Device(DEVICE_CPU) \
286 .TypeConstraint<TYPE>("T"), \
287 BandedTriangularSolveOpCpu<TYPE>);
288
289 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
290 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
291 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
292 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
293
294 } // namespace tensorflow
295