• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17 
18 #include <cstdint>
19 #include <numeric>
20 #include <string>
21 #include <string_view>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/client/lib/constants.h"
26 #include "tensorflow/compiler/xla/client/lib/loops.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace xla {
35 namespace tridiagonal {
36 
37 namespace {
38 
CheckSecondToLastDimension(const Shape & op_shape,int64_t rank,int64_t expected,const std::string & op_name)39 Status CheckSecondToLastDimension(const Shape& op_shape, int64_t rank,
40                                   int64_t expected,
41                                   const std::string& op_name) {
42   const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
43 
44   if (actual_num_dims != expected) {
45     return InvalidArgument(
46         "Second to last dimension of %s should be %d but is %d.", op_name,
47         expected, actual_num_dims);
48   }
49 
50   return OkStatus();
51 }
52 
CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)53 StatusOr<int64_t> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
54                                                    XlaOp main_diagonal,
55                                                    XlaOp upper_diagonal,
56                                                    XlaOp rhs) {
57   XlaBuilder* builder = lower_diagonal.builder();
58 
59   TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
60                       builder->GetShape(lower_diagonal));
61   TF_ASSIGN_OR_RETURN(Shape main_diagonal_shape,
62                       builder->GetShape(main_diagonal));
63   TF_ASSIGN_OR_RETURN(Shape upper_diagonal_shape,
64                       builder->GetShape(upper_diagonal));
65   TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs));
66 
67   const auto lower_diagonal_rank = lower_diagonal_shape.rank();
68   const auto main_diagonal_rank = main_diagonal_shape.rank();
69   const auto upper_diagonal_rank = upper_diagonal_shape.rank();
70   const auto rhs_rank = rhs_shape.rank();
71   if (!((lower_diagonal_rank == main_diagonal_rank) &&
72         (lower_diagonal_rank == upper_diagonal_rank) &&
73         (lower_diagonal_rank == rhs_rank))) {
74     return InvalidArgument(
75         "All inputs should have the same rank but got rank "
76         "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
77         "%d for rhs",
78         lower_diagonal_rank, main_diagonal_rank, upper_diagonal_rank, rhs_rank);
79   }
80   const auto rank = lower_diagonal_rank;
81   if (rank < 2) {
82     return InvalidArgument("Arguments must have rank >=2; got rank %d.", rank);
83   }
84 
85   const auto lower_diagonal_num_eqs =
86       ShapeUtil::GetDimension(lower_diagonal_shape, rank - 1);
87   const auto main_diagonal_num_eqs =
88       ShapeUtil::GetDimension(main_diagonal_shape, rank - 1);
89   const auto upper_diagonal_num_eqs =
90       ShapeUtil::GetDimension(upper_diagonal_shape, rank - 1);
91   const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1);
92   if (!((lower_diagonal_num_eqs == main_diagonal_num_eqs) &&
93         (lower_diagonal_num_eqs == upper_diagonal_num_eqs) &&
94         (lower_diagonal_num_eqs == rhs_num_eqs))) {
95     return InvalidArgument(
96         "All inputs should have the same innermost dimension but got "
97         "%d for lower diagonal, %d for diagonal, %d for upper diagonal, "
98         "%d for rhs",
99         lower_diagonal_num_eqs, main_diagonal_num_eqs, upper_diagonal_num_eqs,
100         rhs_num_eqs);
101   }
102   const auto num_equations = lower_diagonal_num_eqs;
103 
104   TF_RETURN_IF_ERROR(CheckSecondToLastDimension(lower_diagonal_shape, rank, 1,
105                                                 "lower diagonal"));
106   TF_RETURN_IF_ERROR(
107       CheckSecondToLastDimension(main_diagonal_shape, rank, 1, "diagonal"));
108   TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
109                                                 "upper diagonal"));
110 
111   return num_equations;
112 }
113 
114 // Information about matrix with shape [..., M, N].
115 struct TridiagonalMatMulShapeParams {
116   int64_t rank;
117   int64_t m;
118   int64_t n;
119   PrimitiveType element_type;
120 };
121 
ValidateTridiagonalMatMulDiagonal(const Shape & diagonal_shape,const std::string_view diagonal_name,const Shape & rhs_shape)122 Status ValidateTridiagonalMatMulDiagonal(const Shape& diagonal_shape,
123                                          const std::string_view diagonal_name,
124                                          const Shape& rhs_shape) {
125   const int64_t diagonal_rank = diagonal_shape.rank();
126   const int64_t rhs_rank = rhs_shape.rank();
127   if (diagonal_rank != rhs_rank) {
128     return InvalidArgument("%s must have same rank as rhs, but got %d and %d.",
129                            diagonal_name, diagonal_rank, rhs_rank);
130   }
131   for (int64_t i = 0; i < rhs_rank - 2; i++) {
132     const int64_t diagonal_dimension =
133         ShapeUtil::GetDimension(diagonal_shape, i);
134     const int64_t rhs_dimension = ShapeUtil::GetDimension(rhs_shape, i);
135     if (diagonal_dimension != rhs_dimension) {
136       return InvalidArgument(
137           "%s must have same outer dimensions as rhs, but for index %d, got %d "
138           "and %d.",
139           diagonal_name, i, diagonal_dimension, rhs_dimension);
140     }
141   }
142   if (const int64_t digonal_second_last_dimension =
143           ShapeUtil::GetDimension(diagonal_shape, rhs_rank - 2);
144       digonal_second_last_dimension != 1) {
145     return InvalidArgument(
146         "%s's second-to-last dimension must be 1, but got %d.", diagonal_name,
147         digonal_second_last_dimension);
148   }
149 
150   const int64_t digonal_last_dimension =
151       ShapeUtil::GetDimension(diagonal_shape, rhs_rank - 1);
152   const int64_t rhs_second_last_dimension =
153       ShapeUtil::GetDimension(rhs_shape, rhs_rank - 2);
154   if (digonal_last_dimension != rhs_second_last_dimension) {
155     return InvalidArgument(
156         "%s's last dimension size must be rhs's second-to-last dimension size, "
157         "but got %d and %d.",
158         diagonal_name, digonal_last_dimension, rhs_second_last_dimension);
159   }
160   return Status::OK();
161 }
162 
CheckMatMulSystemAndReturnShapeParams(XlaOp upper_diagonal,XlaOp main_diagonal,XlaOp lower_diagonal,XlaOp rhs)163 StatusOr<TridiagonalMatMulShapeParams> CheckMatMulSystemAndReturnShapeParams(
164     XlaOp upper_diagonal, XlaOp main_diagonal, XlaOp lower_diagonal,
165     XlaOp rhs) {
166   XlaBuilder* builder = upper_diagonal.builder();
167 
168   TF_ASSIGN_OR_RETURN(const Shape upper_diagonal_shape,
169                       builder->GetShape(upper_diagonal));
170   TF_ASSIGN_OR_RETURN(const Shape main_diagonal_shape,
171                       builder->GetShape(main_diagonal));
172   TF_ASSIGN_OR_RETURN(const Shape lower_diagonal_shape,
173                       builder->GetShape(lower_diagonal));
174   TF_ASSIGN_OR_RETURN(const Shape rhs_shape, builder->GetShape(rhs));
175 
176   const int64_t rank = rhs_shape.rank();
177   if (rank < 2) {
178     return InvalidArgument("Input must have rank >= 2, but got %d.", rank);
179   }
180 
181   TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(upper_diagonal_shape,
182                                                        "superdiag", rhs_shape));
183   TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(main_diagonal_shape,
184                                                        "maindiag", rhs_shape));
185   TF_RETURN_IF_ERROR(ValidateTridiagonalMatMulDiagonal(lower_diagonal_shape,
186                                                        "subdiag", rhs_shape));
187 
188   const int64_t rhs_height = ShapeUtil::GetDimension(rhs_shape, rank - 2);
189   const int64_t rhs_width = ShapeUtil::GetDimension(rhs_shape, rank - 1);
190 
191   TridiagonalMatMulShapeParams shape_params;
192   shape_params.rank = rank;
193   shape_params.m = rhs_height;
194   shape_params.n = rhs_width;
195   shape_params.element_type = rhs_shape.element_type();
196   return shape_params;
197 }
198 
Coefficient(XlaOp operand,int32_t i)199 XlaOp Coefficient(XlaOp operand, int32_t i) {
200   return DynamicSliceInMinorDims(operand,
201                                  /*starts=*/{ConstantR0(operand.builder(), i)},
202                                  /*sizes=*/{1});
203 }
204 
Coefficient(XlaOp operand,XlaOp i)205 XlaOp Coefficient(XlaOp operand, XlaOp i) {
206   return DynamicSliceInMinorDims(operand,
207                                  /*starts=*/{i}, /*sizes=*/{1});
208 }
209 
UpdateEq(XlaOp updated,int32_t i,XlaOp update)210 XlaOp UpdateEq(XlaOp updated, int32_t i, XlaOp update) {
211   return DynamicUpdateSliceInMinorDims(
212       updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
213 }
214 
UpdateEq(XlaOp updated,XlaOp i,XlaOp update)215 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
216   return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
217 }
218 
219 template <SolverAlgorithm algo>
220 StatusOr<XlaOp> TridiagonalSolverImpl(XlaOp lower_diagonal, XlaOp main_diagonal,
221                                       XlaOp upper_diagonal, XlaOp rhs);
222 
223 // Applies Thomas algorithm to solve a linear system where the linear operand
224 // is a tri-diagonal matrix.
225 // See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm for a simple
226 // reference on the Thomas algorithm.
227 // It is expected that the three diagonals are represented as tensors of shape
228 // [..., 1, num_equations] where num_equations is the number of dimensions of
229 // the unknowns considered in the linear systems.
230 // The first innermost dimension of `lower_diagonal` (`lower_diagonal[..., :,
231 // 0]`) will be ignored. The last innermost dimension of `upper_diagonal`
232 // (`upper_diagonal[..., :, num_equations - 1]`) will be ignored. The shape of
233 // the right-hand-side `rhs` should be [..., num_rhs, num_equations]. The
234 // solution will have the shape [..., num_rhs, num_equations].
235 template <>
TridiagonalSolverImpl(XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)236 StatusOr<XlaOp> TridiagonalSolverImpl<kThomas>(XlaOp lower_diagonal,
237                                                XlaOp main_diagonal,
238                                                XlaOp upper_diagonal,
239                                                XlaOp rhs) {
240   XlaBuilder* builder = lower_diagonal.builder();
241 
242   TF_ASSIGN_OR_RETURN(int64_t num_eqs,
243                       CheckSystemAndReturnNumEquations(
244                           lower_diagonal, main_diagonal, upper_diagonal, rhs));
245 
246   XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
247   XlaOp rhs_after_elimination = ZerosLike(rhs);
248   XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
249   XlaOp x_coeffs = ZerosLike(rhs);
250 
251   // main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
252   main_diag_after_elimination =
253       UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
254 
255   // rhs_after_elimination[:, 0] = rhs[:, 0];
256   rhs_after_elimination =
257       UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
258 
259   auto preparation_body_fn =
260       [](XlaOp i, absl::Span<const XlaOp> values,
261          XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
262     auto upper_diagonal_coeffs = values[0];
263     auto upper_diagonal = values[1];
264     // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
265     upper_diagonal_coeffs =
266         UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
267     return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
268   };
269   TF_ASSIGN_OR_RETURN(auto values_after_preparation,
270                       ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
271                                    {upper_diagonal_coeffs, upper_diagonal},
272                                    "preparation", builder));
273   upper_diagonal_coeffs = values_after_preparation[0];
274 
275   // Forward transformation.
276   auto forward_transformation_fn =
277       [](XlaOp i_minus_one, absl::Span<const XlaOp> values,
278          XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
279     auto lower_diagonal = values[0];
280     auto main_diagonal = values[1];
281     auto rhs = values[2];
282     auto main_diag_after_elimination = values[3];
283     auto upper_diagonal_coeffs = values[4];
284     auto rhs_after_elimination = values[5];
285 
286     auto one = ScalarLike(i_minus_one, 1);
287     auto i = i_minus_one + one;
288     auto lower_diagonal_i = Coefficient(lower_diagonal, i);
289     auto main_diagonal_i = Coefficient(main_diagonal, i);
290     auto rhs_i = Coefficient(rhs, i);
291 
292     auto w_i =
293         lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
294 
295     // main_diag_after_elimination[:, i] =
296     //     main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
297     main_diag_after_elimination = UpdateEq(
298         main_diag_after_elimination, i,
299         main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
300     // rhs_after_elimination[:, i] =
301     //     rhs_i - w_i * rhs_after_elimination[:, i - 1];
302     rhs_after_elimination =
303         UpdateEq(rhs_after_elimination, i,
304                  rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
305 
306     return std::vector<XlaOp>{lower_diagonal,
307                               main_diagonal,
308                               rhs,
309                               main_diag_after_elimination,
310                               upper_diagonal_coeffs,
311                               rhs_after_elimination};
312   };
313   TF_ASSIGN_OR_RETURN(
314       auto values_after_fwd_transformation,
315       ForEachIndex(
316           num_eqs - 1, S32, forward_transformation_fn,
317           {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
318            upper_diagonal_coeffs, rhs_after_elimination},
319           "forward_transformation", builder));
320   lower_diagonal = values_after_fwd_transformation[0];
321   main_diagonal = values_after_fwd_transformation[1];
322   rhs = values_after_fwd_transformation[2];
323   main_diag_after_elimination = values_after_fwd_transformation[3];
324   upper_diagonal_coeffs = values_after_fwd_transformation[4];
325   rhs_after_elimination = values_after_fwd_transformation[5];
326 
327   // Backward reduction.
328   // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
329   //                              main_diag_after_elimination[:, num_eqs - 1];
330   x_coeffs =
331       UpdateEq(x_coeffs, num_eqs - 1,
332                Coefficient(rhs_after_elimination, num_eqs - 1) /
333                    Coefficient(main_diag_after_elimination, num_eqs - 1));
334   auto bwd_reduction_fn =
335       [num_eqs](XlaOp j, absl::Span<const XlaOp> values,
336                 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
337     auto x_coeffs = values[0];
338     auto rhs_after_elimination = values[1];
339     auto upper_diagonal_coeffs = values[2];
340     auto main_diag_after_elimination = values[3];
341     auto n = ScalarLike(j, num_eqs - 2);
342     auto one = ScalarLike(j, 1);
343     auto i = n - j;
344     // for (int i = num_eqs - 2; i >= 0; i--)
345     //   x_coeffs[:, i] = (rhs_after_elimination[:, i] -
346     //     upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
347     //       main_diag_after_elimination[:, i];
348     x_coeffs = UpdateEq(x_coeffs, i,
349                         (Coefficient(rhs_after_elimination, i) -
350                          Coefficient(upper_diagonal_coeffs, i) *
351                              Coefficient(x_coeffs, i + one)) /
352                             Coefficient(main_diag_after_elimination, i));
353     return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
354                               upper_diagonal_coeffs,
355                               main_diag_after_elimination};
356   };
357 
358   TF_ASSIGN_OR_RETURN(
359       auto values_after_bwd_reduction,
360       ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
361                    {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
362                     main_diag_after_elimination},
363                    "backward_reduction", builder));
364   x_coeffs = values_after_bwd_reduction[0];
365 
366   return x_coeffs;
367 }
368 
369 }  // namespace
370 
TridiagonalSolver(SolverAlgorithm algo,XlaOp lower_diagonal,XlaOp main_diagonal,XlaOp upper_diagonal,XlaOp rhs)371 StatusOr<XlaOp> TridiagonalSolver(SolverAlgorithm algo, XlaOp lower_diagonal,
372                                   XlaOp main_diagonal, XlaOp upper_diagonal,
373                                   XlaOp rhs) {
374   switch (algo) {
375     case kThomas:
376       return TridiagonalSolverImpl<kThomas>(lower_diagonal, main_diagonal,
377                                             upper_diagonal, rhs);
378     default:
379       return Unimplemented(
380           "Only algorithm kThomas (%d) is implemented, got: %d",
381           static_cast<int>(kThomas), algo);
382   }
383 }
384 
385 // Solves a linear system where the linear operand is a tri-diagonal matrix.
386 // It is expected that the tree diagonals are stacked into a tensors of shape
387 // [..., 3, num_equations] where num_equations is the number of spatial
388 // dimensions considered in the system.
389 // diagonals[..., 0, :] represents the upper diagonal whose last inner
390 // dimension will be ignored.
391 // diagonals[..., 1, :] represents the main diagonal.
392 // diagonals[..., 2, :] represents the lower diagonal whose first inner
393 // dimension will be ignored.
394 // The right-hand-side d is expected to have dimension
395 // [..., num_rhs, num_equations].
396 // The solution will have size [..., num_rhs, num_equations].
TridiagonalSolver(SolverAlgorithm algo,XlaOp diagonals,XlaOp rhs)397 StatusOr<XlaOp> TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals,
398                                   XlaOp rhs) {
399   XlaBuilder* builder = diagonals.builder();
400   TF_ASSIGN_OR_RETURN(Shape diagonals_shape, builder->GetShape(diagonals));
401   const int64_t rank = diagonals_shape.rank();
402 
403   auto upper_diagonal =
404       SliceInDim(diagonals, /*start_index=*/0, /*limit_index=*/1,
405                  /*stride=*/1, /*dimno=*/rank - 2);
406   auto main_diagonal =
407       SliceInDim(diagonals, /*start_index=*/1, /*limit_index=*/2,
408                  /*stride=*/1, /*dimno=*/rank - 2);
409   auto lower_diagonal =
410       SliceInDim(diagonals, /*start_index=*/2, /*limit_index=*/3,
411                  /*stride=*/1, /*dimno=*/rank - 2);
412 
413   // TODO(belletti): Get rid of the transposes here.
414   std::vector<int64_t> transpose_order(rank);
415   std::iota(transpose_order.begin(), transpose_order.end(), 0);
416   transpose_order[rank - 2] = rank - 1;
417   transpose_order[rank - 1] = rank - 2;
418   // Swap the last two dimensions.
419   rhs = Transpose(rhs, transpose_order);
420 
421   switch (algo) {
422     case kThomas: {
423       TF_ASSIGN_OR_RETURN(
424           XlaOp x, TridiagonalSolverImpl<kThomas>(lower_diagonal, main_diagonal,
425                                                   upper_diagonal, rhs));
426       return Transpose(x, transpose_order);
427     }
428     default:
429       return Unimplemented(
430           "Only algorithm kThomas (%d) is implemented, got: %d",
431           static_cast<int>(kThomas), algo);
432   }
433 }
434 
435 // Multiplies tridiagonal matrix by matrix.
436 // `upper_diagonal` is expected to have dimension [..., 1, M]. Element
437 // [..., M - 1] is ignored.
438 // `main_diagonal` is expected to have dimension [..., 1, M].
439 // `lower_diagonal` is expected to have dimension [..., 1, M]. Element
440 // [..., 0] is ignored.
441 // The `right-hand-side` is expected to have dimension [..., M, N].
442 // The solution will have size [..., M, N].
TridiagonalMatMul(XlaOp upper_diagonal,XlaOp main_diagonal,XlaOp lower_diagonal,XlaOp rhs)443 StatusOr<XlaOp> TridiagonalMatMul(XlaOp upper_diagonal, XlaOp main_diagonal,
444                                   XlaOp lower_diagonal, XlaOp rhs) {
445   TF_ASSIGN_OR_RETURN(const TridiagonalMatMulShapeParams shape_params,
446                       CheckMatMulSystemAndReturnShapeParams(
447                           upper_diagonal, main_diagonal, lower_diagonal, rhs));
448   XlaBuilder* builder = main_diagonal.builder();
449 
450   std::vector<int64_t> broadcasted_dims(shape_params.rank);
451   std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
452   std::vector<int64_t> transpose_dims = broadcasted_dims;
453   std::swap(transpose_dims[shape_params.rank - 2],
454             transpose_dims[shape_params.rank - 1]);
455 
456   // Shape [..., 1, M] -> [..., M, 1]
457   main_diagonal = xla::Transpose(main_diagonal, transpose_dims);
458   XlaOp diag_part = xla::Mul(main_diagonal, rhs, broadcasted_dims);
459 
460   upper_diagonal = SliceInMinorDims(upper_diagonal, /*start=*/{0},
461                                     /*end=*/{shape_params.m - 1});
462   upper_diagonal = xla::Transpose(upper_diagonal, transpose_dims);
463   XlaOp adjusted_upper_rhs = SliceInMinorDims(
464       rhs, /*start=*/{1, 0}, /*end=*/{shape_params.m, shape_params.n});
465   XlaOp upper_diag_part =
466       xla::Mul(upper_diagonal, adjusted_upper_rhs, broadcasted_dims);
467   upper_diag_part = xla::PadInDim(
468       upper_diag_part, xla::Zero(builder, shape_params.element_type),
469       /*dimno=*/shape_params.rank - 2, /*pad_lo=*/0, /*pad_hi=*/1);
470 
471   lower_diagonal = SliceInMinorDims(lower_diagonal, /*start=*/{1},
472                                     /*end=*/{shape_params.m});
473   lower_diagonal = xla::Transpose(lower_diagonal, transpose_dims);
474   XlaOp adjusted_lower_rhs = SliceInMinorDims(
475       rhs, /*start=*/{0, 0}, /*end=*/{shape_params.m - 1, shape_params.n});
476   XlaOp lower_diag_part =
477       xla::Mul(lower_diagonal, adjusted_lower_rhs, broadcasted_dims);
478   lower_diag_part = xla::PadInDim(
479       lower_diag_part, xla::Zero(builder, shape_params.element_type),
480       /*dimno=*/shape_params.rank - 2, /*pad_lo=*/1, /*pad_hi=*/0);
481 
482   return diag_part + upper_diag_part + lower_diag_part;
483 }
484 
485 }  // namespace tridiagonal
486 }  // namespace xla
487