• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <limits>
21 #include <numeric>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
36 #include "tensorflow/compiler/xla/client/lib/constants.h"
37 #include "tensorflow/compiler/xla/client/lib/slicing.h"
38 #include "tensorflow/compiler/xla/client/xla_builder.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 
48 namespace xla {
49 
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64_t m,int64_t n)50 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m,
51                      int64_t n) {
52   auto a = Iota(builder, U32, m);
53   auto b = Iota(builder, U32, n);
54   auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
55   return ConvertElementType(indicator, type);
56 }
57 
GetDiagonalMask(XlaOp x,int diagonal)58 XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
59   XlaBuilder* builder = x.builder();
60   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
61     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
62     auto n_dims = static_cast<int32_t>(shape.rank());
63     TF_RET_CHECK(n_dims >= 2);
64     auto m = shape.dimensions(n_dims - 2);
65     auto n = shape.dimensions(n_dims - 1);
66     absl::Span<const int64_t> major_dims =
67         shape.dimensions().subspan(/*pos=*/0, /*len=*/n_dims - 2);
68     auto a = Iota(builder, S32, n);
69     auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
70     auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
71     auto mask = Broadcast(indicator, major_dims);
72     return mask;
73   });
74 }
75 
GetMatrixDiagonal(XlaOp x,int k)76 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
77   XlaBuilder* builder = x.builder();
78   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
80     auto n_dims = static_cast<int32_t>(shape.rank());
81     TF_RET_CHECK(n_dims >= 2);
82     const int64_t m = shape.dimensions(n_dims - 2);
83     const int64_t n = shape.dimensions(n_dims - 1);
84 
85     if (k <= -m || k >= n) {
86       auto zero_size_shape = shape;
87       zero_size_shape.DeleteDimension(n_dims - 1);
88       zero_size_shape.set_dimensions(n_dims - 2, 0);
89       return ConstantLiteral(builder, Literal{zero_size_shape});
90     }
91     auto mask = GetDiagonalMask(x, k);
92 
93     int64_t reduce_dim = n_dims - 1;
94     if ((k == 0 && m >= n) || k < 0) {
95       reduce_dim = n_dims - 2;
96     }
97     auto result = Reduce(
98         Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
99         CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
100         {reduce_dim});
101     // k == 0, we can save one slice op.
102     if (k == 0) {
103       return result;
104     }
105     return SliceInMinorDims(result, {0},
106                             {k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
107   });
108 }
109 
GetMatrixDiagonalViaGather(XlaOp x,int k)110 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) {
111   XlaBuilder* builder = x.builder();
112   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
114     auto n_dims = static_cast<int32_t>(shape.rank());
115     TF_RET_CHECK(n_dims >= 2);
116     const int64_t m = shape.dimensions(n_dims - 2);
117     const int64_t n = shape.dimensions(n_dims - 1);
118 
119     // The start_indices has a shape of {diag_len, 2}, and each pair of value in
120     // its dimension 1 represents the (row, col) of the diagonal. We set
121     // index_vector_dim to 1 and make start_index_map and collapsed_slice_dims
122     // contain the same two dimension indices. This makes sure that the (row,
123     // col) pairs in start_indices are propagated to the indices for the two
124     // collapsed dimensions in the operand indices through start_index_map.
125     const int64_t num_index_dims = 2;
126     const int64_t axis = n_dims - num_index_dims;
127 
128     // Calculate the indices of diagonal part with offset k.
129     const int64_t diag_len =
130         std::max(std::min(m + std::min(k, 0), n - std::max(k, 0)), int64_t{0});
131     XlaOp diag_base_indices = BroadcastInDim(Iota(builder, S32, diag_len),
132                                              {diag_len, num_index_dims}, {0});
133     XlaOp diag_offset =
134         Broadcast(ConstantR1<int>(builder, {std::max(-k, 0), std::max(k, 0)}),
135                   {diag_len});
136     XlaOp start_indices = Add(diag_base_indices, diag_offset);
137 
138     // Example of a 3D diag-part extracting diagonal part with offset=1 out of a
139     // tensor of shape [2,5,4].
140     //
141     //  operand = s32[2,5,4] parameter(0)
142     //  indices = s32[3,2] parameter(1)
143     //  gather = s32[2,3] gather(operand, indices),
144     //       offset_dims={0},
145     //       collapsed_slice_dims={1,2},
146     //       start_index_map={1,2},
147     //       index_vector_dim=1,
148     //       slice_sizes={2, 1, 1}
149 
150     xla::GatherDimensionNumbers dim_numbers;
151     std::vector<int64_t> slice_sizes;
152     slice_sizes.reserve(n_dims);
153     for (int64_t i = 0; i < n_dims; i++) {
154       int64_t window_bound;
155       if (axis <= i) {
156         dim_numbers.add_collapsed_slice_dims(i);
157         dim_numbers.add_start_index_map(i);
158         window_bound = (shape.dimensions(i) != 0) ? 1 : 0;
159       } else {
160         dim_numbers.add_offset_dims(i);
161         window_bound = shape.dimensions(i);
162       }
163       slice_sizes.push_back(window_bound);
164     }
165 
166     dim_numbers.set_index_vector_dim(1);
167 
168     return Gather(x, start_indices, dim_numbers, slice_sizes,
169                   /*indices_are_sorted=*/true);
170   });
171 }
172 
SetMatrixDiagonal(XlaOp matrix,XlaOp diag,int k)173 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
174   XlaBuilder* builder = matrix.builder();
175   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
176     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
177     TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
178     auto n_dims = static_cast<int32_t>(shape.rank());
179     TF_RET_CHECK(n_dims >= 2);
180     const int64_t m = shape.dimensions(n_dims - 2);
181     const int64_t n = shape.dimensions(n_dims - 1);
182     const int64_t d = diag_shape.dimensions(n_dims - 2);
183     std::vector<int64_t> broadcast_dims(n_dims - 1);
184     absl::c_iota(broadcast_dims, 0);
185     int64_t pad_high = m - d;
186     if (k < 0) {
187       ++(broadcast_dims.back());
188       pad_high = n - d;
189     }
190 
191     if (pad_high != 0) {
192       PaddingConfig padding_config;
193       for (int64_t i = 0; i < diag_shape.rank() - 1; ++i) {
194         auto* dims = padding_config.add_dimensions();
195         dims->set_edge_padding_low(0);
196         dims->set_interior_padding(0);
197         dims->set_edge_padding_high(0);
198       }
199       auto* dims = padding_config.add_dimensions();
200       dims->set_edge_padding_low(0);
201       dims->set_interior_padding(0);
202       dims->set_edge_padding_high(pad_high);
203       diag = Pad(diag, ScalarLike(diag, 0), padding_config);
204     }
205 
206     return Select(GetDiagonalMask(matrix, k),
207                   BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
208                   matrix);
209   });
210 }
211 
TriangleMask(XlaOp x,int diagonal)212 XlaOp TriangleMask(XlaOp x, int diagonal) {
213   XlaBuilder* builder = x.builder();
214   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
215     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
216     const int64_t n_dims = shape.rank();
217     TF_RET_CHECK(n_dims >= 2);
218     const int64_t m = shape.dimensions(n_dims - 2);
219     const int64_t n = shape.dimensions(n_dims - 1);
220     absl::Span<const int64_t> major_dims =
221         shape.dimensions().subspan(/*pos=*/0, /*len=*/n_dims - 2);
222     auto a = Iota(builder, S32, n);
223     auto b = Iota(builder, S32, m) + ConstantR0<int32_t>(builder, diagonal);
224     XlaOp indicator;
225     indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
226     return Broadcast(indicator, major_dims);
227   });
228 }
229 
Triangle(XlaOp x,bool lower)230 XlaOp Triangle(XlaOp x, bool lower) {
231   return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
232                : Select(TriangleMask(x, -1), ZerosLike(x), x);
233 }
234 
UpperTriangle(XlaOp x)235 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
236 
LowerTriangle(XlaOp x)237 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
238 
Symmetrize(XlaOp x,bool lower)239 XlaOp Symmetrize(XlaOp x, bool lower) {
240   XlaBuilder* builder = x.builder();
241   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
242     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
243     if (shape.rank() < 2) {
244       return InvalidArgument(
245           "Argument to symmetrize must have >= 2 dimensions, got %s",
246           shape.ToString());
247     }
248     const int64_t m = ShapeUtil::GetDimension(shape, -2);
249     const int64_t n = ShapeUtil::GetDimension(shape, -1);
250     if (m != n) {
251       return InvalidArgument(
252           "The two most minor dimensions of the argument to symmetrize must be "
253           "equal size, got %s",
254           shape.ToString());
255     }
256     auto mask = lower ? TriangleMask(x, 0) : Not(TriangleMask(x, -1));
257     if (primitive_util::IsComplexType(shape.element_type())) {
258       auto re = Select(mask, Real(x), TransposeInMinorDims(Real(x)));
259       auto im_mask = lower ? TriangleMask(x, -1) : Not(TriangleMask(x, 0));
260       auto im = Select(im_mask, Imag(x), ZerosLike(Imag(x)));
261       im = Select(mask, im, -TransposeInMinorDims(im));
262       return Complex(re, im);
263     } else {
264       return Select(mask, x, TransposeInMinorDims(x));
265     }
266   });
267 }
268 
269 namespace {
EinsumDiagonalLabels(absl::Span<const int64_t> config)270 std::optional<std::array<std::vector<int64_t>, 3>> EinsumDiagonalLabels(
271     absl::Span<const int64_t> config) {
272   std::vector<int64_t> unique_labels;
273   std::vector<int64_t> reduce_dims;
274   std::vector<int64_t> broadcast_dims;
275   for (auto label = config.begin(); label != config.end(); ++label) {
276     auto first_label = absl::c_find(config, *label);
277     auto dim = label - config.begin();
278     if (first_label == label) {
279       unique_labels.push_back(*label);
280       broadcast_dims.push_back(dim);
281     } else {
282       reduce_dims.push_back(dim);
283     }
284   }
285   if (unique_labels.size() == config.size()) {
286     return std::nullopt;
287   }
288   return {{unique_labels, reduce_dims, broadcast_dims}};
289 }
290 
291 // Masks a tensor such that only the diagonal of repeated indices are non-zero.
292 // The result of this can be used to create a diagonal matrix with an identity
293 // reduction.
EinsumDiagonalMask(XlaOp x,absl::Span<const int64_t> config)294 xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64_t> config) {
295   XlaBuilder* builder = x.builder();
296   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
297     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
298     Shape iota_shape = x_shape;
299     iota_shape.set_element_type(S32);
300     XlaOp mask = ConstantR0(builder, true);
301 
302     for (auto label = config.begin(); label != config.end(); ++label) {
303       const int64_t dim = label - config.begin();
304       auto first_label = absl::c_find(config, *label);
305       if (first_label != label) {
306         const int64_t first_dim = first_label - config.begin();
307         mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
308                             Iota(builder, iota_shape, dim)));
309       }
310     }
311     return Select(mask, x, ZerosLike(x));
312   });
313 }
314 
EinsumDiagonal(XlaOp x,absl::Span<const int64_t> config)315 xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64_t> config) {
316   XlaBuilder* builder = x.builder();
317   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
318     auto labels = EinsumDiagonalLabels(config);
319     if (!labels) {
320       return x;
321     }
322     auto zero = ScalarLike(x, 0);
323     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
324     return Reduce(EinsumDiagonalMask(x, config), zero,
325                   CreateScalarIdentityWithZeroComputation(
326                       x_shape.element_type(), builder),
327                   labels->at(1));
328   });
329 }
330 
EinsumInverseDiagonal(XlaOp x,absl::Span<const int64_t> config)331 xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64_t> config) {
332   XlaBuilder* builder = x.builder();
333   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
334     auto labels = EinsumDiagonalLabels(config);
335     if (!labels) {
336       return x;
337     }
338     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
339     std::vector<int64_t> broadcast_sizes;
340     int64_t x_dim = 0;
341     for (auto label = config.begin(); label != config.end(); ++label) {
342       auto first_label = absl::c_find(config, *label);
343       if (first_label == label) {
344         broadcast_sizes.push_back(x_shape.dimensions(x_dim));
345         ++x_dim;
346       } else {
347         broadcast_sizes.push_back(
348             broadcast_sizes[first_label - config.begin()]);
349       }
350     }
351     x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
352     return EinsumDiagonalMask(x, config);
353   });
354 }
355 }  // namespace
356 
357 namespace {
358 // Helper method to remove dimensions from a shape and dot dimension numbers
359 // used to implement implicit broadcasting.
360 template <typename C>
DeleteDimsFromContainer(absl::Span<const int64_t> to_delete,Shape * shape,C * batch_dims,C * contracting_dims)361 void DeleteDimsFromContainer(absl::Span<const int64_t> to_delete, Shape* shape,
362                              C* batch_dims, C* contracting_dims) {
363   if (to_delete.empty()) {
364     return;
365   }
366   for (int64_t i = to_delete.size() - 1; i >= 0; --i) {
367     int64_t dim = to_delete[i];
368     shape->DeleteDimension(dim);
369     for (auto& b : *batch_dims) {
370       if (b > dim) {
371         --b;
372       }
373     }
374     for (auto& c : *contracting_dims) {
375       if (c > dim) {
376         --c;
377       }
378     }
379   }
380 }
381 }  // namespace
382 
Einsum(xla::XlaOp x,absl::Span<const int64_t> x_config,xla::XlaOp y,absl::Span<const int64_t> y_config,absl::Span<const int64_t> output_config,xla::PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)383 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64_t> x_config,
384                   xla::XlaOp y, absl::Span<const int64_t> y_config,
385                   absl::Span<const int64_t> output_config,
386                   xla::PrecisionConfig::Precision precision,
387                   std::optional<PrimitiveType> preferred_element_type) {
388   XlaBuilder* builder = x.builder();
389   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
390     auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
391     if (x_diagonal_labels) {
392       return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
393                     y_config, output_config, precision, preferred_element_type);
394     }
395     auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
396     if (y_diagonal_labels) {
397       return Einsum(x, x_config, EinsumDiagonal(y, y_config),
398                     y_diagonal_labels->at(0), output_config, precision,
399                     preferred_element_type);
400     }
401     auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
402     if (output_diagonal_labels) {
403       return EinsumInverseDiagonal(
404           Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
405                  precision, preferred_element_type),
406           output_config);
407     }
408 
409     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
410     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
411     const int64_t x_rank = x_config.size();
412     const int64_t y_rank = y_config.size();
413     const int64_t output_rank = output_config.size();
414     absl::flat_hash_set<int64_t> x_map;
415     absl::flat_hash_set<int64_t> y_map;
416     absl::flat_hash_set<int64_t> output_map;
417 
418     for (auto d : x_config) {
419       x_map.insert(d);
420     }
421 
422     for (auto d : y_config) {
423       y_map.insert(d);
424     }
425 
426     for (auto d : output_config) {
427       output_map.insert(d);
428     }
429 
430     DotDimensionNumbers dnums;
431     auto is_batch_dim = [&](int64_t d) {
432       return x_map.contains(d) && y_map.contains(d) && output_map.contains(d);
433     };
434     auto is_contracting = [&](int64_t d) {
435       return x_map.contains(d) && y_map.contains(d);
436     };
437 
438     auto rhs_dimension_number = [&](int64_t d) {
439       return absl::c_find(y_config, d) - y_config.begin();
440     };
441 
442     absl::InlinedVector<int64_t, 8> rhs_outer_dims;
443     absl::InlinedVector<int64_t, 8> lhs_outer_dims;
444     absl::InlinedVector<int64_t, 8> rhs_delete_dims;
445     absl::InlinedVector<int64_t, 8> lhs_delete_dims;
446     for (int64_t i = 0; i < x_rank; ++i) {
447       auto dim_name = x_config[i];
448       const int64_t rhs_dim = rhs_dimension_number(dim_name);
449 
450       if (is_batch_dim(dim_name)) {
451         if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
452           dnums.add_lhs_batch_dimensions(i);
453           dnums.add_rhs_batch_dimensions(rhs_dim);
454         } else if (x_shape.dimensions(i) == 1) {
455           rhs_outer_dims.push_back(rhs_dim);
456           lhs_delete_dims.push_back(i);
457         } else {
458           lhs_outer_dims.push_back(i);
459           rhs_delete_dims.push_back(rhs_dim);
460         }
461       } else if (is_contracting(dim_name)) {
462         if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
463           dnums.add_lhs_contracting_dimensions(i);
464           dnums.add_rhs_contracting_dimensions(rhs_dim);
465         } else if (x_shape.dimensions(i) == 1) {
466           rhs_outer_dims.push_back(rhs_dim);
467           lhs_delete_dims.push_back(i);
468         } else {
469           lhs_outer_dims.push_back(i);
470           rhs_delete_dims.push_back(rhs_dim);
471         }
472       } else {
473         lhs_outer_dims.push_back(i);
474       }
475     }
476 
477     for (int64_t i = 0; i < y_rank; ++i) {
478       auto dim_name = y_config[i];
479       if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
480         rhs_outer_dims.push_back(i);
481       }
482     }
483 
484     absl::c_sort(rhs_outer_dims);
485     absl::InlinedVector<int64_t, 8> output_transpose_dims;
486 
487     auto output_dimension_number = [&](int64_t d) -> std::optional<int64_t> {
488       auto pos = absl::c_find(output_config, d);
489       if (pos == output_config.end()) {
490         return std::nullopt;
491       }
492       return pos - output_config.begin();
493     };
494 
495     for (auto d : dnums.lhs_batch_dimensions()) {
496       output_transpose_dims.push_back(*output_dimension_number(x_config[d]));
497     }
498 
499     for (auto d : lhs_outer_dims) {
500       if (auto output_dim = output_dimension_number(x_config[d])) {
501         output_transpose_dims.push_back(*output_dim);
502         continue;
503       }
504       lhs_delete_dims.push_back(d);
505     }
506 
507     for (auto d : rhs_outer_dims) {
508       if (auto output_dim = output_dimension_number(y_config[d])) {
509         output_transpose_dims.push_back(*output_dim);
510         continue;
511       }
512       rhs_delete_dims.push_back(d);
513     }
514 
515     const int64_t transpose_rank = output_transpose_dims.size();
516     std::vector<int64_t> transpose_dims(output_rank);
517     for (int64_t i = 0; i < transpose_rank; ++i) {
518       transpose_dims[output_transpose_dims[i]] = i;
519     }
520 
521     // Remove ones that where broadcasted from the x and the y shape and adjust
522     // the dimension numbers that are more minor than those dimensions.
523     absl::c_sort(lhs_delete_dims);
524     DeleteDimsFromContainer(lhs_delete_dims, &x_shape,
525                             dnums.mutable_lhs_batch_dimensions(),
526                             dnums.mutable_lhs_contracting_dimensions());
527 
528     absl::c_sort(rhs_delete_dims);
529     DeleteDimsFromContainer(rhs_delete_dims, &y_shape,
530                             dnums.mutable_rhs_batch_dimensions(),
531                             dnums.mutable_rhs_contracting_dimensions());
532     if (!lhs_delete_dims.empty()) {
533       x = Reduce(x, ScalarLike(x, 0),
534                  CreateScalarAddComputation(x_shape.element_type(), builder),
535                  lhs_delete_dims);
536     }
537 
538     if (!rhs_delete_dims.empty()) {
539       y = Reduce(y, ScalarLike(y, 0),
540                  CreateScalarAddComputation(y_shape.element_type(), builder),
541                  rhs_delete_dims);
542     }
543 
544     PrecisionConfig precision_proto;
545     precision_proto.add_operand_precision(precision);
546     precision_proto.add_operand_precision(precision);
547     auto dot =
548         DotGeneral(x, y, dnums, &precision_proto, preferred_element_type);
549     dot = Transpose(dot, transpose_dims);
550     if (transpose_rank == output_rank) {
551       return dot;
552     }
553 
554     auto is_output_only = [&](int64_t d) {
555       return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
556     };
557 
558     int64_t dot_dim = 0;
559     std::vector<int64_t> new_dims;
560     new_dims.reserve(output_rank);
561     TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
562     for (auto d : output_config) {
563       if (is_output_only(d)) {
564         new_dims.push_back(1);
565       } else {
566         new_dims.push_back(dot_shape.dimensions(dot_dim));
567       }
568     }
569     return Reshape(dot, new_dims);
570   });
571 }
572 
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)573 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision,
574                std::optional<PrimitiveType> preferred_element_type) {
575   return BatchDot(x, false, y, false, precision, preferred_element_type);
576 }
577 
BatchDot(XlaOp x,bool transpose_x,XlaOp y,bool transpose_y,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)578 XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
579                PrecisionConfig::Precision precision,
580                std::optional<PrimitiveType> preferred_element_type) {
581   XlaBuilder* builder = x.builder();
582   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
583     std::string string("...mk,...kn->...mn");
584     if (transpose_x) {
585       std::swap(string[3], string[4]);
586     }
587     if (transpose_y) {
588       std::swap(string[6 + 3], string[6 + 4]);
589     }
590     return Einsum(x, y, string, precision, preferred_element_type);
591   });
592 }
593 
ParseEinsumString(absl::string_view einsum_config,int64_t x_rank,int64_t y_rank)594 StatusOr<std::array<std::vector<int64_t>, 3>> ParseEinsumString(
595     absl::string_view einsum_config, int64_t x_rank, int64_t y_rank) {
596   std::array<std::vector<int64_t>, 3> einsum_config_numeric;
597   std::vector<absl::string_view> main_split =
598       absl::StrSplit(einsum_config, ',');
599   if (main_split.size() != 2) {
600     return InvalidArgument("Expected one \",\" in einsum_config.");
601   }
602 
603   auto maybe_invalid_character = [](char d) {
604     if (absl::ascii_isalpha(d)) {
605       return OkStatus();
606     }
607     if (d == '.') {
608       return InvalidArgument("Unsupported \".\" in einsum config.");
609     }
610     return InvalidArgument("Unexpected character in einsum config.");
611   };
612 
613   auto string_config_to_numeric =
614       [&](absl::string_view config, bool is_input_config, int64_t input_rank,
615           int64_t ellipsis_rank,
616           std::vector<int64_t>* numeric_config) -> StatusOr<int64_t> {
617     std::vector<absl::string_view> splits = absl::StrSplit(config, "...");
618     if (splits.empty()) {
619       return ellipsis_rank;
620     }
621     if (splits.size() > 2) {
622       return InvalidArgument("Too many ellipses (\"...\") in einsum config.");
623     }
624     // There is one split if we don't have an ellipsis, and two splits if we do.
625     const bool has_ellipsis = splits.size() > 1;
626     // We only compute ellipsis_rank for input configs.
627     if (is_input_config && has_ellipsis) {
628       // ellipsis_rank is input rank minus the number of named labels.
629       ellipsis_rank = input_rank -
630                       static_cast<int64_t>(splits[0].size() + splits[1].size());
631       if (ellipsis_rank < 0) {
632         return InvalidArgument(
633             "Too few dimensions in the input for the given einsum config.");
634       }
635     }
636     for (char d : splits[0]) {
637       TF_RETURN_IF_ERROR(maybe_invalid_character(d));
638       numeric_config->push_back(static_cast<int64_t>(d));
639     }
640     if (has_ellipsis) {
641       // For input configs, we use the value of ellipsis_rank we just computed.
642       // For output config, we use the existing value of ellipsis_rank.
643       for (int64_t i = ellipsis_rank; i > 0; --i) {
644         numeric_config->push_back(-i);
645       }
646       for (char d : splits[1]) {
647         TF_RETURN_IF_ERROR(maybe_invalid_character(d));
648         numeric_config->push_back(static_cast<int64_t>(d));
649       }
650     }
651     return ellipsis_rank;
652   };
653 
654   TF_ASSIGN_OR_RETURN(
655       const int64_t x_ellipsis_rank,
656       string_config_to_numeric(main_split[0],
657                                /*is_input_config=*/true, x_rank,
658                                /*ellipsis_rank=*/0, &einsum_config_numeric[0]));
659 
660   std::vector<absl::string_view> y_output_split =
661       absl::StrSplit(main_split[1], "->");
662   if (y_output_split.size() != 2) {
663     return InvalidArgument("Expected one \"->\" in einsum_config.");
664   }
665 
666   TF_ASSIGN_OR_RETURN(
667       const int64_t y_ellipsis_rank,
668       string_config_to_numeric(y_output_split[0],
669                                /*is_input_config=*/true, y_rank,
670                                /*ellipsis_rank=*/0, &einsum_config_numeric[1]));
671 
672   // Replace ellipsis in output_config with numeric labels with the same
673   // ellipsis rank as in the inputs.
674   // Note: This implementation doesn't support different-rank broadcasting.
675   TF_ASSIGN_OR_RETURN(
676       std::ignore,
677       string_config_to_numeric(
678           y_output_split[1], /*is_input_config=*/false,
679           /*input_rank=*/0,
680           /*ellipsis_rank=*/std::max(x_ellipsis_rank, y_ellipsis_rank),
681           &einsum_config_numeric[2]));
682   return einsum_config_numeric;
683 }
684 
NormalizeEinsumString(absl::string_view einsum_config)685 std::string NormalizeEinsumString(absl::string_view einsum_config) {
686   if (einsum_config.find("->") != einsum_config.npos) {
687     return "";
688   }
689   bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
690   std::map<char, int64_t> chars;
691   for (char c : einsum_config) {
692     if (absl::ascii_isalpha(c)) {
693       ++chars[c];
694     }
695   }
696   std::string new_config(einsum_config.begin(), einsum_config.end());
697   new_config.append("->");
698   if (has_ellipsis) {
699     new_config.append("...");
700   }
701   for (auto p : chars) {
702     if (p.second == 1) {
703       new_config.push_back(p.first);
704     }
705   }
706   return new_config;
707 }
708 
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)709 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
710              PrecisionConfig::Precision precision,
711              std::optional<PrimitiveType> preferred_element_type) {
712   XlaBuilder* builder = x.builder();
713   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
714     auto new_config = NormalizeEinsumString(einsum_config);
715     if (!new_config.empty()) {
716       return Einsum(x, y, new_config, precision, preferred_element_type);
717     }
718     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
719     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
720     TF_ASSIGN_OR_RETURN(
721         auto einsum_config_numeric,
722         ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
723     return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
724                   einsum_config_numeric[2], precision, preferred_element_type);
725   });
726 }
727 
Einsum(XlaOp x,absl::string_view einsum_config,PrecisionConfig::Precision precision)728 XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
729              PrecisionConfig::Precision precision) {
730   return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
731                 precision);
732 }
733 
TransposeInMinorDims(XlaOp x)734 XlaOp TransposeInMinorDims(XlaOp x) {
735   XlaBuilder* builder = x.builder();
736   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
737     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
738     const int64_t n_dims = shape.rank();
739     TF_RET_CHECK(n_dims >= 2);
740     std::vector<int64_t> permutation(n_dims);
741     std::iota(permutation.begin(), permutation.end(), 0);
742     std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
743     return Transpose(x, permutation);
744   });
745 }
746 
MaybeTransposeInMinorDims(XlaOp x,bool transpose)747 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
748   return transpose ? TransposeInMinorDims(x) : x;
749 }
750 
751 }  // namespace xla
752