• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <limits>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/types/optional.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
36 #include "tensorflow/compiler/xla/client/lib/constants.h"
37 #include "tensorflow/compiler/xla/client/lib/slicing.h"
38 #include "tensorflow/compiler/xla/client/xla_builder.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 
47 namespace xla {
48 
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64 m,int64 n)49 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
50                      int64 n) {
51   auto a = Iota(builder, U32, m);
52   auto b = Iota(builder, U32, n);
53   auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
54   return ConvertElementType(indicator, type);
55 }
56 
GetDiagonalMask(XlaOp x,int diagonal)57 XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
58   XlaBuilder* builder = x.builder();
59   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
60     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
61     auto n_dims = static_cast<int32>(shape.rank());
62     TF_RET_CHECK(n_dims >= 2);
63     auto m = shape.dimensions(n_dims - 2);
64     auto n = shape.dimensions(n_dims - 1);
65     absl::Span<const int64> major_dims =
66         AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
67     auto a = Iota(builder, S32, n);
68     auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
69     auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
70     auto mask = Broadcast(indicator, major_dims);
71     return mask;
72   });
73 }
74 
GetMatrixDiagonal(XlaOp x,int k)75 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
76   XlaBuilder* builder = x.builder();
77   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
78     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
79     auto n_dims = static_cast<int32>(shape.rank());
80     TF_RET_CHECK(n_dims >= 2);
81     const int64 m = shape.dimensions(n_dims - 2);
82     const int64 n = shape.dimensions(n_dims - 1);
83 
84     if (k <= -m || k >= n) {
85       auto zero_size_shape = shape;
86       zero_size_shape.DeleteDimension(n_dims - 1);
87       zero_size_shape.set_dimensions(n_dims - 2, 0);
88       return ConstantLiteral(builder, Literal{zero_size_shape});
89     }
90     auto mask = GetDiagonalMask(x, k);
91 
92     int64 reduce_dim = n_dims - 1;
93     if ((k == 0 && m >= n) || k < 0) {
94       reduce_dim = n_dims - 2;
95     }
96     auto result = Reduce(
97         Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
98         CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
99         {reduce_dim});
100     // k == 0, we can save one slice op.
101     if (k == 0) {
102       return result;
103     }
104     return SliceInMinorDims(result, {0},
105                             {k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
106   });
107 }
108 
GetMatrixDiagonalViaGather(XlaOp x,int k)109 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) {
110   XlaBuilder* builder = x.builder();
111   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
112     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
113     auto n_dims = static_cast<int32>(shape.rank());
114     TF_RET_CHECK(n_dims >= 2);
115     const int64 m = shape.dimensions(n_dims - 2);
116     const int64 n = shape.dimensions(n_dims - 1);
117 
118     // The start_indices has a shape of {diag_len, 2}, and each pair of value in
119     // its dimension 1 represents the (row, col) of the diagonal. We set
120     // index_vector_dim to 1 and make start_index_map and collapsed_slice_dims
121     // contain the same two dimension indices. This makes sure that the (row,
122     // col) pairs in start_indices are propagated to the indices for the two
123     // collapsed dimensions in the operand indices through start_index_map.
124     const int64 num_index_dims = 2;
125     const int64 axis = n_dims - num_index_dims;
126 
127     // Calculate the indices of diagonal part with offset k.
128     const int64 diag_len =
129         std::max(std::min(m + std::min(k, 0), n - std::max(k, 0)), int64{0});
130     XlaOp diag_base_indices = BroadcastInDim(Iota(builder, S32, diag_len),
131                                              {diag_len, num_index_dims}, {0});
132     XlaOp diag_offset =
133         Broadcast(ConstantR1<int>(builder, {std::max(-k, 0), std::max(k, 0)}),
134                   {diag_len});
135     XlaOp start_indices = Add(diag_base_indices, diag_offset);
136 
137     // Example of a 3D diag-part extracting diagonal part with offset=1 out of a
138     // tensor of shape [2,5,4].
139     //
140     //  operand = s32[2,5,4] parameter(0)
141     //  indices = s32[3,2] parameter(1)
142     //  gather = s32[2,3] gather(operand, indices),
143     //       offset_dims={0},
144     //       collapsed_slice_dims={1,2},
145     //       start_index_map={1,2},
146     //       index_vector_dim=1,
147     //       slice_sizes={2, 1, 1}
148 
149     xla::GatherDimensionNumbers dim_numbers;
150     std::vector<int64> slice_sizes;
151     slice_sizes.reserve(n_dims);
152     for (int64 i = 0; i < n_dims; i++) {
153       int64 window_bound;
154       if (axis <= i) {
155         dim_numbers.add_collapsed_slice_dims(i);
156         dim_numbers.add_start_index_map(i);
157         window_bound = (shape.dimensions(i) != 0) ? 1 : 0;
158       } else {
159         dim_numbers.add_offset_dims(i);
160         window_bound = shape.dimensions(i);
161       }
162       slice_sizes.push_back(window_bound);
163     }
164 
165     dim_numbers.set_index_vector_dim(1);
166 
167     return Gather(x, start_indices, dim_numbers, slice_sizes,
168                   /*indices_are_sorted=*/true);
169   });
170 }
171 
SetMatrixDiagonal(XlaOp matrix,XlaOp diag,int k)172 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
173   XlaBuilder* builder = matrix.builder();
174   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
175     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
176     TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
177     auto n_dims = static_cast<int32>(shape.rank());
178     TF_RET_CHECK(n_dims >= 2);
179     const int64 m = shape.dimensions(n_dims - 2);
180     const int64 n = shape.dimensions(n_dims - 1);
181     const int64 d = diag_shape.dimensions(n_dims - 2);
182     std::vector<int64> broadcast_dims(n_dims - 1);
183     absl::c_iota(broadcast_dims, 0);
184     int64 pad_high = m - d;
185     if (k < 0) {
186       ++(broadcast_dims.back());
187       pad_high = n - d;
188     }
189 
190     if (pad_high != 0) {
191       PaddingConfig padding_config;
192       for (xla::int64 i = 0; i < diag_shape.rank() - 1; ++i) {
193         auto* dims = padding_config.add_dimensions();
194         dims->set_edge_padding_low(0);
195         dims->set_interior_padding(0);
196         dims->set_edge_padding_high(0);
197       }
198       auto* dims = padding_config.add_dimensions();
199       dims->set_edge_padding_low(0);
200       dims->set_interior_padding(0);
201       dims->set_edge_padding_high(pad_high);
202       diag = Pad(diag, ScalarLike(diag, 0), padding_config);
203     }
204 
205     return Select(GetDiagonalMask(matrix, k),
206                   BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
207                   matrix);
208   });
209 }
210 
TriangleMask(XlaOp x,int diagonal)211 XlaOp TriangleMask(XlaOp x, int diagonal) {
212   XlaBuilder* builder = x.builder();
213   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
214     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
215     const int64 n_dims = shape.rank();
216     TF_RET_CHECK(n_dims >= 2);
217     const int64 m = shape.dimensions(n_dims - 2);
218     const int64 n = shape.dimensions(n_dims - 1);
219     absl::Span<const int64> major_dims =
220         AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
221     auto a = Iota(builder, S32, n);
222     auto b = Iota(builder, S32, m) + ConstantR0<int32>(builder, diagonal);
223     XlaOp indicator;
224     indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
225     return Broadcast(indicator, major_dims);
226   });
227 }
228 
Triangle(XlaOp x,bool lower)229 XlaOp Triangle(XlaOp x, bool lower) {
230   return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
231                : Select(TriangleMask(x, -1), ZerosLike(x), x);
232 }
233 
UpperTriangle(XlaOp x)234 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
235 
LowerTriangle(XlaOp x)236 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
237 
238 namespace {
EinsumDiagonalLabels(absl::Span<const int64> config)239 absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
240     absl::Span<const int64> config) {
241   std::vector<int64> unique_labels;
242   std::vector<int64> reduce_dims;
243   std::vector<int64> broadcast_dims;
244   for (auto label = config.begin(); label != config.end(); ++label) {
245     auto first_label = absl::c_find(config, *label);
246     auto dim = label - config.begin();
247     if (first_label == label) {
248       unique_labels.push_back(*label);
249       broadcast_dims.push_back(dim);
250     } else {
251       reduce_dims.push_back(dim);
252     }
253   }
254   if (unique_labels.size() == config.size()) {
255     return absl::nullopt;
256   }
257   return {{unique_labels, reduce_dims, broadcast_dims}};
258 }
259 
260 // Masks a tensor such that only the diagonal of repeated indices are non-zero.
261 // The result of this can be used to create a diagonal matrix with an identity
262 // reduction.
EinsumDiagonalMask(XlaOp x,absl::Span<const int64> config)263 xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64> config) {
264   XlaBuilder* builder = x.builder();
265   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
266     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
267     Shape iota_shape = x_shape;
268     iota_shape.set_element_type(S32);
269     XlaOp mask = ConstantR0(builder, true);
270 
271     for (auto label = config.begin(); label != config.end(); ++label) {
272       const int64 dim = label - config.begin();
273       auto first_label = absl::c_find(config, *label);
274       if (first_label != label) {
275         const int64 first_dim = first_label - config.begin();
276         mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
277                             Iota(builder, iota_shape, dim)));
278       }
279     }
280     return Select(mask, x, ZerosLike(x));
281   });
282 }
283 
EinsumDiagonal(XlaOp x,absl::Span<const int64> config)284 xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
285   XlaBuilder* builder = x.builder();
286   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
287     auto labels = EinsumDiagonalLabels(config);
288     if (!labels) {
289       return x;
290     }
291     auto zero = ScalarLike(x, 0);
292     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
293     return Reduce(EinsumDiagonalMask(x, config), zero,
294                   CreateScalarIdentityWithZeroComputation(
295                       x_shape.element_type(), builder),
296                   labels->at(1));
297   });
298 }
299 
EinsumInverseDiagonal(XlaOp x,absl::Span<const int64> config)300 xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64> config) {
301   XlaBuilder* builder = x.builder();
302   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
303     auto labels = EinsumDiagonalLabels(config);
304     if (!labels) {
305       return x;
306     }
307     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
308     std::vector<int64> broadcast_sizes;
309     int64 x_dim = 0;
310     for (auto label = config.begin(); label != config.end(); ++label) {
311       auto first_label = absl::c_find(config, *label);
312       if (first_label == label) {
313         broadcast_sizes.push_back(x_shape.dimensions(x_dim));
314         ++x_dim;
315       } else {
316         broadcast_sizes.push_back(
317             broadcast_sizes[first_label - config.begin()]);
318       }
319     }
320     x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
321     return EinsumDiagonalMask(x, config);
322   });
323 }
324 }  // namespace
325 
326 namespace {
327 // Helper method to remove dimensions from a shape and dot dimension numbers
328 // used to implement implicit broadcasting.
329 template <typename C>
DeleteDimsFromContainer(absl::Span<const int64> to_delete,Shape * shape,C * batch_dims,C * contracting_dims)330 void DeleteDimsFromContainer(absl::Span<const int64> to_delete, Shape* shape,
331                              C* batch_dims, C* contracting_dims) {
332   if (to_delete.empty()) {
333     return;
334   }
335   for (int64 i = to_delete.size() - 1; i >= 0; --i) {
336     int64 dim = to_delete[i];
337     shape->DeleteDimension(dim);
338     for (auto& b : *batch_dims) {
339       if (b > dim) {
340         --b;
341       }
342     }
343     for (auto& c : *contracting_dims) {
344       if (c > dim) {
345         --c;
346       }
347     }
348   }
349 }
350 }  // namespace
351 
Einsum(xla::XlaOp x,absl::Span<const int64> x_config,xla::XlaOp y,absl::Span<const int64> y_config,absl::Span<const int64> output_config,xla::PrecisionConfig::Precision precision)352 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
353                   absl::Span<const int64> y_config,
354                   absl::Span<const int64> output_config,
355                   xla::PrecisionConfig::Precision precision) {
356   XlaBuilder* builder = x.builder();
357   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
358     auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
359     if (x_diagonal_labels) {
360       return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
361                     y_config, output_config, precision);
362     }
363     auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
364     if (y_diagonal_labels) {
365       return Einsum(x, x_config, EinsumDiagonal(y, y_config),
366                     y_diagonal_labels->at(0), output_config, precision);
367     }
368     auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
369     if (output_diagonal_labels) {
370       return EinsumInverseDiagonal(
371           Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
372                  precision),
373           output_config);
374     }
375 
376     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
377     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
378     const int64 x_rank = x_config.size();
379     const int64 y_rank = y_config.size();
380     const int64 output_rank = output_config.size();
381     absl::flat_hash_set<int64> x_map;
382     absl::flat_hash_set<int64> y_map;
383     absl::flat_hash_set<int64> output_map;
384 
385     for (auto d : x_config) {
386       x_map.insert(d);
387     }
388 
389     for (auto d : y_config) {
390       y_map.insert(d);
391     }
392 
393     for (auto d : output_config) {
394       output_map.insert(d);
395     }
396 
397     DotDimensionNumbers dnums;
398     auto is_batch_dim = [&](int64 d) {
399       return x_map.contains(d) && y_map.contains(d) && output_map.contains(d);
400     };
401     auto is_contracting = [&](int64 d) {
402       return x_map.contains(d) && y_map.contains(d);
403     };
404 
405     auto rhs_dimension_number = [&](int64 d) {
406       return absl::c_find(y_config, d) - y_config.begin();
407     };
408 
409     absl::InlinedVector<int64, 8> rhs_outer_dims;
410     absl::InlinedVector<int64, 8> lhs_outer_dims;
411     absl::InlinedVector<int64, 8> rhs_delete_dims;
412     absl::InlinedVector<int64, 8> lhs_delete_dims;
413     for (int64 i = 0; i < x_rank; ++i) {
414       auto dim_name = x_config[i];
415       const int64 rhs_dim = rhs_dimension_number(dim_name);
416 
417       if (is_batch_dim(dim_name)) {
418         if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
419           dnums.add_lhs_batch_dimensions(i);
420           dnums.add_rhs_batch_dimensions(rhs_dim);
421         } else if (x_shape.dimensions(i) == 1) {
422           rhs_outer_dims.push_back(rhs_dim);
423           lhs_delete_dims.push_back(i);
424         } else {
425           lhs_outer_dims.push_back(i);
426           rhs_delete_dims.push_back(rhs_dim);
427         }
428       } else if (is_contracting(dim_name)) {
429         if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
430           dnums.add_lhs_contracting_dimensions(i);
431           dnums.add_rhs_contracting_dimensions(rhs_dim);
432         } else if (x_shape.dimensions(i) == 1) {
433           rhs_outer_dims.push_back(rhs_dim);
434           lhs_delete_dims.push_back(i);
435         } else {
436           lhs_outer_dims.push_back(i);
437           rhs_delete_dims.push_back(rhs_dim);
438         }
439       } else {
440         lhs_outer_dims.push_back(i);
441       }
442     }
443 
444     for (int64 i = 0; i < y_rank; ++i) {
445       auto dim_name = y_config[i];
446       if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
447         rhs_outer_dims.push_back(i);
448       }
449     }
450 
451     absl::c_sort(rhs_outer_dims);
452     absl::InlinedVector<int64, 8> output_transpose_dims;
453 
454     auto output_dimension_number = [&](int64 d) -> absl::optional<int64> {
455       auto pos = absl::c_find(output_config, d);
456       if (pos == output_config.end()) {
457         return absl::nullopt;
458       }
459       return pos - output_config.begin();
460     };
461 
462     for (auto d : dnums.lhs_batch_dimensions()) {
463       output_transpose_dims.push_back(*output_dimension_number(x_config[d]));
464     }
465 
466     for (auto d : lhs_outer_dims) {
467       if (auto output_dim = output_dimension_number(x_config[d])) {
468         output_transpose_dims.push_back(*output_dim);
469         continue;
470       }
471       lhs_delete_dims.push_back(d);
472     }
473 
474     for (auto d : rhs_outer_dims) {
475       if (auto output_dim = output_dimension_number(y_config[d])) {
476         output_transpose_dims.push_back(*output_dim);
477         continue;
478       }
479       rhs_delete_dims.push_back(d);
480     }
481 
482     const int64 transpose_rank = output_transpose_dims.size();
483     std::vector<int64> transpose_dims(output_rank);
484     for (int64 i = 0; i < transpose_rank; ++i) {
485       transpose_dims[output_transpose_dims[i]] = i;
486     }
487 
488     // Remove ones that where broadcasted from the x and the y shape and adjust
489     // the dimension numbers that are more minor than those dimensions.
490     absl::c_sort(lhs_delete_dims);
491     DeleteDimsFromContainer(lhs_delete_dims, &x_shape,
492                             dnums.mutable_lhs_batch_dimensions(),
493                             dnums.mutable_lhs_contracting_dimensions());
494 
495     absl::c_sort(rhs_delete_dims);
496     DeleteDimsFromContainer(rhs_delete_dims, &y_shape,
497                             dnums.mutable_rhs_batch_dimensions(),
498                             dnums.mutable_rhs_contracting_dimensions());
499     if (!lhs_delete_dims.empty()) {
500       x = Reduce(x, ScalarLike(x, 0),
501                  CreateScalarAddComputation(x_shape.element_type(), builder),
502                  lhs_delete_dims);
503     }
504 
505     if (!rhs_delete_dims.empty()) {
506       y = Reduce(y, ScalarLike(y, 0),
507                  CreateScalarAddComputation(y_shape.element_type(), builder),
508                  rhs_delete_dims);
509     }
510 
511     PrecisionConfig precision_proto;
512     precision_proto.add_operand_precision(precision);
513     precision_proto.add_operand_precision(precision);
514     auto dot = DotGeneral(x, y, dnums, &precision_proto);
515     dot = Transpose(dot, transpose_dims);
516     if (transpose_rank == output_rank) {
517       return dot;
518     }
519 
520     auto is_output_only = [&](int64 d) {
521       return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
522     };
523 
524     int64 dot_dim = 0;
525     std::vector<int64> new_dims;
526     new_dims.reserve(output_rank);
527     TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
528     for (auto d : output_config) {
529       if (is_output_only(d)) {
530         new_dims.push_back(1);
531       } else {
532         new_dims.push_back(dot_shape.dimensions(dot_dim));
533       }
534     }
535     return Reshape(dot, new_dims);
536   });
537 }
538 
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision)539 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
540   return BatchDot(x, false, y, false, precision);
541 }
542 
BatchDot(XlaOp x,bool transpose_x,XlaOp y,bool transpose_y,PrecisionConfig::Precision precision)543 XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
544                PrecisionConfig::Precision precision) {
545   XlaBuilder* builder = x.builder();
546   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
547     std::string string("...mk,...kn->...mn");
548     if (transpose_x) {
549       std::swap(string[3], string[4]);
550     }
551     if (transpose_y) {
552       std::swap(string[6 + 3], string[6 + 4]);
553     }
554     return Einsum(x, y, string, precision);
555   });
556 }
557 
ParseEinsumString(absl::string_view einsum_config,int64 x_rank,int64 y_rank)558 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
559     absl::string_view einsum_config, int64 x_rank, int64 y_rank) {
560   std::array<std::vector<int64>, 3> einsum_config_numeric;
561   std::vector<absl::string_view> main_split =
562       absl::StrSplit(einsum_config, ',');
563   if (main_split.size() != 2) {
564     return InvalidArgument("Expected one \",\" in einsum_config.");
565   }
566 
567   auto maybe_invalid_character = [](char d) {
568     if (absl::ascii_isalpha(d)) {
569       return Status::OK();
570     }
571     if (d == '.') {
572       return InvalidArgument("Unsupported \".\" in einsum config.");
573     }
574     return InvalidArgument("Unexpected character in einsum config.");
575   };
576 
577   auto string_config_to_numeric =
578       [&](absl::string_view config, bool is_input_config, int64 input_rank,
579           int64 ellipsis_rank,
580           std::vector<int64>* numeric_config) -> StatusOr<int64> {
581     std::vector<absl::string_view> splits = absl::StrSplit(config, "...");
582     if (splits.empty()) {
583       return ellipsis_rank;
584     }
585     if (splits.size() > 2) {
586       return InvalidArgument("Too many ellipses (\"...\") in einsum config.");
587     }
588     // There is one split if we don't have an ellipsis, and two splits if we do.
589     const bool has_ellipsis = splits.size() > 1;
590     // We only compute ellipsis_rank for input configs.
591     if (is_input_config && has_ellipsis) {
592       // ellipsis_rank is input rank minus the number of named labels.
593       ellipsis_rank =
594           input_rank - static_cast<int64>(splits[0].size() + splits[1].size());
595       if (ellipsis_rank < 0) {
596         return InvalidArgument(
597             "Too few dimensions in the input for the given einsum config.");
598       }
599     }
600     for (char d : splits[0]) {
601       TF_RETURN_IF_ERROR(maybe_invalid_character(d));
602       numeric_config->push_back(static_cast<int64>(d));
603     }
604     if (has_ellipsis) {
605       // For input configs, we use the value of ellipsis_rank we just computed.
606       // For output config, we use the existing value of ellipsis_rank.
607       for (int64 i = ellipsis_rank; i > 0; --i) {
608         numeric_config->push_back(-i);
609       }
610       for (char d : splits[1]) {
611         TF_RETURN_IF_ERROR(maybe_invalid_character(d));
612         numeric_config->push_back(static_cast<int64>(d));
613       }
614     }
615     return ellipsis_rank;
616   };
617 
618   TF_ASSIGN_OR_RETURN(
619       const int64 x_ellipsis_rank,
620       string_config_to_numeric(main_split[0],
621                                /*is_input_config=*/true, x_rank,
622                                /*ellipsis_rank=*/0, &einsum_config_numeric[0]));
623 
624   std::vector<absl::string_view> y_output_split =
625       absl::StrSplit(main_split[1], "->");
626   if (y_output_split.size() != 2) {
627     return InvalidArgument("Expected one \"->\" in einsum_config.");
628   }
629 
630   TF_ASSIGN_OR_RETURN(
631       const int64 y_ellipsis_rank,
632       string_config_to_numeric(y_output_split[0],
633                                /*is_input_config=*/true, y_rank,
634                                /*ellipsis_rank=*/0, &einsum_config_numeric[1]));
635 
636   // Replace ellipsis in output_config with numeric labels with the same
637   // ellipsis rank as in the inputs.
638   // Note: This implementation doesn't support different-rank broadcasting.
639   TF_ASSIGN_OR_RETURN(
640       std::ignore,
641       string_config_to_numeric(
642           y_output_split[1], /*is_input_config=*/false,
643           /*input_rank=*/0,
644           /*ellipsis_rank=*/std::max(x_ellipsis_rank, y_ellipsis_rank),
645           &einsum_config_numeric[2]));
646   return einsum_config_numeric;
647 }
648 
NormalizeEinsumString(absl::string_view einsum_config)649 std::string NormalizeEinsumString(absl::string_view einsum_config) {
650   if (einsum_config.find("->") != einsum_config.npos) {
651     return "";
652   }
653   bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
654   std::map<char, int64> chars;
655   for (char c : einsum_config) {
656     if (absl::ascii_isalpha(c)) {
657       ++chars[c];
658     }
659   }
660   std::string new_config(einsum_config.begin(), einsum_config.end());
661   new_config.append("->");
662   if (has_ellipsis) {
663     new_config.append("...");
664   }
665   for (auto p : chars) {
666     if (p.second == 1) {
667       new_config.push_back(p.first);
668     }
669   }
670   return new_config;
671 }
672 
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision)673 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
674              PrecisionConfig::Precision precision) {
675   XlaBuilder* builder = x.builder();
676   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
677     auto new_config = NormalizeEinsumString(einsum_config);
678     if (!new_config.empty()) {
679       return Einsum(x, y, new_config, precision);
680     }
681     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
682     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
683     TF_ASSIGN_OR_RETURN(
684         auto einsum_config_numeric,
685         ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
686     return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
687                   einsum_config_numeric[2], precision);
688   });
689 }
690 
Einsum(XlaOp x,absl::string_view einsum_config,PrecisionConfig::Precision precision)691 XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
692              PrecisionConfig::Precision precision) {
693   return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
694                 precision);
695 }
696 
TransposeInMinorDims(XlaOp x)697 XlaOp TransposeInMinorDims(XlaOp x) {
698   XlaBuilder* builder = x.builder();
699   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
700     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
701     const int64 n_dims = shape.rank();
702     TF_RET_CHECK(n_dims >= 2);
703     std::vector<int64> permutation(n_dims);
704     std::iota(permutation.begin(), permutation.end(), 0);
705     std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
706     return Transpose(x, permutation);
707   });
708 }
709 
MaybeTransposeInMinorDims(XlaOp x,bool transpose)710 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
711   return transpose ? TransposeInMinorDims(x) : x;
712 }
713 
714 }  // namespace xla
715