1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17
18 #include <cstdint>
19 #include <numeric>
20 #include <string>
21 #include <string_view>
22 #include <vector>
23
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/loops.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33
34 namespace xla {
35 namespace tridiagonal {
36
37 namespace {
38
CheckSecondToLastDimension(const Shape & op_shape,int64_t rank,int64_t expected,const std::string & op_name)39 Status CheckSecondToLastDimension(const Shape& op_shape, int64_t rank,
40 int64_t expected,
41 const std::string& op_name) {
42 const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
43
44 if (actual_num_dims != expected) {
45 return InvalidArgument(
46 "Second to last dimension of %s should be %d but is %d.", op_name,
47 expected, actual_num_dims);
48 }
49
50 return OkStatus();
51 }
52
CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)53 StatusOr<int64_t> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
54 XlaOp main_diagonal,
55 XlaOp upper_diagonal,
56 XlaOp rhs) {
57 XlaBuilder* builder = lower_diagonal.builder();
58
59 TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
60 builder->GetShape(lower_diagonal));
61 TF_ASSIGN_OR_RETURN(Shape main_diagonal_shape,
62 builder->GetShape(main_diagonal));
63 TF_ASSIGN_OR_RETURN(Shape upper_diagonal_shape,
64 builder->GetShape(upper_diagonal));
65 TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs));
66
67 const auto lower_diagonal_rank = lower_diagonal_shape.rank();
68 const auto main_diagonal_rank = main_diagonal_shape.rank();
69 const auto upper_diagonal_rank = upper_diagonal_shape.rank();
70 const auto rhs_rank = rhs_shape.rank();
71 if (!((lower_diagonal_rank == main_diagonal_rank) &&
72 (lower_diagonal_rank == upper_diagonal_rank) &&
73 (lower_diagonal_rank == rhs_rank))) {
74 return InvalidArgument(
75 "All inputs should have the same rank but got rank "
76 "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
77 "%d for rhs",
78 lower_diagonal_rank, main_diagonal_rank, upper_diagonal_rank, rhs_rank);
79 }
80 const auto rank = lower_diagonal_rank;
81 if (rank < 2) {
82 return InvalidArgument("Arguments must have rank >=2; got rank %d.", rank);
83 }
84
85 const auto lower_diagonal_num_eqs =
86 ShapeUtil::GetDimension(lower_diagonal_shape, rank - 1);
87 const auto main_diagonal_num_eqs =
88 ShapeUtil::GetDimension(main_diagonal_shape, rank - 1);
89 const auto upper_diagonal_num_eqs =
90 ShapeUtil::GetDimension(upper_diagonal_shape, rank - 1);
91 const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1);
92 if (!((lower_diagonal_num_eqs == main_diagonal_num_eqs) &&
93 (lower_diagonal_num_eqs == upper_diagonal_num_eqs) &&
94 (lower_diagonal_num_eqs == rhs_num_eqs))) {
95 return InvalidArgument(
96 "All inputs should have the same innermost dimension but got "
97 "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
98 "%d for rhs",
99 lower_diagonal_num_eqs, main_diagonal_num_eqs, upper_diagonal_num_eqs,
100 rhs_num_eqs);
101 }
102 const auto num_equations = lower_diagonal_num_eqs;
103
104 TF_RETURN_IF_ERROR(CheckSecondToLastDimension(lower_diagonal_shape, rank, 1,
105 "lower diagonal"));
106 TF_RETURN_IF_ERROR(
107 CheckSecondToLastDimension(main_diagonal_shape, rank, 1, "diagonal"));
108 TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
109 "upper diagonal"));
110
111 return num_equations;
112 }
113
114 // Information about matrix with shape [..., M, N].
115 struct TridiagonalMatMulShapeParams {
116 int64_t rank;
117 int64_t m;
118 int64_t n;
119 PrimitiveType element_type;
120 };
121
ValidateTridiagonalMatMulDiagonal(const Shape & diagonal_shape,const std::string_view diagonal_name,const Shape & rhs_shape)122 Status ValidateTridiagonalMatMulDiagonal(const Shape& diagonal_shape,
123 const std::string_view diagonal_name,
124 const Shape& rhs_shape) {
125 const int64_t diagonal_rank = diagonal_shape.rank();
126 const int64_t rhs_rank = rhs_shape.rank();
127 if (diagonal_rank != rhs_rank) {
128 return InvalidArgument("%s must have same rank as rhs, but got %d and %d.",
129 diagonal_name, diagonal_rank, rhs_rank);
130 }
131 for (int64_t i = 0; i < rhs_rank - 2; i++) {
132 const int64_t diagonal_dimension =
133 ShapeUtil::GetDimension(diagonal_shape, i);
134 const int64_t rhs_dimension = ShapeUtil::GetDimension(rhs_shape, i);
135 if (diagonal_dimension != rhs_dimension) {
136 return InvalidArgument(
137 "%s must have same outer dimensions as rhs, but for index %d, got %d "
138 "and %d.",
139 diagonal_name, i, diagonal_dimension, rhs_dimension);
140 }
141 }
142 if (const int64_t digonal_second_last_dimension =
143 ShapeUtil::GetDimension(diagonal_shape, rhs_rank - 2);
144 digonal_second_last_dimension != 1) {
145 return InvalidArgument(
146 "%s's second-to-last dimension must be 1, but got %d.", diagonal_name,
147 digonal_second_last_dimension);
148 }
149
150 const int64_t digonal_last_dimension =
151 ShapeUtil::GetDimension(diagonal_shape, rhs_rank - 1);
152 const int64_t rhs_second_last_dimension =
153 ShapeUtil::GetDimension(rhs_shape, rhs_rank - 2);
154 if (digonal_last_dimension != rhs_second_last_dimension) {
155 return InvalidArgument(
156 "%s's last dimension size must be rhs's second-to-last dimension size, "
157 "but got %d and %d.",
158 diagonal_name, digonal_last_dimension, rhs_second_last_dimension);
159 }
160 return Status::OK();
161 }
162
CheckMatMulSystemAndReturnShapeParams(XlaOp upper_diagonal,XlaOp main_diagonal,XlaOp lower_diagonal,XlaOp rhs)163 StatusOr<TridiagonalMatMulShapeParams> CheckMatMulSystemAndReturnShapeParams(
164 XlaOp upper_diagonal, XlaOp main_diagonal, XlaOp lower_diagonal,
165 XlaOp rhs) {
166 XlaBuilder* builder = upper_diagonal.builder();
167
168 TF_ASSIGN_OR_RETURN(const Shape upper_diagonal_shape,
169 builder->GetShape(upper_diagonal));
170 TF_ASSIGN_OR_RETURN(const Shape main_diagonal_shape,
171 builder->GetShape(main_diagonal));
172 TF_ASSIGN_OR_RETURN(const Shape lower_diagonal_shape,
173 builder->GetShape(lower_diagonal));
174 TF_ASSIGN_OR_RETURN(const Shape rhs_shape, builder->GetShape(rhs));
175
176 const int64_t rank = rhs_shape.rank();
177 if (rank < 2) {
178 return InvalidArgument("Input must have rank >= 2, but got %d.", rank);
179 }
180
181 TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(upper_diagonal_shape,
182 "superdiag", rhs_shape));
183 TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(main_diagonal_shape,
184 "maindiag", rhs_shape));
185 TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(lower_diagonal_shape,
186 "subdiag", rhs_shape));
187
188 const int64_t rhs_height = ShapeUtil::GetDimension(rhs_shape, rank - 2);
189 const int64_t rhs_width = ShapeUtil::GetDimension(rhs_shape, rank - 1);
190
191 TridiagonalMatMulShapeParams shape_params;
192 shape_params.rank = rank;
193 shape_params.m = rhs_height;
194 shape_params.n = rhs_width;
195 shape_params.element_type = rhs_shape.element_type();
196 return shape_params;
197 }
198
Coefficient(XlaOp operand,int32_t i)199 XlaOp Coefficient(XlaOp operand, int32_t i) {
200 return DynamicSliceInMinorDims(operand,
201 /*starts=*/{ConstantR0(operand.builder(), i)},
202 /*sizes=*/{1});
203 }
204
Coefficient(XlaOp operand,XlaOp i)205 XlaOp Coefficient(XlaOp operand, XlaOp i) {
206 return DynamicSliceInMinorDims(operand,
207 /*starts=*/{i}, /*sizes=*/{1});
208 }
209
UpdateEq(XlaOp updated,int32_t i,XlaOp update)210 XlaOp UpdateEq(XlaOp updated, int32_t i, XlaOp update) {
211 return DynamicUpdateSliceInMinorDims(
212 updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
213 }
214
UpdateEq(XlaOp updated,XlaOp i,XlaOp update)215 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
216 return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
217 }
218
219 template <SolverAlgorithm algo>
220 StatusOr<XlaOp> TridiagonalSolverImpl(XlaOp lower_diagonal, XlaOp main_diagonal,
221 XlaOp upper_diagonal, XlaOp rhs);
222
223 // Applies Thomas algorithm to solve a linear system where the linear operand
224 // is a tri-diagonal matrix.
225 // See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm for a simple
226 // reference on the Thomas algorithm.
227 // It is expected that the three diagonals are represented as tensors of shape
228 // [..., 1, num_equations] where num_equations is the number of dimensions of
229 // the unknowns considered in the linear systems.
230 // The first innermost dimension of `lower_diagonal` (`lower_diagonal[..., :,
231 // 0]`) will be ignored. The last innermost dimension of `upper_diagonal`
232 // (`upper_diagonal[..., :, num_equations - 1]`) will be ignored. The shape of
233 // the right-hand-side `rhs` should be [..., num_rhs, num_equations]. The
234 // solution will have the shape [..., num_rhs, num_equations].
235 template <>
TridiagonalSolverImpl(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)236 StatusOr<XlaOp> TridiagonalSolverImpl<kThomas>(XlaOp lower_diagonal,
237 XlaOp main_diagonal,
238 XlaOp upper_diagonal,
239 XlaOp rhs) {
240 XlaBuilder* builder = lower_diagonal.builder();
241
242 TF_ASSIGN_OR_RETURN(int64_t num_eqs,
243 CheckSystemAndReturnNumEquations(
244 lower_diagonal, main_diagonal, upper_diagonal, rhs));
245
246 XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
247 XlaOp rhs_after_elimination = ZerosLike(rhs);
248 XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
249 XlaOp x_coeffs = ZerosLike(rhs);
250
251 // main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
252 main_diag_after_elimination =
253 UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
254
255 // rhs_after_elimination[:, 0] = rhs[:, 0];
256 rhs_after_elimination =
257 UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
258
259 auto preparation_body_fn =
260 [](XlaOp i, absl::Span<const XlaOp> values,
261 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
262 auto upper_diagonal_coeffs = values[0];
263 auto upper_diagonal = values[1];
264 // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
265 upper_diagonal_coeffs =
266 UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
267 return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
268 };
269 TF_ASSIGN_OR_RETURN(auto values_after_preparation,
270 ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
271 {upper_diagonal_coeffs, upper_diagonal},
272 "preparation", builder));
273 upper_diagonal_coeffs = values_after_preparation[0];
274
275 // Forward transformation.
276 auto forward_transformation_fn =
277 [](XlaOp i_minus_one, absl::Span<const XlaOp> values,
278 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
279 auto lower_diagonal = values[0];
280 auto main_diagonal = values[1];
281 auto rhs = values[2];
282 auto main_diag_after_elimination = values[3];
283 auto upper_diagonal_coeffs = values[4];
284 auto rhs_after_elimination = values[5];
285
286 auto one = ScalarLike(i_minus_one, 1);
287 auto i = i_minus_one + one;
288 auto lower_diagonal_i = Coefficient(lower_diagonal, i);
289 auto main_diagonal_i = Coefficient(main_diagonal, i);
290 auto rhs_i = Coefficient(rhs, i);
291
292 auto w_i =
293 lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
294
295 // main_diag_after_elimination[:, i] =
296 // main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
297 main_diag_after_elimination = UpdateEq(
298 main_diag_after_elimination, i,
299 main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
300 // rhs_after_elimination[:, i] =
301 // rhs_i - w_i * rhs_after_elimination[:, i - 1];
302 rhs_after_elimination =
303 UpdateEq(rhs_after_elimination, i,
304 rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
305
306 return std::vector<XlaOp>{lower_diagonal,
307 main_diagonal,
308 rhs,
309 main_diag_after_elimination,
310 upper_diagonal_coeffs,
311 rhs_after_elimination};
312 };
313 TF_ASSIGN_OR_RETURN(
314 auto values_after_fwd_transformation,
315 ForEachIndex(
316 num_eqs - 1, S32, forward_transformation_fn,
317 {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
318 upper_diagonal_coeffs, rhs_after_elimination},
319 "forward_transformation", builder));
320 lower_diagonal = values_after_fwd_transformation[0];
321 main_diagonal = values_after_fwd_transformation[1];
322 rhs = values_after_fwd_transformation[2];
323 main_diag_after_elimination = values_after_fwd_transformation[3];
324 upper_diagonal_coeffs = values_after_fwd_transformation[4];
325 rhs_after_elimination = values_after_fwd_transformation[5];
326
327 // Backward reduction.
328 // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
329 // main_diag_after_elimination[:, num_eqs - 1];
330 x_coeffs =
331 UpdateEq(x_coeffs, num_eqs - 1,
332 Coefficient(rhs_after_elimination, num_eqs - 1) /
333 Coefficient(main_diag_after_elimination, num_eqs - 1));
334 auto bwd_reduction_fn =
335 [num_eqs](XlaOp j, absl::Span<const XlaOp> values,
336 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
337 auto x_coeffs = values[0];
338 auto rhs_after_elimination = values[1];
339 auto upper_diagonal_coeffs = values[2];
340 auto main_diag_after_elimination = values[3];
341 auto n = ScalarLike(j, num_eqs - 2);
342 auto one = ScalarLike(j, 1);
343 auto i = n - j;
344 // for (int i = num_eqs - 2; i >= 0; i--)
345 // x_coeffs[:, i] = (rhs_after_elimination[:, i] -
346 // upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
347 // main_diag_after_elimination[:, i];
348 x_coeffs = UpdateEq(x_coeffs, i,
349 (Coefficient(rhs_after_elimination, i) -
350 Coefficient(upper_diagonal_coeffs, i) *
351 Coefficient(x_coeffs, i + one)) /
352 Coefficient(main_diag_after_elimination, i));
353 return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
354 upper_diagonal_coeffs,
355 main_diag_after_elimination};
356 };
357
358 TF_ASSIGN_OR_RETURN(
359 auto values_after_bwd_reduction,
360 ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
361 {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
362 main_diag_after_elimination},
363 "backward_reduction", builder));
364 x_coeffs = values_after_bwd_reduction[0];
365
366 return x_coeffs;
367 }
368
369 } // namespace
370
TridiagonalSolver(SolverAlgorithm algo,XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)371 StatusOr<XlaOp> TridiagonalSolver(SolverAlgorithm algo, XlaOp lower_diagonal,
372 XlaOp main_diagonal, XlaOp upper_diagonal,
373 XlaOp rhs) {
374 switch (algo) {
375 case kThomas:
376 return TridiagonalSolverImpl<kThomas>(lower_diagonal, main_diagonal,
377 upper_diagonal, rhs);
378 default:
379 return Unimplemented(
380 "Only algorithm kThomas (%d) is implemented, got: %d",
381 static_cast<int>(kThomas), algo);
382 }
383 }
384
385 // Solves a linear system where the linear operand is a tri-diagonal matrix.
386 // It is expected that the tree diagonals are stacked into a tensors of shape
387 // [..., 3, num_equations] where num_equations is the number of spatial
388 // dimensions considered in the system.
389 // diagonals[..., 0, :] represents the upper diagonal whose last inner
390 // dimension will be ignored.
391 // diagonals[..., 1, :] represents the main diagonal.
392 // diagonals[..., 2, :] represents the lower diagonal whose first inner
393 // dimension will be ignored.
394 // The right-hand-side d is expected to have dimension
395 // [..., num_rhs, num_equations].
396 // The solution will have size [..., num_rhs, num_equations].
TridiagonalSolver(SolverAlgorithm algo,XlaOp diagonals,XlaOp rhs)397 StatusOr<XlaOp> TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals,
398 XlaOp rhs) {
399 XlaBuilder* builder = diagonals.builder();
400 TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals));
401 const int64_t rank = diagonals_shape.rank();
402
403 auto upper_diagonal =
404 SliceInDim(diagonals, /*start_index=*/0, /*limit_index=*/1,
405 /*stride=*/1, /*dimno=*/rank - 2);
406 auto main_diagonal =
407 SliceInDim(diagonals, /*start_index=*/1, /*limit_index=*/2,
408 /*stride=*/1, /*dimno=*/rank - 2);
409 auto lower_diagonal =
410 SliceInDim(diagonals, /*start_index=*/2, /*limit_index=*/3,
411 /*stride=*/1, /*dimno=*/rank - 2);
412
413 // TODO(belletti): Get rid of the transposes here.
414 std::vector<int64_t> transpose_order(rank);
415 std::iota(transpose_order.begin(), transpose_order.end(), 0);
416 transpose_order[rank - 2] = rank - 1;
417 transpose_order[rank - 1] = rank - 2;
418 // Swap the last two dimensions.
419 rhs = Transpose(rhs, transpose_order);
420
421 switch (algo) {
422 case kThomas: {
423 TF_ASSIGN_OR_RETURN(
424 XlaOp x, TridiagonalSolverImpl<kThomas>(lower_diagonal, main_diagonal,
425 upper_diagonal, rhs));
426 return Transpose(x, transpose_order);
427 }
428 default:
429 return Unimplemented(
430 "Only algorithm kThomas (%d) is implemented, got: %d",
431 static_cast<int>(kThomas), algo);
432 }
433 }
434
435 // Multiplies tridiagonal matrix by matrix.
436 // `upper_diagonal` is expected to have dimension [..., 1, M]. Element
437 // [..., M - 1] is ignored.
438 // `main_diagonal` is expected to have dimension [..., 1, M].
439 // `lower_diagonal` is expected to have dimension [..., 1, M]. Element
440 // [..., 0] is ignored.
441 // The `right-hand-side` is expected to have dimension [..., M, N].
442 // The solution will have size [..., M, N].
TridiagonalMatMul(XlaOp upper_diagonal,XlaOp main_diagonal,XlaOp lower_diagonal,XlaOp rhs)443 StatusOr<XlaOp> TridiagonalMatMul(XlaOp upper_diagonal, XlaOp main_diagonal,
444 XlaOp lower_diagonal, XlaOp rhs) {
445 TF_ASSIGN_OR_RETURN(const TridiagonalMatMulShapeParams shape_params,
446 CheckMatMulSystemAndReturnShapeParams(
447 upper_diagonal, main_diagonal, lower_diagonal, rhs));
448 XlaBuilder* builder = main_diagonal.builder();
449
450 std::vector<int64_t> broadcasted_dims(shape_params.rank);
451 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
452 std::vector<int64_t> transpose_dims = broadcasted_dims;
453 std::swap(transpose_dims[shape_params.rank - 2],
454 transpose_dims[shape_params.rank - 1]);
455
456 // Shape [..., 1, M] -> [..., M, 1]
457 main_diagonal = xla::Transpose(main_diagonal, transpose_dims);
458 XlaOp diag_part = xla::Mul(main_diagonal, rhs, broadcasted_dims);
459
460 upper_diagonal = SliceInMinorDims(upper_diagonal, /*start=*/{0},
461 /*end=*/{shape_params.m - 1});
462 upper_diagonal = xla::Transpose(upper_diagonal, transpose_dims);
463 XlaOp adjusted_upper_rhs = SliceInMinorDims(
464 rhs, /*start=*/{1, 0}, /*end=*/{shape_params.m, shape_params.n});
465 XlaOp upper_diag_part =
466 xla::Mul(upper_diagonal, adjusted_upper_rhs, broadcasted_dims);
467 upper_diag_part = xla::PadInDim(
468 upper_diag_part, xla::Zero(builder, shape_params.element_type),
469 /*dimno=*/shape_params.rank - 2, /*pad_lo=*/0, /*pad_hi=*/1);
470
471 lower_diagonal = SliceInMinorDims(lower_diagonal, /*start=*/{1},
472 /*end=*/{shape_params.m});
473 lower_diagonal = xla::Transpose(lower_diagonal, transpose_dims);
474 XlaOp adjusted_lower_rhs = SliceInMinorDims(
475 rhs, /*start=*/{0, 0}, /*end=*/{shape_params.m - 1, shape_params.n});
476 XlaOp lower_diag_part =
477 xla::Mul(lower_diagonal, adjusted_lower_rhs, broadcasted_dims);
478 lower_diag_part = xla::PadInDim(
479 lower_diag_part, xla::Zero(builder, shape_params.element_type),
480 /*dimno=*/shape_params.rank - 2, /*pad_lo=*/1, /*pad_hi=*/0);
481
482 return diag_part + upper_diag_part + lower_diag_part;
483 }
484
485 } // namespace tridiagonal
486 } // namespace xla
487