• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 
35 namespace xla {
36 
37 namespace {
38 
39 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64_t block_size)40 XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) {
41   XlaBuilder* builder = a.builder();
42   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
43     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
44     int ndims = shape.rank();
45     int64_t n = ShapeUtil::GetDimension(shape, -1);
46     int64_t num_blocks = n / block_size;
47     absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
48         shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
49 
50     XlaOp diag_blocks;
51 
52     // If the coefficient matrix is exactly the block size, we just add a
53     // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
54     if (n == block_size) {
55       std::vector<int64> permutation(ndims);
56       std::iota(permutation.begin(), permutation.end(), 1);
57       permutation.insert(permutation.end() - 2, 0);
58       return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
59     }
60 
61     // We can grab entire blocks using gather
62     if (n > block_size) {
63       // Construct the starting indices of the diagonal blocks
64       auto start_indices =
65           Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
66                                   ConstantR0<int32>(builder, block_size)),
67                               /*broadcast_sizes=*/{2}),
68                     /*permutation=*/{1, 0});
69 
70       PaddingConfig padding_config =
71           MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
72       start_indices =
73           Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
74 
75       // Gather the diagonal blocks
76       std::vector<int64> slice_sizes(ndims);
77       GatherDimensionNumbers dim_numbers;
78       for (int i = 0; i < ndims - 2; ++i) {
79         dim_numbers.add_offset_dims(i);
80         dim_numbers.add_start_index_map(i);
81         slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
82       }
83       slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
84       dim_numbers.add_offset_dims(ndims - 1);
85       dim_numbers.add_offset_dims(ndims);
86       dim_numbers.add_start_index_map(ndims - 2);
87       dim_numbers.add_start_index_map(ndims - 1);
88       dim_numbers.set_index_vector_dim(1);
89       diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
90     }
91 
92     // The last block might be smaller than the block size,
93     // so we will need to pad it
94     if (n % block_size != 0) {
95       // Pad with identity matrix.
96       auto last_blocks =
97           SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
98       PaddingConfig config = MakeNoPaddingConfig(ndims);
99       int64_t padding = block_size - n % block_size;
100       config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
101       last_blocks =
102           Pad(last_blocks, Zero(builder, shape.element_type()), config);
103 
104       auto eye =
105           IdentityMatrix(builder, shape.element_type(), padding, padding);
106       config = MakeNoPaddingConfig(2);
107       config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
108       eye = Pad(eye, Zero(builder, shape.element_type()), config);
109       eye = Broadcast(eye, batch_dims);
110       last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
111 
112       // Add a singleton dimension
113       // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
114       TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
115       auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
116       auto last_blocks_dims = std::vector<int64>(ndims);
117       std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
118       last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
119       last_blocks = Reshape(last_blocks, last_blocks_dims);
120 
121       // Concatenate with the other blocks if necessary
122       if (n > block_size) {
123         diag_blocks =
124             ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
125       } else {
126         diag_blocks = last_blocks;
127       }
128     }
129 
130     return diag_blocks;
131   });
132 }
133 
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)134 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
135                                       bool left_side, bool lower,
136                                       bool transpose_a, bool conjugate_a,
137                                       PrecisionConfig::Precision precision) {
138   XlaBuilder* builder = a.builder();
139   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140     TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
141     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
142     int64_t block_size = ShapeUtil::GetDimension(blocks_shape, -1);
143 
144     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
145     int64_t ndims = a_shape.rank();
146     int64_t n = ShapeUtil::GetDimension(a_shape, -1);
147     int64_t num_blocks = n / block_size + (n % block_size != 0);
148     int64_t m_dim = (left_side) ? -1 : -2;
149     int64_t m = ShapeUtil::GetDimension(b_shape, m_dim);
150 
151     std::vector<XlaOp> update_ops;
152     int bdims = b_shape.rank();
153     int64_t block_dim = (left_side) ? bdims - 2 : bdims - 1;
154 
155     // Initialize the solution
156     XlaOp x;
157 
158     // This loop is unrolled for performance reasons, but it could be expressed
159     // rolled as well since the matrices are of the same size each iteration
160     for (int i = 0; i < num_blocks; i++) {
161       // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
162       // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
163       // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
164       // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
165 
166       // Decide whether we go from first block to last or vice versa
167       bool backward = left_side ^ lower ^ transpose_a;
168       auto j = backward ? num_blocks - 1 - i : i;
169 
170       // Get the size of the inverse blocks (the last one might be smaller)
171       int64_t block = (n % block_size != 0 && j + 1 == num_blocks)
172                           ? n % block_size
173                           : block_size;
174       auto inv_block =
175           MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
176                                                    {j + 1, block, block}),
177                                   /*dimensions=*/{ndims - 2, ndims - 1}),
178                          conjugate_a);
179 
180       // Get the corresponding row of B
181       int64_t k = std::min((j + 1) * block_size, n);
182       std::vector<int64> start = {j * block_size, 0};
183       std::vector<int64> end = {k, m};
184       if (!left_side) {
185         std::swap(start[0], start[1]);
186         std::swap(end[0], end[1]);
187       }
188       auto b_row = SliceInMinorDims(b, start, end);
189 
190       XlaOp remainder;
191       if (i == 0) {
192         remainder = b_row;
193       } else {
194         // This matrix multiply get rid of a lot of multiplying with zero
195         // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
196         if (backward) {
197           start = {j * block_size,
198                    std::max(int64{0}, (num_blocks - i) * block_size)};
199           end = {k, n};
200         } else {
201           start = {j * block_size, 0};
202           end = {k, std::min(i * block_size, n)};
203         }
204 
205         if (!left_side ^ transpose_a) {
206           std::swap(start[0], start[1]);
207           std::swap(end[0], end[1]);
208         }
209         auto a_row =
210             MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
211         if (left_side) {
212           remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
213         } else {
214           remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
215         }
216       }
217 
218       XlaOp x_update;
219       if (left_side) {
220         x_update =
221             BatchDot(inv_block, transpose_a, remainder, false, precision);
222       } else {
223         x_update =
224             BatchDot(remainder, false, inv_block, transpose_a, precision);
225       }
226 
227       if (i == 0) {
228         x = x_update;
229       } else {
230         if (backward) {
231           x = ConcatInDim(builder, {x_update, x}, block_dim);
232         } else {
233           x = ConcatInDim(builder, {x, x_update}, block_dim);
234         }
235       }
236     }
237 
238     return x;
239   });
240 }
241 
242 }  // namespace
243 
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower_triangular,PrecisionConfig::Precision precision)244 XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
245     XlaOp diag_blocks, bool lower_triangular,
246     PrecisionConfig::Precision precision) {
247   XlaBuilder* builder = diag_blocks.builder();
248   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
249     // Input is a batch of square lower triangular square matrices. Its shape is
250     // (..., size, size). We resize this to (num_blocks, size, size).
251     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
252     int64_t block_size = ShapeUtil::GetDimension(shape, -1);
253     int64_t num_blocks = ShapeUtil::ElementsIn(shape) /
254                          tensorflow::MathUtil::IPow(block_size, 2);
255     diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
256 
257     // The input must be triangular because we rely on that when doing
258     // multiplications later on
259     diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
260 
261     // Rescale blocks to be unit triangular, but avoid dividing by
262     // zero (which can happen if the last block was padded) otherwise it will
263     // introduce nans which will propagate
264     auto diags = GetMatrixDiagonal(diag_blocks);
265     auto ones = FullLike(diags, 1);
266     diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
267     auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
268 
269     // We can now use the fact that for an upper triangular matrix
270     // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
271     // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
272     // have been rescaled to be unit triangular, so L22 = L22' = 1.
273 
274     // Initialize the output matrix with -1s on the diagonal. We use -1 instead
275     // of 1 because we cannot do matrix-vector multiplies with variable shapes
276     // inside of a loop, or do irregularly shaped in-place updates. Hence,
277     // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
278     // entire row i.e. we calculate
279     // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
280     // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
281     auto identity =
282         IdentityMatrix(builder, shape.element_type(), block_size, block_size);
283     auto neg_identity = -identity;
284 
285     // The first or last  diagonal element should be set to 1 instead of -1
286     // though, since we never update it
287     auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
288     auto start_index =
289         ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
290     auto output_block =
291         DynamicUpdateSlice(neg_identity, pos_one,
292                            /*start_indices=*/{start_index, start_index});
293 
294     // Broadcast diag([1, -1, -1, ...]) to every block
295     XlaOp output = Broadcast(output_block,
296                              /*broadcast_sizes=*/{num_blocks});
297 
298     // Now we construct a loop that performs matrix-vector multiplications
299     // inverting the blocks one row at a time
300     std::vector<Shape> tuple_shapes = {
301         // The loop iteration counter is a scalar, incremented each iteration.
302         ShapeUtil::MakeShape(S32, {}),
303         // The output has the shape of A, with one row updated each iteration.
304         ShapeUtil::MakeShape(shape.element_type(),
305                              {num_blocks, block_size, block_size}),
306         // The input is a loop invariant.
307         ShapeUtil::MakeShape(shape.element_type(),
308                              {num_blocks, block_size, block_size})};
309     Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
310 
311     auto init_i = One(builder, S32);
312     auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
313 
314     // Construct the loop condition function.
315     std::unique_ptr<XlaBuilder> condb =
316         builder->CreateSubBuilder("InvertDiagCond");
317     {
318       auto i = GetTupleElement(
319           Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
320       Lt(i, ConstantR0<int32>(condb.get(), block_size));
321     }
322     TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
323 
324     // Construct the loop body function.
325     std::unique_ptr<XlaBuilder> bodyb =
326         builder->CreateSubBuilder("InvertDiagBody");
327     {
328       auto input_tuple =
329           Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
330 
331       auto i = GetTupleElement(input_tuple, 0);
332       auto body_out = GetTupleElement(input_tuple, 1);
333       auto body_input = GetTupleElement(input_tuple, 2);
334 
335       auto zero = ConstantR0<int32>(bodyb.get(), 0);
336       auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
337       auto input_row =
338           DynamicSlice(body_input, {zero, j, zero},
339                        /*slice_sizes=*/{num_blocks, 1, block_size});
340 
341       // We want -L21 L11^{-1}
342       DotDimensionNumbers dnums;
343       dnums.add_lhs_batch_dimensions(0);
344       dnums.add_rhs_batch_dimensions(0);
345       dnums.add_lhs_contracting_dimensions(2);
346       dnums.add_rhs_contracting_dimensions(1);
347       PrecisionConfig precision_proto;
348       precision_proto.add_operand_precision(precision);
349       precision_proto.add_operand_precision(precision);
350       auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
351 
352       body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
353 
354       auto next_i = i + ScalarLike(i, 1);
355       Tuple(bodyb.get(), {next_i, body_out, body_input});
356     }
357     TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
358 
359     // Construct the While loop and return the result,
360     // return while_loop(cond_fun, body_fun, init)[1]
361     auto invert_while = While(cond, body, init);
362     auto inv_diag_blocks = GetTupleElement(invert_while, 1);
363     // Undo the scaling
364     inv_diag_blocks = Div(inv_diag_blocks, diags,
365                           /*broadcast_dimensions=*/{0, 1});
366 
367     // Reshape back to original batch major dimensions
368     return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
369   });
370 }
371 
SolveByInvertingDiagonalBlocks(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)372 XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
373     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
374     bool conjugate_a, bool unit_diagonal,
375     PrecisionConfig::Precision precision) {
376   XlaBuilder* builder = a.builder();
377   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
378     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
379     const int64_t ndims = a_shape.rank();
380     int64_t k = ShapeUtil::GetDimension(a_shape, -1);
381 
382     // TODO(phawkins): consider pushing triangle masking into
383     // InvertDiagonalBlocks.
384     if (unit_diagonal) {
385       // Mask everything but the subdiagonal/superdiagonal elements.
386       a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
387                 : Select(TriangleMask(a, 0), ZerosLike(a), a);
388       a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
389                    /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
390     } else {
391       // Mask off the ignored elements of the triangular matrix a.
392       a = Triangle(a, lower);
393     }
394 
395     // We find the diagonal blocks of the coefficient matrix
396     int64_t block_size = std::min(block_size_, k);
397     auto diag_blocks = DiagonalBlocks(a, block_size);
398 
399     // We invert these blocks in parallel using batched matrix-vector products
400     auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
401 
402     // We now find the solution using GEMMs
403     return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
404                                            lower, transpose_a, conjugate_a,
405                                            precision);
406   });
407 }
408 
409 // def trsm_left_lower_leftlooking(a, b):
410 //   n = a.shape[-1]
411 //   assert a.shape == (n, n)
412 //   b = b.copy()
413 //   for j in range(n):
414 //     b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
415 //   return b
SolveDirectly(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)416 XlaOp TriangularSolveExpander::SolveDirectly(
417     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
418     bool conjugate_a, bool unit_diagonal,
419     PrecisionConfig::Precision precision) {
420   XlaBuilder* builder = a.builder();
421   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
422     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
423     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
424     int64_t m = ShapeUtil::GetDimension(b_shape, -2);
425     int64_t n = ShapeUtil::GetDimension(b_shape, -1);
426     const int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
427     a = MaybeConjugate(a, conjugate_a);
428     bool backwards = transpose_a ^ lower ^ !left_side;
429     for (int64_t i = 0; i < a_size; ++i) {
430       int64_t j = backwards ? i : (a_size - i - 1);
431       std::vector<int64> b_row_start, b_row_end;
432       if (left_side) {
433         b_row_start = {j, 0};
434         b_row_end = {j + 1, n};
435       } else {
436         b_row_start = {0, j};
437         b_row_end = {m, j + 1};
438       }
439       auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
440 
441       std::vector<int64> a_start = {j, backwards ? 0 : (j + 1)};
442       std::vector<int64> a_end = {j + 1, backwards ? j : a_size};
443       if (transpose_a ^ !left_side) {
444         std::swap(a_start[0], a_start[1]);
445         std::swap(a_end[0], a_end[1]);
446       }
447       auto a_chunk = SliceInMinorDims(a, a_start, a_end);
448       if (left_side) {
449         bool which = transpose_a ^ lower;
450         auto b_chunk =
451             SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
452         b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
453                                  /*transpose_y=*/false, precision);
454       } else {
455         bool which = transpose_a ^ !lower;
456         auto b_chunk =
457             SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
458         b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
459                                  /*transpose_y=*/transpose_a, precision);
460       }
461       if (!unit_diagonal) {
462         auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
463         b_row = b_row / a_diag;
464       }
465 
466       b = UpdateSliceInMinorDims(b, b_row, b_row_start);
467     }
468 
469     return b;
470   });
471 }
472 
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64_t block_size,PrecisionConfig::Precision precision)473 XlaOp TriangularSolveExpander::BuildTriangularSolve(
474     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
475     bool conjugate_a, bool unit_diagonal, int64_t block_size,
476     PrecisionConfig::Precision precision) {
477   XlaBuilder* builder = a.builder();
478   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
479     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
480     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
481     if (a_shape.rank() != b_shape.rank()) {
482       return InvalidArgument(
483           "Arguments to TriangularSolve have shapes with different ranks: "
484           "%s vs. %s",
485           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
486     }
487     const int64_t ndims = a_shape.rank();
488     if (ndims < 2) {
489       return InvalidArgument(
490           "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
491           ndims);
492     }
493     // The batch dimensions must be equal.
494     std::vector<int64> batch_dimensions;
495     int64_t batch = 1;
496     for (int i = 0; i < ndims - 2; ++i) {
497       int64_t a_size = a_shape.dimensions(i);
498       int64_t b_size = b_shape.dimensions(i);
499       if (a_size != b_size) {
500         return InvalidArgument(
501             "Batch dimensions of arguments to TriangularSolve must be equal; "
502             "shapes were %s and %s.",
503             ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
504       }
505       batch_dimensions.push_back(a_size);
506       batch *= a_size;
507     }
508 
509     if (ShapeUtil::GetDimension(a_shape, -1) !=
510         ShapeUtil::GetDimension(a_shape, -2)) {
511       return InvalidArgument(
512           "The 'a' argument to TriangularSolve must be a batched square matrix;"
513           " shape was: %s",
514           ShapeUtil::HumanString(a_shape));
515     }
516     const int64_t m = ShapeUtil::GetDimension(b_shape, -2);
517     const int64_t n = ShapeUtil::GetDimension(b_shape, -1);
518     if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
519       return InvalidArgument(
520           "Arguments to TriangularSolve have incompatible matrix shapes %s and "
521           "%s",
522           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
523     }
524 
525     int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
526 
527     if (ShapeUtil::IsZeroElementArray(b_shape)) {
528       // The output has the same shape as 'b', and since the output has zero
529       // elements, any such array will do.
530       return b;
531     }
532 
533     // Degenerate case: 1x1 matrices.
534     if (a_size == 1) {
535       return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
536     }
537 
538     // Prefer the direct implementation whenever there is a nontrivial batch
539     // dimension and the matrix is very small.
540     if (UseDirectSolves() && batch > block_size_ / 16 &&
541         a_size < block_size_ / 4) {
542       return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
543                            unit_diagonal, precision);
544     } else {
545       return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
546                                             conjugate_a, unit_diagonal,
547                                             precision);
548     }
549   });
550 }
551 
TriangularSolveExpander(int64_t block_size)552 TriangularSolveExpander::TriangularSolveExpander(int64_t block_size)
553     : block_size_(block_size) {
554   CHECK_GE(block_size_, 1);
555 }
556 
InstructionMatchesPattern(HloInstruction * instruction)557 bool TriangularSolveExpander::InstructionMatchesPattern(
558     HloInstruction* instruction) {
559   return instruction->opcode() == HloOpcode::kTriangularSolve;
560 }
561 
ExpandInstruction(HloInstruction * instruction)562 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
563     HloInstruction* instruction) {
564   const TriangularSolveOptions& options =
565       instruction->triangular_solve_options();
566   const string name = absl::StrFormat(
567       "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
568       instruction->operand(0)->shape().ToString(),
569       instruction->operand(1)->shape().ToString(),
570       options.left_side() ? "left" : "right",
571       options.lower() ? "lower" : "upper",
572       TriangularSolveOptions_Transpose_Name(options.transpose_a()),
573       options.unit_diagonal() ? "unit" : "nonunit");
574 
575   HloModule* module = instruction->parent()->parent();
576 
577   HloComputation*& computation =
578       computation_cache_.emplace(name, nullptr).first->second;
579   if (!computation) {
580     // Builds a new expansion.
581     //
582     // We do something unusual here: we build the computation using the
583     // XlaBuilder API, which is nominally an XLA client API. We do this because
584     // the external APIs for building complicated computations (XlaBuilder)
585     // are much more ergonomic than the internal ones. As it turns out,
586     // XlaBuilder isn't really a client API—what it does is build a
587     // HloModuleProto protocol buffer, that we can then deserialize and clone
588     // into our HloModule. Ideally we would avoid the protocol buffer step;
589     // that is left as an exercise for future work.
590     XlaBuilder builder(name);
591     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
592     XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
593     bool transpose_a =
594         options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
595     bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
596 
597     BuildTriangularSolve(a, b, options.left_side(), options.lower(),
598                          transpose_a, conjugate_a, options.unit_diagonal(),
599                          /*block_size=*/block_size_,
600                          /*precision=*/PrecisionConfig::HIGHEST);
601     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
602 
603     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
604                         xla_computation.GetProgramShape());
605     HloModuleConfig config(program_shape);
606     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
607                                              xla_computation.proto(), config));
608     HloCloneContext context(module);
609     computation =
610         module->DeepCloneComputation(new_module->entry_computation(), &context);
611   }
612 
613   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
614       instruction->shape(), instruction->operands(), computation));
615 }
616 
617 }  // namespace xla
618