1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #include "third_party/eigen3/Eigen/Core"
19 #include "tensorflow/core/framework/kernel_def_builder.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/fill_functor.h"
23 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/matmul_bcast.h"
29
30 namespace tensorflow {
31
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33
34 template <typename Scalar>
eigen_conj(const Scalar & scalar)35 Scalar eigen_conj(const Scalar& scalar) {
36 return Eigen::numext::conj<Scalar>(scalar);
37 }
38
39 // Sequential batch matrix triangular solve kernel that calls Eigen's
40 // matrix triangular solve.
41 template <typename Scalar>
42 struct SequentialBandedTriangularSolveKernel {
43 using Matrix =
44 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
45 using ConstMatrixMap = Eigen::Map<const Matrix>;
46 using MatrixMap = Eigen::Map<Matrix>;
47 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
48
ConstTensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel49 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
50 int slice) {
51 return ConstMatrixMap(
52 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
53 t.dim_size(1), t.dim_size(2));
54 }
55
TensorSliceToEigenMatrixtensorflow::SequentialBandedTriangularSolveKernel56 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
57 return MatrixMap(
58 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
59 t->dim_size(1), t->dim_size(2));
60 }
61
Runtensorflow::SequentialBandedTriangularSolveKernel62 static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
63 bool adjoint, const MatMulBCast& bcast, Tensor* out,
64 int start, int limit) {
65 const bool should_bcast = bcast.IsBroadcastingRequired();
66 const auto& x_batch_indices = bcast.x_batch_indices();
67 const auto& y_batch_indices = bcast.y_batch_indices();
68 int num_bands = in_x.dim_size(1);
69 int matrix_size = in_x.dim_size(2);
70
71 for (int64_t i = start; i < limit; ++i) {
72 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
73 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
74 auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
75 auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
76 auto output = TensorSliceToEigenMatrix(out, i);
77 // Below, we use the standard algorithm for computing a triangular solve,
78 // except we band limit it.
79 // Given A x = b, where A is lower triangular,
80 // x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
81 // j = 0 to i - 1.
82 //
83 // Now, in a banded triangular matrix, when i exceeds the band size,
84 // then the sum goes from j = i - band_size to i - 1, since the other
85 // elements are zero.
86 //
87 // Finally, given the band storage format, we'll need to change the
88 // indexing.
89 if (lower) {
90 if (!adjoint) {
91 output.row(0) = rhs.row(0) / matrix(0, 0);
92 for (int i = 1; i < matrix_size; ++i) {
93 if (i < num_bands) {
94 output.row(i).noalias() =
95 (rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
96 output.topRows(i)) /
97 matrix(0, i);
98 } else {
99 output.row(i).noalias() =
100 (rhs.row(i) -
101 matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
102 output.middleRows(i - (num_bands - 1), num_bands - 1)) /
103 matrix(0, i);
104 }
105 }
106 } else {
107 // In the adjoint case, here and below, we now have an upper (lower)
108 // triangular matrix, and thus need to work through with the other
109 // case. We can't simply conjugate `matrix` and use the upper (lower)
110 // algorithm because the band storage format for upper and lower
111 // triangular matrices are different (in the lower case, we pad
112 // entries on the left, and in the upper case we pad entries on the
113 // right.
114 output.row(matrix_size - 1) =
115 rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
116 for (int i = matrix_size - 1; i >= 0; --i) {
117 output.row(i).noalias() = rhs.row(i);
118 for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
119 output.row(i).noalias() -=
120 eigen_conj(matrix(j - i, j)) * output.row(j);
121 }
122 output.row(i) /= eigen_conj(matrix(0, i));
123 }
124 }
125 } else {
126 if (!adjoint) {
127 output.row(matrix_size - 1) =
128 rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
129 for (int i = 1; i < matrix_size; ++i) {
130 int k = matrix_size - 1 - i;
131 if (i < num_bands) {
132 output.row(k).noalias() =
133 (rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
134 .reverse()
135 .transpose() *
136 output.bottomRows(i)) /
137 matrix(num_bands - 1, k);
138 } else {
139 output.row(k).noalias() =
140 (rhs.row(k) -
141 matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
142 output.middleRows(k + 1, num_bands - 1)) /
143 matrix(num_bands - 1, k);
144 }
145 }
146 } else {
147 output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
148 for (int i = 1; i < matrix_size; ++i) {
149 output.row(i).noalias() = rhs.row(i);
150 for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
151 output.row(i).noalias() -=
152 eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
153 output.row(j);
154 }
155 output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
156 }
157 }
158 }
159 }
160 }
161 };
162
163 template <typename Scalar>
164 struct LaunchBatchBandedTriangularSolve;
165
166 template <typename Scalar>
167 struct LaunchBatchBandedTriangularSolve {
Launchtensorflow::LaunchBatchBandedTriangularSolve168 static void Launch(OpKernelContext* context, const Tensor& in_x,
169 const Tensor& in_y, bool adjoint, bool lower,
170 const MatMulBCast& bcast, Tensor* out) {
171 // Number of banded matrix triangular solves i.e. size of the batch.
172 const int64_t batch_size = bcast.output_batch_size();
173 const int64_t cost_per_unit =
174 in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
175 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
176
177 using Matrix =
178 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
179 using ConstMatrixMap = Eigen::Map<const Matrix>;
180 using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
181 // Check diagonal before doing any solves. This is the first row in the
182 // lower case and else is the last row.
183 auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
184 in_x.dim_size(2));
185 RealScalar min_abs_pivot;
186 if (lower) {
187 min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
188 } else {
189 min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
190 }
191 OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
192 errors::InvalidArgument("Input matrix is not invertible."));
193
194 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
195 cost_per_unit,
196 [&in_x, &in_y, adjoint, lower, &bcast, out](int64_t start,
197 int64_t limit) {
198 SequentialBandedTriangularSolveKernel<Scalar>::Run(
199 in_x, in_y, lower, adjoint, bcast, out, start, limit);
200 });
201 }
202 };
203
204 template <typename Scalar>
205 class BandedTriangularSolveOpCpu : public OpKernel {
206 public:
BandedTriangularSolveOpCpu(OpKernelConstruction * context)207 explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
208 : OpKernel(context) {
209 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
210 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
211 }
212
~BandedTriangularSolveOpCpu()213 ~BandedTriangularSolveOpCpu() override {}
214
Compute(OpKernelContext * ctx)215 void Compute(OpKernelContext* ctx) override {
216 const Tensor& in0 = ctx->input(0);
217 const Tensor& in1 = ctx->input(1);
218
219 ValidateInputTensors(ctx, in0, in1);
220 if (!ctx->status().ok()) return;
221
222 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
223 OP_REQUIRES(
224 ctx, bcast.IsValid(),
225 errors::InvalidArgument(
226 "In[0] and In[1] must have compatible batch dimensions: ",
227 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
228
229 TensorShape out_shape = bcast.output_batch_shape();
230 auto batch_size = bcast.output_batch_size();
231 auto d0 = in0.dim_size(in0.dims() - 2); // Band size.
232 auto d1 = in0.dim_size(in0.dims() - 1);
233 Tensor in0_reshaped;
234 OP_REQUIRES(
235 ctx,
236 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
237 errors::Internal("Failed to reshape In[0] from ",
238 in0.shape().DebugString()));
239 auto d2 = in1.dim_size(in1.dims() - 2);
240 auto d3 = in1.dim_size(in1.dims() - 1);
241 Tensor in1_reshaped;
242 OP_REQUIRES(
243 ctx,
244 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
245 errors::Internal("Failed to reshape In[1] from ",
246 in1.shape().DebugString()));
247 OP_REQUIRES(ctx, d1 == d2,
248 errors::InvalidArgument(
249 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
250 in0.shape().DebugString(), " ", in1.shape().DebugString(),
251 " ", lower_, " ", adjoint_));
252 out_shape.AddDim(d1);
253 out_shape.AddDim(d3);
254 Tensor* out = nullptr;
255 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
256 if (out->NumElements() == 0) {
257 return;
258 }
259 Tensor out_reshaped;
260 OP_REQUIRES(ctx,
261 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
262 errors::Internal("Failed to reshape output from ",
263 out->shape().DebugString()));
264 LaunchBatchBandedTriangularSolve<Scalar>::Launch(
265 ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
266 &out_reshaped);
267 }
268
269 private:
ValidateInputTensors(OpKernelContext * ctx,const Tensor & in0,const Tensor & in1)270 void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
271 const Tensor& in1) {
272 OP_REQUIRES(
273 ctx, in0.dims() >= 2,
274 errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
275
276 OP_REQUIRES(
277 ctx, in1.dims() >= 2,
278 errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
279
280 OP_REQUIRES(ctx, in0.NumElements() > 0,
281 errors::InvalidArgument("In[0] must not be an empty tensor: ",
282 in0.DebugString()));
283
284 OP_REQUIRES(ctx, in1.NumElements() > 0,
285 errors::InvalidArgument("In[1] must not be an empty tensor: ",
286 in1.DebugString()));
287 }
288 bool lower_;
289 bool adjoint_;
290 };
291
292 #define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE) \
293 REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve") \
294 .Device(DEVICE_CPU) \
295 .TypeConstraint<TYPE>("T"), \
296 BandedTriangularSolveOpCpu<TYPE>);
297
298 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
299 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
300 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
301 REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
302
303 } // namespace tensorflow
304