1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17
18 #include <numeric>
19 #include <string>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31
32 namespace xla {
33 namespace tridiagonal {
34
35 namespace {
36
CheckSecondToLastDimension(const Shape & op_shape,int64 rank,int64 expected,const std::string & op_name)37 Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
38 int64 expected, const std::string& op_name) {
39 const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
40
41 if (actual_num_dims != expected) {
42 return InvalidArgument(
43 "Second to last dimension of %s should be %d but is %d.", op_name,
44 expected, actual_num_dims);
45 }
46
47 return Status::OK();
48 }
49
CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)50 StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
51 XlaOp main_diagonal,
52 XlaOp upper_diagonal,
53 XlaOp rhs) {
54 XlaBuilder* builder = lower_diagonal.builder();
55
56 TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
57 builder->GetShape(lower_diagonal));
58 TF_ASSIGN_OR_RETURN(Shape main_diagonal_shape,
59 builder->GetShape(main_diagonal));
60 TF_ASSIGN_OR_RETURN(Shape upper_diagonal_shape,
61 builder->GetShape(upper_diagonal));
62 TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs));
63
64 const auto lower_diagonal_rank = lower_diagonal_shape.rank();
65 const auto main_diagonal_rank = main_diagonal_shape.rank();
66 const auto upper_diagonal_rank = upper_diagonal_shape.rank();
67 const auto rhs_rank = rhs_shape.rank();
68 if (!((lower_diagonal_rank == main_diagonal_rank) &&
69 (lower_diagonal_rank == upper_diagonal_rank) &&
70 (lower_diagonal_rank == rhs_rank))) {
71 return InvalidArgument(
72 "All inputs should have the same rank but got rank "
73 "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
74 "%d for rhs",
75 lower_diagonal_rank, main_diagonal_rank, upper_diagonal_rank, rhs_rank);
76 }
77 const auto rank = lower_diagonal_rank;
78 if (rank < 2) {
79 return InvalidArgument("Arguments must have rank >=2; got rank %d.", rank);
80 }
81
82 const auto lower_diagonal_num_eqs =
83 ShapeUtil::GetDimension(lower_diagonal_shape, rank - 1);
84 const auto main_diagonal_num_eqs =
85 ShapeUtil::GetDimension(main_diagonal_shape, rank - 1);
86 const auto upper_diagonal_num_eqs =
87 ShapeUtil::GetDimension(upper_diagonal_shape, rank - 1);
88 const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1);
89 if (!((lower_diagonal_num_eqs == main_diagonal_num_eqs) &&
90 (lower_diagonal_num_eqs == upper_diagonal_num_eqs) &&
91 (lower_diagonal_num_eqs == rhs_num_eqs))) {
92 return InvalidArgument(
93 "All inputs should have the same innermost dimension but got "
94 "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
95 "%d for rhs",
96 lower_diagonal_num_eqs, main_diagonal_num_eqs, upper_diagonal_num_eqs,
97 rhs_num_eqs);
98 }
99 const auto num_equations = lower_diagonal_num_eqs;
100
101 TF_RETURN_IF_ERROR(CheckSecondToLastDimension(lower_diagonal_shape, rank, 1,
102 "lower diagonal"));
103 TF_RETURN_IF_ERROR(
104 CheckSecondToLastDimension(main_diagonal_shape, rank, 1, "diagonal"));
105 TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
106 "upper diagonal"));
107
108 return num_equations;
109 }
110
Coefficient(XlaOp operand,int32 i)111 XlaOp Coefficient(XlaOp operand, int32 i) {
112 return DynamicSliceInMinorDims(operand,
113 /*starts=*/{ConstantR0(operand.builder(), i)},
114 /*sizes=*/{1});
115 }
116
Coefficient(XlaOp operand,XlaOp i)117 XlaOp Coefficient(XlaOp operand, XlaOp i) {
118 return DynamicSliceInMinorDims(operand,
119 /*starts=*/{i}, /*sizes=*/{1});
120 }
121
UpdateEq(XlaOp updated,int32 i,XlaOp update)122 XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
123 return DynamicUpdateSliceInMinorDims(
124 updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
125 }
126
UpdateEq(XlaOp updated,XlaOp i,XlaOp update)127 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
128 return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
129 }
130
131 } // namespace
132
133 // Applies Thomas algorithm to solve a linear system where the linear operand
134 // is a tri-diagonal matrix.
135 // See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm for a simple
136 // reference on the Thomas algorithm.
137 // It is expected that the three diagonals are represented as tensors of shape
138 // [..., 1, num_equations] where num_equations is the number of dimensions of
139 // the unknowns considered in the linear systems.
140 // The first innermost dimension of `lower_diagonal` (`lower_diagonal[..., :,
141 // 0]`) will be ignored. The last innermost dimension of `upper_diagonal`
142 // (`upper_diagonal[..., :, num_equations - 1]`) will be ignored. The shape of
143 // the right-hand-side `rhs` should be [..., num_rhs, num_equations]. The
144 // solution will have the shape [..., num_rhs, num_equations].
ThomasSolver(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)145 StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
146 XlaOp upper_diagonal, XlaOp rhs) {
147 XlaBuilder* builder = lower_diagonal.builder();
148
149 TF_ASSIGN_OR_RETURN(int64 num_eqs,
150 CheckSystemAndReturnNumEquations(
151 lower_diagonal, main_diagonal, upper_diagonal, rhs));
152
153 XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
154 XlaOp rhs_after_elimination = ZerosLike(rhs);
155 XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
156 XlaOp x_coeffs = ZerosLike(rhs);
157
158 // main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
159 main_diag_after_elimination =
160 UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
161
162 // rhs_after_elimination[:, 0] = rhs[:, 0];
163 rhs_after_elimination =
164 UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
165
166 auto preparation_body_fn =
167 [](XlaOp i, absl::Span<const XlaOp> values,
168 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
169 auto upper_diagonal_coeffs = values[0];
170 auto upper_diagonal = values[1];
171 // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
172 upper_diagonal_coeffs =
173 UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
174 return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
175 };
176 TF_ASSIGN_OR_RETURN(auto values_after_preparation,
177 ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
178 {upper_diagonal_coeffs, upper_diagonal},
179 "preparation", builder));
180 upper_diagonal_coeffs = values_after_preparation[0];
181
182 // Forward transformation.
183 auto forward_transformation_fn =
184 [](XlaOp i_minus_one, absl::Span<const XlaOp> values,
185 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
186 auto lower_diagonal = values[0];
187 auto main_diagonal = values[1];
188 auto rhs = values[2];
189 auto main_diag_after_elimination = values[3];
190 auto upper_diagonal_coeffs = values[4];
191 auto rhs_after_elimination = values[5];
192
193 auto one = ScalarLike(i_minus_one, 1);
194 auto i = i_minus_one + one;
195 auto lower_diagonal_i = Coefficient(lower_diagonal, i);
196 auto main_diagonal_i = Coefficient(main_diagonal, i);
197 auto rhs_i = Coefficient(rhs, i);
198
199 auto w_i =
200 lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
201
202 // main_diag_after_elimination[:, i] =
203 // main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
204 main_diag_after_elimination = UpdateEq(
205 main_diag_after_elimination, i,
206 main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
207 // rhs_after_elimination[:, i] =
208 // rhs_i - w_i * rhs_after_elimination[:, i - 1];
209 rhs_after_elimination =
210 UpdateEq(rhs_after_elimination, i,
211 rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
212
213 return std::vector<XlaOp>{lower_diagonal,
214 main_diagonal,
215 rhs,
216 main_diag_after_elimination,
217 upper_diagonal_coeffs,
218 rhs_after_elimination};
219 };
220 TF_ASSIGN_OR_RETURN(
221 auto values_after_fwd_transformation,
222 ForEachIndex(
223 num_eqs - 1, S32, forward_transformation_fn,
224 {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
225 upper_diagonal_coeffs, rhs_after_elimination},
226 "forward_transformation", builder));
227 lower_diagonal = values_after_fwd_transformation[0];
228 main_diagonal = values_after_fwd_transformation[1];
229 rhs = values_after_fwd_transformation[2];
230 main_diag_after_elimination = values_after_fwd_transformation[3];
231 upper_diagonal_coeffs = values_after_fwd_transformation[4];
232 rhs_after_elimination = values_after_fwd_transformation[5];
233
234 // Backward reduction.
235 // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
236 // main_diag_after_elimination[:, num_eqs - 1];
237 x_coeffs =
238 UpdateEq(x_coeffs, num_eqs - 1,
239 Coefficient(rhs_after_elimination, num_eqs - 1) /
240 Coefficient(main_diag_after_elimination, num_eqs - 1));
241 auto bwd_reduction_fn =
242 [num_eqs](XlaOp j, absl::Span<const XlaOp> values,
243 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
244 auto x_coeffs = values[0];
245 auto rhs_after_elimination = values[1];
246 auto upper_diagonal_coeffs = values[2];
247 auto main_diag_after_elimination = values[3];
248 auto n = ScalarLike(j, num_eqs - 2);
249 auto one = ScalarLike(j, 1);
250 auto i = n - j;
251 // for (int i = num_eqs - 2; i >= 0; i--)
252 // x_coeffs[:, i] = (rhs_after_elimination[:, i] -
253 // upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
254 // main_diag_after_elimination[:, i];
255 x_coeffs = UpdateEq(x_coeffs, i,
256 (Coefficient(rhs_after_elimination, i) -
257 Coefficient(upper_diagonal_coeffs, i) *
258 Coefficient(x_coeffs, i + one)) /
259 Coefficient(main_diag_after_elimination, i));
260 return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
261 upper_diagonal_coeffs,
262 main_diag_after_elimination};
263 };
264
265 TF_ASSIGN_OR_RETURN(
266 auto values_after_bwd_reduction,
267 ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
268 {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
269 main_diag_after_elimination},
270 "backward_reduction", builder));
271 x_coeffs = values_after_bwd_reduction[0];
272
273 return x_coeffs;
274 }
275
276 // Applies Thomas algorithm to solve a linear system where the linear operand
277 // is a tri-diagonal matrix.
278 // It is expected that the tree diagonals are stacked into a tensors of shape
279 // [..., 3, num_equations] where num_equations is the number of spatial
280 // dimensions considered in the system.
281 // diagonals[..., 0, :] represents the upper diagonal whose last inner
282 // dimension will be ignored.
283 // diagonals[..., 1, :] represents the main diagonal.
284 // diagonals[..., 2, :] represents the lower diagonal whose first inner
285 // dimension will be ignored.
286 // The right-hand-side d is expected to have dimension
287 // [..., num_rhs, num_equations].
288 // The solution will have size [..., num_rhs, num_equations].
ThomasSolver(XlaOp diagonals,XlaOp rhs)289 StatusOr<XlaOp> ThomasSolver(XlaOp diagonals, XlaOp rhs) {
290 XlaBuilder* builder = diagonals.builder();
291 TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals));
292 const int64 rank = diagonals_shape.rank();
293
294 auto upper_diagonal =
295 SliceInDim(diagonals, /*start_index=*/0, /*limit_index=*/1,
296 /*stride=*/1, /*dimno=*/rank - 2);
297 auto main_diagonal =
298 SliceInDim(diagonals, /*start_index=*/1, /*limit_index=*/2,
299 /*stride=*/1, /*dimno=*/rank - 2);
300 auto lower_diagonal =
301 SliceInDim(diagonals, /*start_index=*/2, /*limit_index=*/3,
302 /*stride=*/1, /*dimno=*/rank - 2);
303
304 // TODO(belletti): Get rid of the transposes here.
305 std::vector<int64> transpose_order(rank);
306 std::iota(transpose_order.begin(), transpose_order.end(), 0);
307 transpose_order[rank - 2] = rank - 1;
308 transpose_order[rank - 1] = rank - 2;
309 // Swap the last two dimensions.
310 rhs = Transpose(rhs, transpose_order);
311
312 TF_ASSIGN_OR_RETURN(XlaOp x, ThomasSolver(lower_diagonal, main_diagonal,
313 upper_diagonal, rhs));
314 return Transpose(x, transpose_order);
315 }
316
317 } // namespace tridiagonal
318 } // namespace xla
319