1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <limits>
21 #include <numeric>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
36 #include "tensorflow/compiler/xla/client/lib/constants.h"
37 #include "tensorflow/compiler/xla/client/lib/slicing.h"
38 #include "tensorflow/compiler/xla/client/xla_builder.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47
48 namespace xla {
49
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64_t m,int64_t n)50 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m,
51 int64_t n) {
52 auto a = Iota(builder, U32, m);
53 auto b = Iota(builder, U32, n);
54 auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
55 return ConvertElementType(indicator, type);
56 }
57
GetDiagonalMask(XlaOp x,int diagonal)58 XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
59 XlaBuilder* builder = x.builder();
60 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
61 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
62 auto n_dims = static_cast<int32_t>(shape.rank());
63 TF_RET_CHECK(n_dims >= 2);
64 auto m = shape.dimensions(n_dims - 2);
65 auto n = shape.dimensions(n_dims - 1);
66 absl::Span<const int64_t> major_dims =
67 shape.dimensions().subspan(/*pos=*/0, /*len=*/n_dims - 2);
68 auto a = Iota(builder, S32, n);
69 auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
70 auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
71 auto mask = Broadcast(indicator, major_dims);
72 return mask;
73 });
74 }
75
GetMatrixDiagonal(XlaOp x,int k)76 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
77 XlaBuilder* builder = x.builder();
78 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
80 auto n_dims = static_cast<int32_t>(shape.rank());
81 TF_RET_CHECK(n_dims >= 2);
82 const int64_t m = shape.dimensions(n_dims - 2);
83 const int64_t n = shape.dimensions(n_dims - 1);
84
85 if (k <= -m || k >= n) {
86 auto zero_size_shape = shape;
87 zero_size_shape.DeleteDimension(n_dims - 1);
88 zero_size_shape.set_dimensions(n_dims - 2, 0);
89 return ConstantLiteral(builder, Literal{zero_size_shape});
90 }
91 auto mask = GetDiagonalMask(x, k);
92
93 int64_t reduce_dim = n_dims - 1;
94 if ((k == 0 && m >= n) || k < 0) {
95 reduce_dim = n_dims - 2;
96 }
97 auto result = Reduce(
98 Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
99 CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
100 {reduce_dim});
101 // k == 0, we can save one slice op.
102 if (k == 0) {
103 return result;
104 }
105 return SliceInMinorDims(result, {0},
106 {k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
107 });
108 }
109
GetMatrixDiagonalViaGather(XlaOp x,int k)110 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k) {
111 XlaBuilder* builder = x.builder();
112 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
114 auto n_dims = static_cast<int32_t>(shape.rank());
115 TF_RET_CHECK(n_dims >= 2);
116 const int64_t m = shape.dimensions(n_dims - 2);
117 const int64_t n = shape.dimensions(n_dims - 1);
118
119 // The start_indices has a shape of {diag_len, 2}, and each pair of value in
120 // its dimension 1 represents the (row, col) of the diagonal. We set
121 // index_vector_dim to 1 and make start_index_map and collapsed_slice_dims
122 // contain the same two dimension indices. This makes sure that the (row,
123 // col) pairs in start_indices are propagated to the indices for the two
124 // collapsed dimensions in the operand indices through start_index_map.
125 const int64_t num_index_dims = 2;
126 const int64_t axis = n_dims - num_index_dims;
127
128 // Calculate the indices of diagonal part with offset k.
129 const int64_t diag_len =
130 std::max(std::min(m + std::min(k, 0), n - std::max(k, 0)), int64_t{0});
131 XlaOp diag_base_indices = BroadcastInDim(Iota(builder, S32, diag_len),
132 {diag_len, num_index_dims}, {0});
133 XlaOp diag_offset =
134 Broadcast(ConstantR1<int>(builder, {std::max(-k, 0), std::max(k, 0)}),
135 {diag_len});
136 XlaOp start_indices = Add(diag_base_indices, diag_offset);
137
138 // Example of a 3D diag-part extracting diagonal part with offset=1 out of a
139 // tensor of shape [2,5,4].
140 //
141 // operand = s32[2,5,4] parameter(0)
142 // indices = s32[3,2] parameter(1)
143 // gather = s32[2,3] gather(operand, indices),
144 // offset_dims={0},
145 // collapsed_slice_dims={1,2},
146 // start_index_map={1,2},
147 // index_vector_dim=1,
148 // slice_sizes={2, 1, 1}
149
150 xla::GatherDimensionNumbers dim_numbers;
151 std::vector<int64_t> slice_sizes;
152 slice_sizes.reserve(n_dims);
153 for (int64_t i = 0; i < n_dims; i++) {
154 int64_t window_bound;
155 if (axis <= i) {
156 dim_numbers.add_collapsed_slice_dims(i);
157 dim_numbers.add_start_index_map(i);
158 window_bound = (shape.dimensions(i) != 0) ? 1 : 0;
159 } else {
160 dim_numbers.add_offset_dims(i);
161 window_bound = shape.dimensions(i);
162 }
163 slice_sizes.push_back(window_bound);
164 }
165
166 dim_numbers.set_index_vector_dim(1);
167
168 return Gather(x, start_indices, dim_numbers, slice_sizes,
169 /*indices_are_sorted=*/true);
170 });
171 }
172
SetMatrixDiagonal(XlaOp matrix,XlaOp diag,int k)173 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
174 XlaBuilder* builder = matrix.builder();
175 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
176 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
177 TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
178 auto n_dims = static_cast<int32_t>(shape.rank());
179 TF_RET_CHECK(n_dims >= 2);
180 const int64_t m = shape.dimensions(n_dims - 2);
181 const int64_t n = shape.dimensions(n_dims - 1);
182 const int64_t d = diag_shape.dimensions(n_dims - 2);
183 std::vector<int64_t> broadcast_dims(n_dims - 1);
184 absl::c_iota(broadcast_dims, 0);
185 int64_t pad_high = m - d;
186 if (k < 0) {
187 ++(broadcast_dims.back());
188 pad_high = n - d;
189 }
190
191 if (pad_high != 0) {
192 PaddingConfig padding_config;
193 for (int64_t i = 0; i < diag_shape.rank() - 1; ++i) {
194 auto* dims = padding_config.add_dimensions();
195 dims->set_edge_padding_low(0);
196 dims->set_interior_padding(0);
197 dims->set_edge_padding_high(0);
198 }
199 auto* dims = padding_config.add_dimensions();
200 dims->set_edge_padding_low(0);
201 dims->set_interior_padding(0);
202 dims->set_edge_padding_high(pad_high);
203 diag = Pad(diag, ScalarLike(diag, 0), padding_config);
204 }
205
206 return Select(GetDiagonalMask(matrix, k),
207 BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
208 matrix);
209 });
210 }
211
TriangleMask(XlaOp x,int diagonal)212 XlaOp TriangleMask(XlaOp x, int diagonal) {
213 XlaBuilder* builder = x.builder();
214 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
215 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
216 const int64_t n_dims = shape.rank();
217 TF_RET_CHECK(n_dims >= 2);
218 const int64_t m = shape.dimensions(n_dims - 2);
219 const int64_t n = shape.dimensions(n_dims - 1);
220 absl::Span<const int64_t> major_dims =
221 shape.dimensions().subspan(/*pos=*/0, /*len=*/n_dims - 2);
222 auto a = Iota(builder, S32, n);
223 auto b = Iota(builder, S32, m) + ConstantR0<int32_t>(builder, diagonal);
224 XlaOp indicator;
225 indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
226 return Broadcast(indicator, major_dims);
227 });
228 }
229
Triangle(XlaOp x,bool lower)230 XlaOp Triangle(XlaOp x, bool lower) {
231 return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
232 : Select(TriangleMask(x, -1), ZerosLike(x), x);
233 }
234
UpperTriangle(XlaOp x)235 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
236
LowerTriangle(XlaOp x)237 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
238
Symmetrize(XlaOp x,bool lower)239 XlaOp Symmetrize(XlaOp x, bool lower) {
240 XlaBuilder* builder = x.builder();
241 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
242 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
243 if (shape.rank() < 2) {
244 return InvalidArgument(
245 "Argument to symmetrize must have >= 2 dimensions, got %s",
246 shape.ToString());
247 }
248 const int64_t m = ShapeUtil::GetDimension(shape, -2);
249 const int64_t n = ShapeUtil::GetDimension(shape, -1);
250 if (m != n) {
251 return InvalidArgument(
252 "The two most minor dimensions of the argument to symmetrize must be "
253 "equal size, got %s",
254 shape.ToString());
255 }
256 auto mask = lower ? TriangleMask(x, 0) : Not(TriangleMask(x, -1));
257 if (primitive_util::IsComplexType(shape.element_type())) {
258 auto re = Select(mask, Real(x), TransposeInMinorDims(Real(x)));
259 auto im_mask = lower ? TriangleMask(x, -1) : Not(TriangleMask(x, 0));
260 auto im = Select(im_mask, Imag(x), ZerosLike(Imag(x)));
261 im = Select(mask, im, -TransposeInMinorDims(im));
262 return Complex(re, im);
263 } else {
264 return Select(mask, x, TransposeInMinorDims(x));
265 }
266 });
267 }
268
269 namespace {
EinsumDiagonalLabels(absl::Span<const int64_t> config)270 std::optional<std::array<std::vector<int64_t>, 3>> EinsumDiagonalLabels(
271 absl::Span<const int64_t> config) {
272 std::vector<int64_t> unique_labels;
273 std::vector<int64_t> reduce_dims;
274 std::vector<int64_t> broadcast_dims;
275 for (auto label = config.begin(); label != config.end(); ++label) {
276 auto first_label = absl::c_find(config, *label);
277 auto dim = label - config.begin();
278 if (first_label == label) {
279 unique_labels.push_back(*label);
280 broadcast_dims.push_back(dim);
281 } else {
282 reduce_dims.push_back(dim);
283 }
284 }
285 if (unique_labels.size() == config.size()) {
286 return std::nullopt;
287 }
288 return {{unique_labels, reduce_dims, broadcast_dims}};
289 }
290
291 // Masks a tensor such that only the diagonal of repeated indices are non-zero.
292 // The result of this can be used to create a diagonal matrix with an identity
293 // reduction.
EinsumDiagonalMask(XlaOp x,absl::Span<const int64_t> config)294 xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64_t> config) {
295 XlaBuilder* builder = x.builder();
296 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
297 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
298 Shape iota_shape = x_shape;
299 iota_shape.set_element_type(S32);
300 XlaOp mask = ConstantR0(builder, true);
301
302 for (auto label = config.begin(); label != config.end(); ++label) {
303 const int64_t dim = label - config.begin();
304 auto first_label = absl::c_find(config, *label);
305 if (first_label != label) {
306 const int64_t first_dim = first_label - config.begin();
307 mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
308 Iota(builder, iota_shape, dim)));
309 }
310 }
311 return Select(mask, x, ZerosLike(x));
312 });
313 }
314
EinsumDiagonal(XlaOp x,absl::Span<const int64_t> config)315 xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64_t> config) {
316 XlaBuilder* builder = x.builder();
317 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
318 auto labels = EinsumDiagonalLabels(config);
319 if (!labels) {
320 return x;
321 }
322 auto zero = ScalarLike(x, 0);
323 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
324 return Reduce(EinsumDiagonalMask(x, config), zero,
325 CreateScalarIdentityWithZeroComputation(
326 x_shape.element_type(), builder),
327 labels->at(1));
328 });
329 }
330
EinsumInverseDiagonal(XlaOp x,absl::Span<const int64_t> config)331 xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64_t> config) {
332 XlaBuilder* builder = x.builder();
333 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
334 auto labels = EinsumDiagonalLabels(config);
335 if (!labels) {
336 return x;
337 }
338 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
339 std::vector<int64_t> broadcast_sizes;
340 int64_t x_dim = 0;
341 for (auto label = config.begin(); label != config.end(); ++label) {
342 auto first_label = absl::c_find(config, *label);
343 if (first_label == label) {
344 broadcast_sizes.push_back(x_shape.dimensions(x_dim));
345 ++x_dim;
346 } else {
347 broadcast_sizes.push_back(
348 broadcast_sizes[first_label - config.begin()]);
349 }
350 }
351 x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
352 return EinsumDiagonalMask(x, config);
353 });
354 }
355 } // namespace
356
357 namespace {
358 // Helper method to remove dimensions from a shape and dot dimension numbers
359 // used to implement implicit broadcasting.
360 template <typename C>
DeleteDimsFromContainer(absl::Span<const int64_t> to_delete,Shape * shape,C * batch_dims,C * contracting_dims)361 void DeleteDimsFromContainer(absl::Span<const int64_t> to_delete, Shape* shape,
362 C* batch_dims, C* contracting_dims) {
363 if (to_delete.empty()) {
364 return;
365 }
366 for (int64_t i = to_delete.size() - 1; i >= 0; --i) {
367 int64_t dim = to_delete[i];
368 shape->DeleteDimension(dim);
369 for (auto& b : *batch_dims) {
370 if (b > dim) {
371 --b;
372 }
373 }
374 for (auto& c : *contracting_dims) {
375 if (c > dim) {
376 --c;
377 }
378 }
379 }
380 }
381 } // namespace
382
Einsum(xla::XlaOp x,absl::Span<const int64_t> x_config,xla::XlaOp y,absl::Span<const int64_t> y_config,absl::Span<const int64_t> output_config,xla::PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)383 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64_t> x_config,
384 xla::XlaOp y, absl::Span<const int64_t> y_config,
385 absl::Span<const int64_t> output_config,
386 xla::PrecisionConfig::Precision precision,
387 std::optional<PrimitiveType> preferred_element_type) {
388 XlaBuilder* builder = x.builder();
389 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
390 auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
391 if (x_diagonal_labels) {
392 return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
393 y_config, output_config, precision, preferred_element_type);
394 }
395 auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
396 if (y_diagonal_labels) {
397 return Einsum(x, x_config, EinsumDiagonal(y, y_config),
398 y_diagonal_labels->at(0), output_config, precision,
399 preferred_element_type);
400 }
401 auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
402 if (output_diagonal_labels) {
403 return EinsumInverseDiagonal(
404 Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
405 precision, preferred_element_type),
406 output_config);
407 }
408
409 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
410 TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
411 const int64_t x_rank = x_config.size();
412 const int64_t y_rank = y_config.size();
413 const int64_t output_rank = output_config.size();
414 absl::flat_hash_set<int64_t> x_map;
415 absl::flat_hash_set<int64_t> y_map;
416 absl::flat_hash_set<int64_t> output_map;
417
418 for (auto d : x_config) {
419 x_map.insert(d);
420 }
421
422 for (auto d : y_config) {
423 y_map.insert(d);
424 }
425
426 for (auto d : output_config) {
427 output_map.insert(d);
428 }
429
430 DotDimensionNumbers dnums;
431 auto is_batch_dim = [&](int64_t d) {
432 return x_map.contains(d) && y_map.contains(d) && output_map.contains(d);
433 };
434 auto is_contracting = [&](int64_t d) {
435 return x_map.contains(d) && y_map.contains(d);
436 };
437
438 auto rhs_dimension_number = [&](int64_t d) {
439 return absl::c_find(y_config, d) - y_config.begin();
440 };
441
442 absl::InlinedVector<int64_t, 8> rhs_outer_dims;
443 absl::InlinedVector<int64_t, 8> lhs_outer_dims;
444 absl::InlinedVector<int64_t, 8> rhs_delete_dims;
445 absl::InlinedVector<int64_t, 8> lhs_delete_dims;
446 for (int64_t i = 0; i < x_rank; ++i) {
447 auto dim_name = x_config[i];
448 const int64_t rhs_dim = rhs_dimension_number(dim_name);
449
450 if (is_batch_dim(dim_name)) {
451 if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
452 dnums.add_lhs_batch_dimensions(i);
453 dnums.add_rhs_batch_dimensions(rhs_dim);
454 } else if (x_shape.dimensions(i) == 1) {
455 rhs_outer_dims.push_back(rhs_dim);
456 lhs_delete_dims.push_back(i);
457 } else {
458 lhs_outer_dims.push_back(i);
459 rhs_delete_dims.push_back(rhs_dim);
460 }
461 } else if (is_contracting(dim_name)) {
462 if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
463 dnums.add_lhs_contracting_dimensions(i);
464 dnums.add_rhs_contracting_dimensions(rhs_dim);
465 } else if (x_shape.dimensions(i) == 1) {
466 rhs_outer_dims.push_back(rhs_dim);
467 lhs_delete_dims.push_back(i);
468 } else {
469 lhs_outer_dims.push_back(i);
470 rhs_delete_dims.push_back(rhs_dim);
471 }
472 } else {
473 lhs_outer_dims.push_back(i);
474 }
475 }
476
477 for (int64_t i = 0; i < y_rank; ++i) {
478 auto dim_name = y_config[i];
479 if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
480 rhs_outer_dims.push_back(i);
481 }
482 }
483
484 absl::c_sort(rhs_outer_dims);
485 absl::InlinedVector<int64_t, 8> output_transpose_dims;
486
487 auto output_dimension_number = [&](int64_t d) -> std::optional<int64_t> {
488 auto pos = absl::c_find(output_config, d);
489 if (pos == output_config.end()) {
490 return std::nullopt;
491 }
492 return pos - output_config.begin();
493 };
494
495 for (auto d : dnums.lhs_batch_dimensions()) {
496 output_transpose_dims.push_back(*output_dimension_number(x_config[d]));
497 }
498
499 for (auto d : lhs_outer_dims) {
500 if (auto output_dim = output_dimension_number(x_config[d])) {
501 output_transpose_dims.push_back(*output_dim);
502 continue;
503 }
504 lhs_delete_dims.push_back(d);
505 }
506
507 for (auto d : rhs_outer_dims) {
508 if (auto output_dim = output_dimension_number(y_config[d])) {
509 output_transpose_dims.push_back(*output_dim);
510 continue;
511 }
512 rhs_delete_dims.push_back(d);
513 }
514
515 const int64_t transpose_rank = output_transpose_dims.size();
516 std::vector<int64_t> transpose_dims(output_rank);
517 for (int64_t i = 0; i < transpose_rank; ++i) {
518 transpose_dims[output_transpose_dims[i]] = i;
519 }
520
521 // Remove ones that where broadcasted from the x and the y shape and adjust
522 // the dimension numbers that are more minor than those dimensions.
523 absl::c_sort(lhs_delete_dims);
524 DeleteDimsFromContainer(lhs_delete_dims, &x_shape,
525 dnums.mutable_lhs_batch_dimensions(),
526 dnums.mutable_lhs_contracting_dimensions());
527
528 absl::c_sort(rhs_delete_dims);
529 DeleteDimsFromContainer(rhs_delete_dims, &y_shape,
530 dnums.mutable_rhs_batch_dimensions(),
531 dnums.mutable_rhs_contracting_dimensions());
532 if (!lhs_delete_dims.empty()) {
533 x = Reduce(x, ScalarLike(x, 0),
534 CreateScalarAddComputation(x_shape.element_type(), builder),
535 lhs_delete_dims);
536 }
537
538 if (!rhs_delete_dims.empty()) {
539 y = Reduce(y, ScalarLike(y, 0),
540 CreateScalarAddComputation(y_shape.element_type(), builder),
541 rhs_delete_dims);
542 }
543
544 PrecisionConfig precision_proto;
545 precision_proto.add_operand_precision(precision);
546 precision_proto.add_operand_precision(precision);
547 auto dot =
548 DotGeneral(x, y, dnums, &precision_proto, preferred_element_type);
549 dot = Transpose(dot, transpose_dims);
550 if (transpose_rank == output_rank) {
551 return dot;
552 }
553
554 auto is_output_only = [&](int64_t d) {
555 return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
556 };
557
558 int64_t dot_dim = 0;
559 std::vector<int64_t> new_dims;
560 new_dims.reserve(output_rank);
561 TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
562 for (auto d : output_config) {
563 if (is_output_only(d)) {
564 new_dims.push_back(1);
565 } else {
566 new_dims.push_back(dot_shape.dimensions(dot_dim));
567 }
568 }
569 return Reshape(dot, new_dims);
570 });
571 }
572
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)573 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision,
574 std::optional<PrimitiveType> preferred_element_type) {
575 return BatchDot(x, false, y, false, precision, preferred_element_type);
576 }
577
BatchDot(XlaOp x,bool transpose_x,XlaOp y,bool transpose_y,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)578 XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y,
579 PrecisionConfig::Precision precision,
580 std::optional<PrimitiveType> preferred_element_type) {
581 XlaBuilder* builder = x.builder();
582 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
583 std::string string("...mk,...kn->...mn");
584 if (transpose_x) {
585 std::swap(string[3], string[4]);
586 }
587 if (transpose_y) {
588 std::swap(string[6 + 3], string[6 + 4]);
589 }
590 return Einsum(x, y, string, precision, preferred_element_type);
591 });
592 }
593
ParseEinsumString(absl::string_view einsum_config,int64_t x_rank,int64_t y_rank)594 StatusOr<std::array<std::vector<int64_t>, 3>> ParseEinsumString(
595 absl::string_view einsum_config, int64_t x_rank, int64_t y_rank) {
596 std::array<std::vector<int64_t>, 3> einsum_config_numeric;
597 std::vector<absl::string_view> main_split =
598 absl::StrSplit(einsum_config, ',');
599 if (main_split.size() != 2) {
600 return InvalidArgument("Expected one \",\" in einsum_config.");
601 }
602
603 auto maybe_invalid_character = [](char d) {
604 if (absl::ascii_isalpha(d)) {
605 return OkStatus();
606 }
607 if (d == '.') {
608 return InvalidArgument("Unsupported \".\" in einsum config.");
609 }
610 return InvalidArgument("Unexpected character in einsum config.");
611 };
612
613 auto string_config_to_numeric =
614 [&](absl::string_view config, bool is_input_config, int64_t input_rank,
615 int64_t ellipsis_rank,
616 std::vector<int64_t>* numeric_config) -> StatusOr<int64_t> {
617 std::vector<absl::string_view> splits = absl::StrSplit(config, "...");
618 if (splits.empty()) {
619 return ellipsis_rank;
620 }
621 if (splits.size() > 2) {
622 return InvalidArgument("Too many ellipses (\"...\") in einsum config.");
623 }
624 // There is one split if we don't have an ellipsis, and two splits if we do.
625 const bool has_ellipsis = splits.size() > 1;
626 // We only compute ellipsis_rank for input configs.
627 if (is_input_config && has_ellipsis) {
628 // ellipsis_rank is input rank minus the number of named labels.
629 ellipsis_rank = input_rank -
630 static_cast<int64_t>(splits[0].size() + splits[1].size());
631 if (ellipsis_rank < 0) {
632 return InvalidArgument(
633 "Too few dimensions in the input for the given einsum config.");
634 }
635 }
636 for (char d : splits[0]) {
637 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
638 numeric_config->push_back(static_cast<int64_t>(d));
639 }
640 if (has_ellipsis) {
641 // For input configs, we use the value of ellipsis_rank we just computed.
642 // For output config, we use the existing value of ellipsis_rank.
643 for (int64_t i = ellipsis_rank; i > 0; --i) {
644 numeric_config->push_back(-i);
645 }
646 for (char d : splits[1]) {
647 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
648 numeric_config->push_back(static_cast<int64_t>(d));
649 }
650 }
651 return ellipsis_rank;
652 };
653
654 TF_ASSIGN_OR_RETURN(
655 const int64_t x_ellipsis_rank,
656 string_config_to_numeric(main_split[0],
657 /*is_input_config=*/true, x_rank,
658 /*ellipsis_rank=*/0, &einsum_config_numeric[0]));
659
660 std::vector<absl::string_view> y_output_split =
661 absl::StrSplit(main_split[1], "->");
662 if (y_output_split.size() != 2) {
663 return InvalidArgument("Expected one \"->\" in einsum_config.");
664 }
665
666 TF_ASSIGN_OR_RETURN(
667 const int64_t y_ellipsis_rank,
668 string_config_to_numeric(y_output_split[0],
669 /*is_input_config=*/true, y_rank,
670 /*ellipsis_rank=*/0, &einsum_config_numeric[1]));
671
672 // Replace ellipsis in output_config with numeric labels with the same
673 // ellipsis rank as in the inputs.
674 // Note: This implementation doesn't support different-rank broadcasting.
675 TF_ASSIGN_OR_RETURN(
676 std::ignore,
677 string_config_to_numeric(
678 y_output_split[1], /*is_input_config=*/false,
679 /*input_rank=*/0,
680 /*ellipsis_rank=*/std::max(x_ellipsis_rank, y_ellipsis_rank),
681 &einsum_config_numeric[2]));
682 return einsum_config_numeric;
683 }
684
NormalizeEinsumString(absl::string_view einsum_config)685 std::string NormalizeEinsumString(absl::string_view einsum_config) {
686 if (einsum_config.find("->") != einsum_config.npos) {
687 return "";
688 }
689 bool has_ellipsis = einsum_config.find("...") != einsum_config.npos;
690 std::map<char, int64_t> chars;
691 for (char c : einsum_config) {
692 if (absl::ascii_isalpha(c)) {
693 ++chars[c];
694 }
695 }
696 std::string new_config(einsum_config.begin(), einsum_config.end());
697 new_config.append("->");
698 if (has_ellipsis) {
699 new_config.append("...");
700 }
701 for (auto p : chars) {
702 if (p.second == 1) {
703 new_config.push_back(p.first);
704 }
705 }
706 return new_config;
707 }
708
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision,std::optional<PrimitiveType> preferred_element_type)709 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
710 PrecisionConfig::Precision precision,
711 std::optional<PrimitiveType> preferred_element_type) {
712 XlaBuilder* builder = x.builder();
713 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
714 auto new_config = NormalizeEinsumString(einsum_config);
715 if (!new_config.empty()) {
716 return Einsum(x, y, new_config, precision, preferred_element_type);
717 }
718 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
719 TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
720 TF_ASSIGN_OR_RETURN(
721 auto einsum_config_numeric,
722 ParseEinsumString(einsum_config, x_shape.rank(), y_shape.rank()));
723 return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
724 einsum_config_numeric[2], precision, preferred_element_type);
725 });
726 }
727
Einsum(XlaOp x,absl::string_view einsum_config,PrecisionConfig::Precision precision)728 XlaOp Einsum(XlaOp x, absl::string_view einsum_config,
729 PrecisionConfig::Precision precision) {
730 return Einsum(ScalarLike(x, 1), x, absl::StrCat(",", einsum_config),
731 precision);
732 }
733
TransposeInMinorDims(XlaOp x)734 XlaOp TransposeInMinorDims(XlaOp x) {
735 XlaBuilder* builder = x.builder();
736 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
737 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
738 const int64_t n_dims = shape.rank();
739 TF_RET_CHECK(n_dims >= 2);
740 std::vector<int64_t> permutation(n_dims);
741 std::iota(permutation.begin(), permutation.end(), 0);
742 std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
743 return Transpose(x, permutation);
744 });
745 }
746
MaybeTransposeInMinorDims(XlaOp x,bool transpose)747 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
748 return transpose ? TransposeInMinorDims(x) : x;
749 }
750
751 } // namespace xla
752