1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <limits>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/types/optional.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
36 #include "tensorflow/compiler/xla/client/lib/constants.h"
37 #include "tensorflow/compiler/xla/client/lib/slicing.h"
38 #include "tensorflow/compiler/xla/client/xla_builder.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46
47 namespace xla {
48
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64 m,int64 n)49 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
50 int64 n) {
51 auto a = Iota(builder, U32, m);
52 auto b = Iota(builder, U32, n);
53 auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
54 return ConvertElementType(indicator, type);
55 }
56
GetDiagonalMask(XlaOp x,int diagonal)57 XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
58 XlaBuilder* builder = x.builder();
59 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
60 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
61 auto n_dims = static_cast<int32>(shape.rank());
62 TF_RET_CHECK(n_dims >= 2);
63 auto m = shape.dimensions(n_dims - 2);
64 auto n = shape.dimensions(n_dims - 1);
65 absl::Span<const int64> major_dims =
66 AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
67 auto a = Iota(builder, S32, n);
68 auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
69 auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
70 auto mask = Broadcast(indicator, major_dims);
71 return mask;
72 });
73 }
74
GetMatrixDiagonal(XlaOp x,int k)75 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
76 XlaBuilder* builder = x.builder();
77 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
78 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
79 auto n_dims = static_cast<int32>(shape.rank());
80 TF_RET_CHECK(n_dims >= 2);
81 const int64 m = shape.dimensions(n_dims - 2);
82 const int64 n = shape.dimensions(n_dims - 1);
83
84 if (k <= -m || k >= n) {
85 auto zero_size_shape = shape;
86 zero_size_shape.DeleteDimension(n_dims - 1);
87 zero_size_shape.set_dimensions(n_dims - 2, 0);
88 return ConstantLiteral(builder, Literal{zero_size_shape});
89 }
90 auto mask = GetDiagonalMask(x, k);
91
92 int64 reduce_dim = n_dims - 1;
93 if ((k == 0 && m >= n) || k < 0) {
94 reduce_dim = n_dims - 2;
95 }
96 auto result = Reduce(
97 Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
98 CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
99 {reduce_dim});
100 // k == 0, we can save one slice op.
101 if (k == 0) {
102 return result;
103 }
104 return SliceInMinorDims(result, {0},
105 {k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
106 });
107 }
108
GetMatrixDiagonalViaGather(XlaOp x,int k)109 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) {
110 XlaBuilder* builder = x.builder();
111 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
112 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
113 auto n_dims = static_cast<int32>(shape.rank());
114 TF_RET_CHECK(n_dims >= 2);
115 const int64 m = shape.dimensions(n_dims - 2);
116 const int64 n = shape.dimensions(n_dims - 1);
117
118 // The start_indices has a shape of {diag_len, 2}, and each pair of value in
119 // its dimension 1 represents the (row, col) of the diagonal. We set
120 // index_vector_dim to 1 and make start_index_map and collapsed_slice_dims
121 // contain the same two dimension indices. This makes sure that the (row,
122 // col) pairs in start_indices are propagated to the indices for the two
123 // collapsed dimensions in the operand indices through start_index_map.
124 const int64 num_index_dims = 2;
125 const int64 axis = n_dims - num_index_dims;
126
127 // Calculate the indices of diagonal part with offset k.
128 const int64 diag_len =
129 std::max(std::min(m + std::min(k, 0), n - std::max(k, 0)), int64{0});
130 XlaOp diag_base_indices = BroadcastInDim(Iota(builder, S32, diag_len),
131 {diag_len, num_index_dims}, {0});
132 XlaOp diag_offset =
133 Broadcast(ConstantR1<int>(builder, {std::max(-k, 0), std::max(k, 0)}),
134 {diag_len});
135 XlaOp start_indices = Add(diag_base_indices, diag_offset);
136
137 // Example of a 3D diag-part extracting diagonal part with offset=1 out of a
138 // tensor of shape [2,5,4].
139 //
140 // operand = s32[2,5,4] parameter(0)
141 // indices = s32[3,2] parameter(1)
142 // gather = s32[2,3] gather(operand, indices),
143 // offset_dims={0},
144 // collapsed_slice_dims={1,2},
145 // start_index_map={1,2},
146 // index_vector_dim=1,
147 // slice_sizes={2, 1, 1}
148
149 xla::GatherDimensionNumbers dim_numbers;
150 std::vector<int64> slice_sizes;
151 slice_sizes.reserve(n_dims);
152 for (int64 i = 0; i < n_dims; i++) {
153 int64 window_bound;
154 if (axis <= i) {
155 dim_numbers.add_collapsed_slice_dims(i);
156 dim_numbers.add_start_index_map(i);
157 window_bound = (shape.dimensions(i) != 0) ? 1 : 0;
158 } else {
159 dim_numbers.add_offset_dims(i);
160 window_bound = shape.dimensions(i);
161 }
162 slice_sizes.push_back(window_bound);
163 }
164
165 dim_numbers.set_index_vector_dim(1);
166
167 return Gather(x, start_indices, dim_numbers, slice_sizes,
168 /*indices_are_sorted=*/true);
169 });
170 }
171
SetMatrixDiagonal(XlaOp matrix,XlaOp diag,int k)172 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
173 XlaBuilder* builder = matrix.builder();
174 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
175 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
176 TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
177 auto n_dims = static_cast<int32>(shape.rank());
178 TF_RET_CHECK(n_dims >= 2);
179 const int64 m = shape.dimensions(n_dims - 2);
180 const int64 n = shape.dimensions(n_dims - 1);
181 const int64 d = diag_shape.dimensions(n_dims - 2);
182 std::vector<int64> broadcast_dims(n_dims - 1);
183 absl::c_iota(broadcast_dims, 0);
184 int64 pad_high = m - d;
185 if (k < 0) {
186 ++(broadcast_dims.back());
187 pad_high = n - d;
188 }
189
190 if (pad_high != 0) {
191 PaddingConfig padding_config;
192 for (xla::int64 i = 0; i < diag_shape.rank() - 1; ++i) {
193 auto* dims = padding_config.add_dimensions();
194 dims->set_edge_padding_low(0);
195 dims->set_interior_padding(0);
196 dims->set_edge_padding_high(0);
197 }
198 auto* dims = padding_config.add_dimensions();
199 dims->set_edge_padding_low(0);
200 dims->set_interior_padding(0);
201 dims->set_edge_padding_high(pad_high);
202 diag = Pad(diag, ScalarLike(diag, 0), padding_config);
203 }
204
205 return Select(GetDiagonalMask(matrix, k),
206 BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
207 matrix);
208 });
209 }
210
TriangleMask(XlaOp x,int diagonal)211 XlaOp TriangleMask(XlaOp x, int diagonal) {
212 XlaBuilder* builder = x.builder();
213 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
214 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
215 const int64 n_dims = shape.rank();
216 TF_RET_CHECK(n_dims >= 2);
217 const int64 m = shape.dimensions(n_dims - 2);
218 const int64 n = shape.dimensions(n_dims - 1);
219 absl::Span<const int64> major_dims =
220 AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
221 auto a = Iota(builder, S32, n);
222 auto b = Iota(builder, S32, m) + ConstantR0<int32>(builder, diagonal);
223 XlaOp indicator;
224 indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
225 return Broadcast(indicator, major_dims);
226 });
227 }
228
Triangle(XlaOp x,bool lower)229 XlaOp Triangle(XlaOp x, bool lower) {
230 return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
231 : Select(TriangleMask(x, -1), ZerosLike(x), x);
232 }
233
UpperTriangle(XlaOp x)234 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
235
LowerTriangle(XlaOp x)236 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
237
238 namespace {
EinsumDiagonalLabels(absl::Span<const int64> config)239 absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
240 absl::Span<const int64> config) {
241 std::vector<int64> unique_labels;
242 std::vector<int64> reduce_dims;
243 std::vector<int64> broadcast_dims;
244 for (auto label = config.begin(); label != config.end(); ++label) {
245 auto first_label = absl::c_find(config, *label);
246 auto dim = label - config.begin();
247 if (first_label == label) {
248 unique_labels.push_back(*label);
249 broadcast_dims.push_back(dim);
250 } else {
251 reduce_dims.push_back(dim);
252 }
253 }
254 if (unique_labels.size() == config.size()) {
255 return absl::nullopt;
256 }
257 return {{unique_labels, reduce_dims, broadcast_dims}};
258 }
259
260 // Masks a tensor such that only the diagonal of repeated indices are non-zero.
261 // The result of this can be used to create a diagonal matrix with an identity
262 // reduction.
EinsumDiagonalMask(XlaOp x,absl::Span<const int64> config)263 xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64> config) {
264 XlaBuilder* builder = x.builder();
265 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
266 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
267 Shape iota_shape = x_shape;
268 iota_shape.set_element_type(S32);
269 XlaOp mask = ConstantR0(builder, true);
270
271 for (auto label = config.begin(); label != config.end(); ++label) {
272 const int64 dim = label - config.begin();
273 auto first_label = absl::c_find(config, *label);
274 if (first_label != label) {
275 const int64 first_dim = first_label - config.begin();
276 mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
277 Iota(builder, iota_shape, dim)));
278 }
279 }
280 return Select(mask, x, ZerosLike(x));
281 });
282 }
283
EinsumDiagonal(XlaOp x,absl::Span<const int64> config)284 xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
285 XlaBuilder* builder = x.builder();
286 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
287 auto labels = EinsumDiagonalLabels(config);
288 if (!labels) {
289 return x;
290 }
291 auto zero = ScalarLike(x, 0);
292 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
293 return Reduce(EinsumDiagonalMask(x, config), zero,
294 CreateScalarIdentityWithZeroComputation(
295 x_shape.element_type(), builder),
296 labels->at(1));
297 });
298 }
299
EinsumInverseDiagonal(XlaOp x,absl::Span<const int64> config)300 xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64> config) {
301 XlaBuilder* builder = x.builder();
302 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
303 auto labels = EinsumDiagonalLabels(config);
304 if (!labels) {
305 return x;
306 }
307 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
308 std::vector<int64> broadcast_sizes;
309 int64 x_dim = 0;
310 for (auto label = config.begin(); label != config.end(); ++label) {
311 auto first_label = absl::c_find(config, *label);
312 if (first_label == label) {
313 broadcast_sizes.push_back(x_shape.dimensions(x_dim));
314 ++x_dim;
315 } else {
316 broadcast_sizes.push_back(
317 broadcast_sizes[first_label - config.begin()]);
318 }
319 }
320 x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
321 return EinsumDiagonalMask(x, config);
322 });
323 }
324 } // namespace
325
326 namespace {
327 // Helper method to remove dimensions from a shape and dot dimension numbers
328 // used to implement implicit broadcasting.
329 template <typename C>
DeleteDimsFromContainer(absl::Span<const int64> to_delete,Shape * shape,C * batch_dims,C * contracting_dims)330 void DeleteDimsFromContainer(absl::Span<const int64> to_delete, Shape* shape,
331 C* batch_dims, C* contracting_dims) {
332 if (to_delete.empty()) {
333 return;
334 }
335 for (int64 i = to_delete.size() - 1; i >= 0; --i) {
336 int64 dim = to_delete[i];
337 shape->DeleteDimension(dim);
338 for (auto& b : *batch_dims) {
339 if (b > dim) {
340 --b;
341 }
342 }
343 for (auto& c : *contracting_dims) {
344 if (c > dim) {
345 --c;
346 }
347 }
348 }
349 }
350 } // namespace
351
Einsum(xla::XlaOp x,absl::Span<const int64> x_config,xla::XlaOp y,absl::Span<const int64> y_config,absl::Span<const int64> output_config,xla::PrecisionConfig::Precision precision)352 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
353 absl::Span<const int64> y_config,
354 absl::Span<const int64> output_config,
355 xla::PrecisionConfig::Precision precision) {
356 XlaBuilder* builder = x.builder();
357 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
358 auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
359 if (x_diagonal_labels) {
360 return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
361 y_config, output_config, precision);
362 }
363 auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
364 if (y_diagonal_labels) {
365 return Einsum(x, x_config, EinsumDiagonal(y, y_config),
366 y_diagonal_labels->at(0), output_config, precision);
367 }
368 auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
369 if (output_diagonal_labels) {
370 return EinsumInverseDiagonal(
371 Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
372 precision),
373 output_config);
374 }
375
376 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
377 TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
378 const int64 x_rank = x_config.size();
379 const int64 y_rank = y_config.size();
380 const int64 output_rank = output_config.size();
381 absl::flat_hash_set<int64> x_map;
382 absl::flat_hash_set<int64> y_map;
383 absl::flat_hash_set<int64> output_map;
384
385 for (auto d : x_config) {
386 x_map.insert(d);
387 }
388
389 for (auto d : y_config) {
390 y_map.insert(d);
391 }
392
393 for (auto d : output_config) {
394 output_map.insert(d);
395 }
396
397 DotDimensionNumbers dnums;
398 auto is_batch_dim = [&](int64 d) {
399 return x_map.contains(d) && y_map.contains(d) && output_map.contains(d);
400 };
401 auto is_contracting = [&](int64 d) {
402 return x_map.contains(d) && y_map.contains(d);
403 };
404
405 auto rhs_dimension_number = [&](int64 d) {
406 return absl::c_find(y_config, d) - y_config.begin();
407 };
408
409 absl::InlinedVector<int64, 8> rhs_outer_dims;
410 absl::InlinedVector<int64, 8> lhs_outer_dims;
411 absl::InlinedVector<int64, 8> rhs_delete_dims;
412 absl::InlinedVector<int64, 8> lhs_delete_dims;
413 for (int64 i = 0; i < x_rank; ++i) {
414 auto dim_name = x_config[i];
415 const int64 rhs_dim = rhs_dimension_number(dim_name);
416
417 if (is_batch_dim(dim_name)) {
418 if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
419 dnums.add_lhs_batch_dimensions(i);
420 dnums.add_rhs_batch_dimensions(rhs_dim);
421 } else if (x_shape.dimensions(i) == 1) {
422 rhs_outer_dims.push_back(rhs_dim);
423 lhs_delete_dims.push_back(i);
424 } else {
425 lhs_outer_dims.push_back(i);
426 rhs_delete_dims.push_back(rhs_dim);
427 }
428 } else if (is_contracting(dim_name)) {
429 if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
430 dnums.add_lhs_contracting_dimensions(i);
431 dnums.add_rhs_contracting_dimensions(rhs_dim);
432 } else if (x_shape.dimensions(i) == 1) {
433 rhs_outer_dims.push_back(rhs_dim);
434 lhs_delete_dims.push_back(i);
435 } else {
436 lhs_outer_dims.push_back(i);
437 rhs_delete_dims.push_back(rhs_dim);
438 }
439 } else {
440 lhs_outer_dims.push_back(i);
441 }
442 }
443
444 for (int64 i = 0; i < y_rank; ++i) {
445 auto dim_name = y_config[i];
446 if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
447 rhs_outer_dims.push_back(i);
448 }
449 }
450
451 absl::c_sort(rhs_outer_dims);
452 absl::InlinedVector<int64, 8> output_transpose_dims;
453
454 auto output_dimension_number = [&](int64 d) -> absl::optional<int64> {
455 auto pos = absl::c_find(output_config, d);
456 if (pos == output_config.end()) {
457 return absl::nullopt;
458 }
459 return pos - output_config.begin();
460 };
461
462 for (auto d : dnums.lhs_batch_dimensions()) {
463 output_transpose_dims.push_back(*output_dimension_number(x_config[d]));
464 }
465
466 for (auto d : lhs_outer_dims) {
467 if (auto output_dim = output_dimension_number(x_config[d])) {
468 output_transpose_dims.push_back(*output_dim);
469 continue;
470 }
471 lhs_delete_dims.push_back(d);
472 }
473
474 for (auto d : rhs_outer_dims) {
475 if (auto output_dim = output_dimension_number(y_config[d])) {
476 output_transpose_dims.push_back(*output_dim);
477 continue;
478 }
479 rhs_delete_dims.push_back(d);
480 }
481
482 const int64 transpose_rank = output_transpose_dims.size();
483 std::vector<int64> transpose_dims(output_rank);
484 for (int64 i = 0; i < transpose_rank; ++i) {
485 transpose_dims[output_transpose_dims[i]] = i;
486 }
487
488 // Remove ones that where broadcasted from the x and the y shape and adjust
489 // the dimension numbers that are more minor than those dimensions.
490 absl::c_sort(lhs_delete_dims);
491 DeleteDimsFromContainer(lhs_delete_dims, &x_shape,
492 dnums.mutable_lhs_batch_dimensions(),
493 dnums.mutable_lhs_contracting_dimensions());
494
495 absl::c_sort(rhs_delete_dims);
496 DeleteDimsFromContainer(rhs_delete_dims, &y_shape,
497 dnums.mutable_rhs_batch_dimensions(),
498 dnums.mutable_rhs_contracting_dimensions());
499 if (!lhs_delete_dims.empty()) {
500 x = Reduce(x, ScalarLike(x, 0),
501 CreateScalarAddComputation(x_shape.element_type(), builder),
502 lhs_delete_dims);
503 }
504
505 if (!rhs_delete_dims.empty()) {
506 y = Reduce(y, ScalarLike(y, 0),
507 CreateScalarAddComputation(y_shape.element_type(), builder),
508 rhs_delete_dims);
509 }
510
511 PrecisionConfig precision_proto;
512 precision_proto.add_operand_precision(precision);
513 precision_proto.add_operand_precision(precision);
514 auto dot = DotGeneral(x, y, dnums, &precision_proto);
515 dot = Transpose(dot, transpose_dims);
516 if (transpose_rank == output_rank) {
517 return dot;
518 }
519
520 auto is_output_only = [&](int64 d) {
521 return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
522 };
523
524 int64 dot_dim = 0;
525 std::vector<int64> new_dims;
526 new_dims.reserve(output_rank);
527 TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
528 for (auto d : output_config) {
529 if (is_output_only(d)) {
530 new_dims.push_back(1);
531 } else {
532 new_dims.push_back(dot_shape.dimensions(dot_dim));
533 }
534 }
535 return Reshape(dot, new_dims);
536 });
537 }
538
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision)539 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
540 return BatchDot(x, false, y, false, precision);
541 }
542
BatchDot(XlaOp x,bool transpose_x,XlaOp y,bool transpose_y,PrecisionConfig::Precision precision)543 XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
544 PrecisionConfig::Precision precision) {
545 XlaBuilder* builder = x.builder();
546 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
547 std::string string("...mk,...kn->...mn");
548 if (transpose_x) {
549 std::swap(string[3], string[4]);
550 }
551 if (transpose_y) {
552 std::swap(string[6 + 3], string[6 + 4]);
553 }
554 return Einsum(x, y, string, precision);
555 });
556 }
557
ParseEinsumString(absl::string_view einsum_config,int64 x_rank,int64 y_rank)558 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
559 absl::string_view einsum_config, int64 x_rank, int64 y_rank) {
560 std::array<std::vector<int64>, 3> einsum_config_numeric;
561 std::vector<absl::string_view> main_split =
562 absl::StrSplit(einsum_config, ',');
563 if (main_split.size() != 2) {
564 return InvalidArgument("Expected one \",\" in einsum_config.");
565 }
566
567 auto maybe_invalid_character = [](char d) {
568 if (absl::ascii_isalpha(d)) {
569 return Status::OK();
570 }
571 if (d == '.') {
572 return InvalidArgument("Unsupported \".\" in einsum config.");
573 }
574 return InvalidArgument("Unexpected character in einsum config.");
575 };
576
577 auto string_config_to_numeric =
578 [&](absl::string_view config, bool is_input_config, int64 input_rank,
579 int64 ellipsis_rank,
580 std::vector<int64>* numeric_config) -> StatusOr<int64> {
581 std::vector<absl::string_view> splits = absl::StrSplit(config, "...");
582 if (splits.empty()) {
583 return ellipsis_rank;
584 }
585 if (splits.size() > 2) {
586 return InvalidArgument("Too many ellipses (\"...\") in einsum config.");
587 }
588 // There is one split if we don't have an ellipsis, and two splits if we do.
589 const bool has_ellipsis = splits.size() > 1;
590 // We only compute ellipsis_rank for input configs.
591 if (is_input_config && has_ellipsis) {
592 // ellipsis_rank is input rank minus the number of named labels.
593 ellipsis_rank =
594 input_rank - static_cast<int64>(splits[0].size() + splits[1].size());
595 if (ellipsis_rank < 0) {
596 return InvalidArgument(
597 "Too few dimensions in the input for the given einsum config.");
598 }
599 }
600 for (char d : splits[0]) {
601 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
602 numeric_config->push_back(static_cast<int64>(d));
603 }
604 if (has_ellipsis) {
605 // For input configs, we use the value of ellipsis_rank we just computed.
606 // For output config, we use the existing value of ellipsis_rank.
607 for (int64 i = ellipsis_rank; i > 0; --i) {
608 numeric_config->push_back(-i);
609 }
610 for (char d : splits[1]) {
611 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
612 numeric_config->push_back(static_cast<int64>(d));
613 }
614 }
615 return ellipsis_rank;
616 };
617
618 TF_ASSIGN_OR_RETURN(
619 const int64 x_ellipsis_rank,
620 string_config_to_numeric(main_split[0],
621 /*is_input_config=*/true, x_rank,
622 /*ellipsis_rank=*/0, &einsum_config_numeric[0]));
623
624 std::vector<absl::string_view> y_output_split =
625 absl::StrSplit(main_split[1], "->");
626 if (y_output_split.size() != 2) {
627 return InvalidArgument("Expected one \"->\" in einsum_config.");
628 }
629
630 TF_ASSIGN_OR_RETURN(
631 const int64 y_ellipsis_rank,
632 string_config_to_numeric(y_output_split[0],
633 /*is_input_config=*/true, y_rank,
634 /*ellipsis_rank=*/0, &einsum_config_numeric[1]));
635
636 // Replace ellipsis in output_config with numeric labels with the same
637 // ellipsis rank as in the inputs.
638 // Note: This implementation doesn't support different-rank broadcasting.
639 TF_ASSIGN_OR_RETURN(
640 std::ignore,
641 string_config_to_numeric(
642 y_output_split[1], /*is_input_config=*/false,
643 /*input_rank=*/0,
644 /*ellipsis_rank=*/std::max(x_ellipsis_rank, y_ellipsis_rank),
645 &einsum_config_numeric[2]));
646 return einsum_config_numeric;
647 }
648
NormalizeEinsumString(absl::string_view einsum_config)649 std::string NormalizeEinsumString(absl::string_view einsum_config) {
650 if (einsum_config.find("->") != einsum_config.npos) {
651 return "";
652 }
653 bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
654 std::map<char, int64> chars;
655 for (char c : einsum_config) {
656 if (absl::ascii_isalpha(c)) {
657 ++chars[c];
658 }
659 }
660 std::string new_config(einsum_config.begin(), einsum_config.end());
661 new_config.append("->");
662 if (has_ellipsis) {
663 new_config.append("...");
664 }
665 for (auto p : chars) {
666 if (p.second == 1) {
667 new_config.push_back(p.first);
668 }
669 }
670 return new_config;
671 }
672
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision)673 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
674 PrecisionConfig::Precision precision) {
675 XlaBuilder* builder = x.builder();
676 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
677 auto new_config = NormalizeEinsumString(einsum_config);
678 if (!new_config.empty()) {
679 return Einsum(x, y, new_config, precision);
680 }
681 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
682 TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
683 TF_ASSIGN_OR_RETURN(
684 auto einsum_config_numeric,
685 ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
686 return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
687 einsum_config_numeric[2], precision);
688 });
689 }
690
Einsum(XlaOp x,absl::string_view einsum_config,PrecisionConfig::Precision precision)691 XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
692 PrecisionConfig::Precision precision) {
693 return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
694 precision);
695 }
696
TransposeInMinorDims(XlaOp x)697 XlaOp TransposeInMinorDims(XlaOp x) {
698 XlaBuilder* builder = x.builder();
699 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
700 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
701 const int64 n_dims = shape.rank();
702 TF_RET_CHECK(n_dims >= 2);
703 std::vector<int64> permutation(n_dims);
704 std::iota(permutation.begin(), permutation.end(), 0);
705 std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
706 return Transpose(x, permutation);
707 });
708 }
709
MaybeTransposeInMinorDims(XlaOp x,bool transpose)710 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
711 return transpose ? TransposeInMinorDims(x) : x;
712 }
713
714 } // namespace xla
715