• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/reference_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 namespace {
40 
41 class DotOperationTest : public ClientLibraryTestBase {
42  public:
43   ErrorSpec error_spec_{0.0001, 1e-5};
44 };
45 
46 using TypesF16F32 = ::testing::Types<
47 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
48     Eigen::half,
49 #endif
50     float>;
51 
52 using TypesF16F32F64 = ::testing::Types<
53 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
54     Eigen::half,
55 #endif
56 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
57     double,
58 #endif
59     float>;
60 
61 using TypesF16F32F64CF64 = ::testing::Types<
62 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
63     Eigen::half,
64 #endif
65 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
66     double, complex64,
67 #endif
68     float>;
69 
70 // Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest,DotOfInputTupleElem)71 XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
72   XlaBuilder builder(TestName());
73 
74   XlaOp param;
75   TF_ASSERT_OK_AND_ASSIGN(
76       auto param_data,
77       CreateParameterAndTransferLiteral(
78           0,
79           LiteralUtil::MakeTupleFromSlices(
80               {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
81                LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
82           "arg0", &builder, &param));
83   auto lhs = GetTupleElement(param, 0);
84   auto rhs = GetTupleElement(param, 1);
85   Dot(lhs, rhs);
86 
87   ComputeAndCompareLiteral(&builder,
88                            LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
89                            {param_data.get()});
90 }
91 
92 template <typename T>
93 class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
94 TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);
95 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ZeroElementVectorDot)96 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
97   using T = TypeParam;
98   XlaBuilder builder(this->TestName());
99 
100   auto lhs = ConstantR1<T>(&builder, {});
101   auto rhs = ConstantR1<T>(&builder, {});
102   Dot(lhs, rhs);
103 
104   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
105                                         this->error_spec_);
106 }
107 
108 template <typename T>
109 class DotOperationTest_F16F32F64 : public DotOperationTest {};
110 TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64);
111 
XLA_TYPED_TEST(DotOperationTest_F16F32F64,TrivialMatrixVectorDot)112 XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
113   using T = TypeParam;
114   XlaBuilder builder(this->TestName());
115   auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}});
116   auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f});
117   Dot(lhs, rhs);
118 
119   this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
120                                         this->error_spec_);
121 }
122 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,OneElementVectorDot)123 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
124   using T = TypeParam;
125   XlaBuilder builder(this->TestName());
126   auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
127   auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)});
128   Dot(lhs, rhs);
129 
130   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
131                                         this->error_spec_);
132 }
133 
XLA_TYPED_TEST(DotOperationTest_F16F32F64,VectorDot)134 XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
135   using T = TypeParam;
136   XlaBuilder builder(this->TestName());
137   auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f});
138   auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f});
139   Dot(lhs, rhs);
140 
141   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
142                                         this->error_spec_);
143 }
144 
MinorToMajorForIsRowMajor(bool row_major)145 std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
146   return {row_major ? 1 : 0, row_major ? 0 : 1};
147 }
148 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x0)149 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
150   using T = TypeParam;
151   XlaBuilder builder(this->TestName());
152   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
153   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
154   Dot(lhs, rhs);
155 
156   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
157                                         this->error_spec_);
158 }
159 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x3)160 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
161   using T = TypeParam;
162   XlaBuilder builder(this->TestName());
163   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
164   auto rhs = ConstantR2FromArray2D<T>(
165       &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
166   Dot(lhs, rhs);
167 
168   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
169                                         this->error_spec_);
170 }
171 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_3x2_2x0)172 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
173   using T = TypeParam;
174   XlaBuilder builder(this->TestName());
175   auto lhs = ConstantR2FromArray2D<T>(
176       &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
177   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
178   Dot(lhs, rhs);
179 
180   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
181                                         this->error_spec_);
182 }
183 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_2x0_0x2)184 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
185   using T = TypeParam;
186   XlaBuilder builder(this->TestName());
187   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
188   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
189   Dot(lhs, rhs);
190 
191   this->template ComputeAndCompareR2<T>(
192       &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
193 }
194 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,FusedDot)195 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
196   using T = TypeParam;
197   XlaBuilder builder(this->TestName());
198   auto param0 =
199       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
200   auto param1 =
201       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
202   auto exp0 = Exp(param0);
203   Dot(exp0, param1);
204 
205   auto lhs_handle =
206       this->client_
207           ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
208               {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
209           .ConsumeValueOrDie();
210   auto rhs_handle = this->client_
211                         ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
212                             {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
213                         .ConsumeValueOrDie();
214 
215   if (std::is_same<Eigen::half, T>::value) {
216     this->error_spec_ = ErrorSpec{0.0001, 1e-3};
217   }
218 
219   this->template ComputeAndCompareR2<T>(
220       &builder, Array2D<T>({{296.14560492846033f}, {0.8611737683031964f}}),
221       {lhs_handle.get(), rhs_handle.get()}, this->error_spec_);
222 }
223 
224 template <typename T>
225 class SquareMatrixDot : public DotOperationTest {
226  public:
TestImpl(bool lhs_row_major,bool rhs_row_major)227   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
228     auto lhs_handle =
229         client_
230             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
231                 {{1.0f, 2.0f}, {3.0f, -4.0f}},
232                 LayoutUtil::MakeLayout(
233                     MinorToMajorForIsRowMajor(lhs_row_major))))
234             .ConsumeValueOrDie();
235     auto rhs_handle =
236         client_
237             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
238                 {{1.0f, 6.0f}, {7.0f, -4.0f}},
239                 LayoutUtil::MakeLayout(
240                     MinorToMajorForIsRowMajor(rhs_row_major))))
241             .ConsumeValueOrDie();
242     XlaBuilder builder(TestName());
243     auto prim_type = primitive_util::NativeToPrimitiveType<T>();
244     Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
245         Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
246 
247     Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}});
248     ComputeAndCompareR2<T>(&builder, expected,
249                            {lhs_handle.get(), rhs_handle.get()}, error_spec_);
250   }
251 };
252 
253 TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(SquareMatrixDot,TypesFF)254 XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesFT)255 XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTF)256 XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTT)257 XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); }
258 
259 struct DotTestParam {
260   int m;
261   int k;
262   int n;
263   bool dot_lhs_row_major;
264   bool dot_rhs_row_major;
265   bool has_addend;
266   bool addend_row_major;
267 };
268 
PrintDotTestParam(const::testing::TestParamInfo<DotTestParam> & test_param)269 string PrintDotTestParam(
270     const ::testing::TestParamInfo<DotTestParam>& test_param) {
271   const DotTestParam& param = test_param.param;
272   if (param.has_addend) {
273     return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
274                         param.dot_lhs_row_major ? "T" : "F",
275                         param.dot_rhs_row_major ? "T" : "F",
276                         param.addend_row_major ? "T" : "F");
277   } else {
278     return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
279                         param.dot_lhs_row_major ? "T" : "F",
280                         param.dot_rhs_row_major ? "T" : "F");
281   }
282 }
283 
284 class ParametricDotTest : public DotOperationTest,
285                           public ::testing::WithParamInterface<DotTestParam> {
286  protected:
287   template <typename NativeT>
288   void TestImpl();
289 
290   template <typename NativeT>
291   void ComputeAndCompareR2WithError(XlaBuilder* builder,
292                                     const Array2D<NativeT>& expected,
293                                     absl::Span<GlobalData* const> arguments);
294 };
295 
296 template <typename NativeT>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)297 void ParametricDotTest::ComputeAndCompareR2WithError(
298     XlaBuilder* builder, const Array2D<NativeT>& expected,
299     absl::Span<GlobalData* const> arguments) {
300   ErrorSpec error_spec(0.3, 3e-3);
301   ComputeAndCompareR2(builder, expected, arguments, error_spec);
302 }
303 
304 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<Eigen::half> & expected,absl::Span<GlobalData * const> arguments)305 void ParametricDotTest::ComputeAndCompareR2WithError<Eigen::half>(
306     XlaBuilder* builder, const Array2D<Eigen::half>& expected,
307     absl::Span<GlobalData* const> arguments) {
308   ErrorSpec error_spec(0.3, 7e-3);
309   ComputeAndCompareR2(builder, expected, arguments, error_spec);
310 }
311 
312 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<int32> & expected,absl::Span<GlobalData * const> arguments)313 void ParametricDotTest::ComputeAndCompareR2WithError<int32>(
314     XlaBuilder* builder, const Array2D<int32>& expected,
315     absl::Span<GlobalData* const> arguments) {
316   ComputeAndCompareR2(builder, expected, arguments);
317 }
318 
319 template <typename NativeT>
TestImpl()320 void ParametricDotTest::TestImpl() {
321   DotTestParam param = GetParam();
322 
323   std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
324       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
325   Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
326       *dot_lhs_data, LayoutUtil::MakeLayout(
327                          MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
328   std::unique_ptr<GlobalData> dot_lhs_handle =
329       client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
330 
331   std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
332       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
333   Layout rhs_layout = LayoutUtil::MakeLayout(
334       MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
335   Literal dot_rhs_lit =
336       LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
337   std::unique_ptr<GlobalData> dot_rhs_handle =
338       client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
339 
340   std::unique_ptr<Array2D<NativeT>> addend_data;
341   Literal addend_lit;
342   std::unique_ptr<GlobalData> addend_handle;
343 
344   if (param.has_addend) {
345     addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
346     addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
347         *addend_data, LayoutUtil::MakeLayout(
348                           MinorToMajorForIsRowMajor(param.addend_row_major)));
349     addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
350   }
351 
352   XlaBuilder builder(TestName());
353   auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>();
354   auto result =
355       Dot(Parameter(&builder, 0,
356                     ShapeUtil::MakeShapeWithLayout(
357                         prim_type, {param.m, param.k},
358                         MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
359                     "dot_lhs"),
360           Parameter(&builder, 1,
361                     ShapeUtil::MakeShapeWithLayout(
362                         prim_type, {param.k, param.n},
363                         MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
364                     "dot_rhs"));
365 
366   if (param.has_addend) {
367     result =
368         Add(result,
369             Parameter(&builder, 2,
370                       ShapeUtil::MakeShapeWithLayout(
371                           prim_type, {param.m, param.n},
372                           MinorToMajorForIsRowMajor(param.addend_row_major)),
373                       "addend"));
374   }
375 
376   std::unique_ptr<Array2D<NativeT>> expected;
377   if (param.has_addend) {
378     expected = ReferenceUtil::ApplyElementwise2D(
379         std::plus<NativeT>(),
380         *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
381         *addend_data);
382   } else {
383     expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
384   }
385 
386   std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
387   if (param.has_addend) {
388     args.push_back(addend_handle.get());
389   }
390   ComputeAndCompareR2WithError<NativeT>(&builder, *expected, args);
391 }
392 
CreateDotTestParameters()393 std::vector<DotTestParam> CreateDotTestParameters() {
394   std::vector<DotTestParam> params;
395 
396   auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
397     for (bool lhs_row_major : {true, false}) {
398       for (bool rhs_row_major : {true, false}) {
399         params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
400                           /*dot_lhs_row_major=*/lhs_row_major,
401                           /*dot_rhs_row_major=*/rhs_row_major,
402                           /*has_addend=*/false, /*addend_row_major=*/true});
403       }
404     }
405   };
406 
407   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
408   add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
409   add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
410   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
411   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
412   add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
413   add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
414   add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
415 
416   return params;
417 }
418 
419 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTest,TestF16)420 XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
421 #endif
XLA_TEST_P(ParametricDotTest,TestF32)422 XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
XLA_TEST_P(ParametricDotTest,TestF64)423 XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
XLA_TEST_P(ParametricDotTest,TestC64)424 XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
425 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
XLA_TEST_P(ParametricDotTest,TestC128)426 XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
427 #endif
XLA_TEST_P(ParametricDotTest,TestS32)428 XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32>(); }
429 
430 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
431                         ::testing::ValuesIn(CreateDotTestParameters()),
432                         PrintDotTestParam);
433 
434 class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
435  public:
ParametricDotTestWithoutLayoutAssignment()436   ParametricDotTestWithoutLayoutAssignment() {
437     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
438         "layout-assignment");
439     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
440         "tiling-assignment");
441     // Disable algebraic simplification because the pass may replace a dot
442     // instruction with a layout-changing multiplication instruction.
443     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
444         "algsimp");
445   }
446 };
447 
CreateNoLayoutAssignmentDotTestParameters()448 std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
449   std::vector<DotTestParam> params;
450 
451   auto add_matrix_vector_dot_test = [&](int k, int n) {
452     for (bool lhs_row_major : {true, false}) {
453       for (bool rhs_row_major : {true, false}) {
454         for (bool has_addend : {true, false}) {
455           // The addend needs to be row major to match the result of the dot.
456           params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
457                             /*dot_lhs_row_major=*/lhs_row_major,
458                             /*dot_rhs_row_major=*/rhs_row_major,
459                             /*has_addend=*/has_addend,
460                             /*addend_row_major=*/true});
461           if (n != 1) {
462             params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
463                               /*dot_lhs_row_major=*/lhs_row_major,
464                               /*dot_rhs_row_major=*/rhs_row_major,
465                               /*has_addend=*/has_addend,
466                               /*addend_row_major=*/true});
467           }
468         }
469       }
470     }
471   };
472 
473   add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
474   add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
475   add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
476   add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
477   add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
478   add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
479   add_matrix_vector_dot_test(/*k=*/1, /*n=*/4);
480   add_matrix_vector_dot_test(/*k=*/1, /*n=*/3);
481   add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
482   add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
483   add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
484   add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
485   add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
486   add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
487 
488   return params;
489 }
490 
491 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF16)492 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) {
493   TestImpl<Eigen::half>();
494 }
495 #endif
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF32)496 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) {
497   TestImpl<float>();
498 }
499 // TODO(b/147505663): Disabled for now.
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,DISABLED_TestF64)500 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, DISABLED_TestF64) {
501   TestImpl<double>();
502 }
503 
504 INSTANTIATE_TEST_CASE_P(
505     DotTests, ParametricDotTestWithoutLayoutAssignment,
506     ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()),
507     PrintDotTestParam);
508 
509 template <typename T>
510 class NonsquareMatrixDot : public DotOperationTest {
511  public:
TestImpl(bool lhs_row_major,bool rhs_row_major)512   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
513     auto lhs_handle =
514         client_
515             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
516                 {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
517                 LayoutUtil::MakeLayout(
518                     MinorToMajorForIsRowMajor(lhs_row_major))))
519             .ConsumeValueOrDie();
520     auto rhs_handle =
521         client_
522             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
523                 {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
524                 LayoutUtil::MakeLayout(
525                     MinorToMajorForIsRowMajor(rhs_row_major))))
526             .ConsumeValueOrDie();
527 
528     XlaBuilder builder(TestName());
529     auto prim_type = primitive_util::NativeToPrimitiveType<T>();
530     Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
531         Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
532 
533     Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
534 
535     ComputeAndCompareR2<T>(&builder, expected,
536                            {lhs_handle.get(), rhs_handle.get()}, error_spec_);
537   }
538 };
539 
540 TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(NonsquareMatrixDot,TestFF)541 XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestFT)542 XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTF)543 XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTT)544 XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
545 
XLA_TEST_F(DotOperationTest,MatrixVectorC64)546 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
547   auto lhs_handle =
548       client_
549           ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
550               {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
551           .ConsumeValueOrDie();
552   auto rhs_handle =
553       client_
554           ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
555               {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
556               LayoutUtil::MakeLayout({1, 0})))
557           .ConsumeValueOrDie();
558 
559   XlaBuilder builder(TestName());
560   auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
561   Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
562       Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
563 
564   Array2D<complex64> expected({{30.0, -2.0}});
565 
566   ComputeAndCompareR2<complex64>(
567       &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
568 }
569 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ConcurrentMatMult)570 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
571   using T = TypeParam;
572 
573   XlaBuilder builder(this->TestName());
574   auto matrix1 =
575       ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}});
576   auto matrix2 =
577       ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}});
578   auto matrix12 = Dot(matrix1, matrix2);
579   auto matrix21 = Dot(matrix2, matrix1);
580   Add(matrix12, matrix21);
581 
582   Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}});
583   this->template ComputeAndCompareR2<T>(&builder, expected, {},
584                                         this->error_spec_);
585 }
586 
587 template <typename T>
588 class DotOperationTestForBatchMatMul : public DotOperationTest {};
589 TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64);
590 
591 // Regression test for b/32055648. The root of the graph is a kFusion of 4
592 // bitcasts. Although bitcasts don't map to thunks, the root should still be
593 // sync-dependent on bitcasts' operands.
XLA_TYPED_TEST(DotOperationTestForBatchMatMul,Types)594 XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
595   using T = TypeParam;
596   XlaBuilder builder(this->TestName());
597   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
598                      "x");
599   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
600                      "y");
601 
602   auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
603   auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
604 
605   // Slice batches into individual matrices and multiply them.
606   std::vector<XlaOp> out_slices;
607   for (int i = 0; i < 4; ++i) {
608     // Slice off individual matrices and reshape to 2D tensors.
609     auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
610     x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2});
611     auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
612     y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2});
613 
614     auto out = Dot(x_slice, y_slice);
615     out = Reshape(out, {0, 1}, {1, 2, 2});
616     out_slices.push_back(out);
617   }
618   auto out_flat = ConcatInDim(&builder, out_slices, 0);
619   Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
620 
621   auto x_data = this->client_
622                     ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
623                         {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
624                           {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
625                          {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
626                           {{4000.0f, 400.0f}, {40.0f, 4.0f}}}}))
627                     .ConsumeValueOrDie();
628   auto y_data =
629       this->client_
630           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
631               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
632                {{{11.0f, 22.0f}, {33.0f, 44.0f}},
633                 {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
634           .ConsumeValueOrDie();
635 
636   if (std::is_same<Eigen::half, T>::value) {
637     this->error_spec_ = ErrorSpec{0.0001, 1e-3};
638   }
639   this->template ComputeAndCompareR4<T>(
640       &builder,
641       /*expected=*/
642       {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}},
643         {{11400.0f, 13600.0f}, {114.0f, 136.0f}}},
644        {{{42900.0f, 79200.0f}, {429.0f, 792.0f}},
645         {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}},
646       {x_data.get(), y_data.get()}, this->error_spec_);
647 }
648 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMul)649 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
650   using T = TypeParam;
651 
652   XlaBuilder builder(this->TestName());
653   auto x =
654       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
655   auto y =
656       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
657 
658   DotDimensionNumbers dnums;
659   dnums.add_lhs_contracting_dimensions(2);
660   dnums.add_rhs_contracting_dimensions(1);
661   dnums.add_lhs_batch_dimensions(0);
662   dnums.add_rhs_batch_dimensions(0);
663 
664   DotGeneral(x, y, dnums);
665 
666   auto x_data =
667       this->client_
668           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
669               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
670           .ConsumeValueOrDie();
671 
672   auto y_data =
673       this->client_
674           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
675               {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
676           .ConsumeValueOrDie();
677 
678   this->template ComputeAndCompareR3<T>(
679       &builder,
680       /*expected=*/
681       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
682       {x_data.get(), y_data.get()}, this->error_spec_);
683 }
684 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR3LhsR2Rhs)685 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) {
686   using T = TypeParam;
687 
688   XlaBuilder builder(this->TestName());
689   auto x =
690       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
691   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y");
692 
693   DotDimensionNumbers dnums;
694   dnums.add_lhs_contracting_dimensions(1);
695   dnums.add_rhs_contracting_dimensions(1);
696   dnums.add_lhs_batch_dimensions(0);
697   dnums.add_rhs_batch_dimensions(0);
698 
699   DotGeneral(x, y, dnums);
700 
701   auto x_data =
702       this->client_
703           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
704               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
705           .ConsumeValueOrDie();
706 
707   auto y_data = this->client_
708                     ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
709                         {{1.0f, 0.0f}, {0.0f, 1.0f}}))
710                     .ConsumeValueOrDie();
711 
712   this->template ComputeAndCompareR2<T>(
713       &builder,
714       /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
715       this->error_spec_);
716 }
717 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR2LhsR3Rhs)718 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) {
719   using T = TypeParam;
720 
721   XlaBuilder builder(this->TestName());
722   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x");
723   auto y =
724       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
725 
726   DotDimensionNumbers dnums;
727   dnums.add_lhs_contracting_dimensions(1);
728   dnums.add_rhs_contracting_dimensions(1);
729   dnums.add_lhs_batch_dimensions(0);
730   dnums.add_rhs_batch_dimensions(0);
731 
732   DotGeneral(x, y, dnums);
733 
734   auto x_data = this->client_
735                     ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
736                         {{1.0f, 0.0f}, {0.0f, 1.0f}}))
737                     .ConsumeValueOrDie();
738 
739   auto y_data =
740       this->client_
741           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
742               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
743           .ConsumeValueOrDie();
744 
745   this->template ComputeAndCompareR2<T>(
746       &builder,
747       /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
748       this->error_spec_);
749 }
750 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulMultipleBatch)751 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
752   using T = TypeParam;
753 
754   XlaBuilder builder(this->TestName());
755   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
756                      "x");
757   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
758                      "y");
759 
760   DotDimensionNumbers dnums;
761   dnums.add_lhs_contracting_dimensions(3);
762   dnums.add_rhs_contracting_dimensions(2);
763   dnums.add_lhs_batch_dimensions(0);
764   dnums.add_lhs_batch_dimensions(1);
765   dnums.add_rhs_batch_dimensions(0);
766   dnums.add_rhs_batch_dimensions(1);
767 
768   DotGeneral(x, y, dnums);
769 
770   auto x_data =
771       this->client_
772           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
773               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
774                {{{9.0f, 10.0f}, {11.0f, 12.0f}},
775                 {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
776           .ConsumeValueOrDie();
777 
778   auto y_data =
779       this->client_
780           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
781               {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
782                {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
783           .ConsumeValueOrDie();
784 
785   this->template ComputeAndCompareR4<T>(
786       &builder,
787       /*expected=*/
788       {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
789        {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
790       {x_data.get(), y_data.get()}, this->error_spec_);
791 }
792 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,TransposeFolding)793 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
794   using T = TypeParam;
795   for (bool transpose_lhs : {false, true}) {
796     for (bool transpose_rhs : {false, true}) {
797       for (bool row_major : {false, true}) {
798         std::unique_ptr<Array2D<T>> lhs(
799             new Array2D<T>({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}));
800         std::unique_ptr<Array2D<T>> rhs(
801             new Array2D<T>({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}));
802 
803         if (transpose_lhs) {
804           lhs = ReferenceUtil::TransposeArray2D(*lhs);
805         }
806         if (transpose_rhs) {
807           rhs = ReferenceUtil::TransposeArray2D(*rhs);
808         }
809         auto lhs_handle =
810             this->client_
811                 ->TransferToServer(
812                     LiteralUtil::CreateR2FromArray2DWithLayout<T>(
813                         *lhs, LayoutUtil::MakeLayout(
814                                   MinorToMajorForIsRowMajor(row_major))))
815                 .ConsumeValueOrDie();
816         auto rhs_handle =
817             this->client_
818                 ->TransferToServer(
819                     LiteralUtil::CreateR2FromArray2DWithLayout<T>(
820                         *rhs, LayoutUtil::MakeLayout(
821                                   MinorToMajorForIsRowMajor(row_major))))
822                 .ConsumeValueOrDie();
823 
824         XlaBuilder builder(this->TestName());
825         auto prim_type = primitive_util::NativeToPrimitiveType<T>();
826         auto lhs_arg = Parameter(
827             &builder, 0,
828             ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
829             "lhs");
830         auto rhs_arg = Parameter(
831             &builder, 1,
832             ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
833             "rhs");
834         if (transpose_lhs) {
835           lhs_arg = Transpose(lhs_arg, {1, 0});
836         }
837         if (transpose_rhs) {
838           rhs_arg = Transpose(rhs_arg, {1, 0});
839         }
840         Dot(lhs_arg, rhs_arg);
841 
842         Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
843         VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
844                 << transpose_rhs << " " << row_major;
845         this->template ComputeAndCompareR2<T>(
846             &builder, expected, {lhs_handle.get(), rhs_handle.get()},
847             this->error_spec_);
848       }
849     }
850   }
851 }
852 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstLHS)853 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
854                DotOfConcatOptimizationWithConstLHS) {
855   using T = TypeParam;
856   auto prim_type = primitive_util::NativeToPrimitiveType<T>();
857 
858   std::unique_ptr<Array2D<T>> constant_lhs_array(
859       new Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
860                       {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}}));
861 
862   XlaBuilder builder(this->TestName());
863   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
864   auto rhs_arg_0 = Parameter(
865       &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0");
866   auto rhs_arg_1 = Parameter(
867       &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1");
868   auto rhs_arg_2 = Parameter(
869       &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2");
870   Dot(lhs_constant,
871       ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
872 
873   std::unique_ptr<Array2D<T>> arg_0_value_array(
874       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
875   std::unique_ptr<Array2D<T>> arg_1_value_array(
876       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}));
877   std::unique_ptr<Array2D<T>> arg_2_value_array(new Array2D<T>({{1.0f, 2.0f}}));
878 
879   TF_ASSERT_OK_AND_ASSIGN(
880       auto arg_0_value,
881       this->client_->TransferToServer(
882           LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
883   TF_ASSERT_OK_AND_ASSIGN(
884       auto arg_1_value,
885       this->client_->TransferToServer(
886           LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
887   TF_ASSERT_OK_AND_ASSIGN(
888       auto arg_2_value,
889       this->client_->TransferToServer(
890           LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
891 
892   Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
893   this->template ComputeAndCompareR2<T>(
894       &builder, expected,
895       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
896       this->error_spec_);
897 }
898 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstRHS)899 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
900                DotOfConcatOptimizationWithConstRHS) {
901   using T = TypeParam;
902   std::unique_ptr<Array2D<T>> constant_rhs_array(
903       new Array2D<T>({{1.0f, 2.0f},
904                       {3.0f, 4.0f},
905                       {5.0f, 6.0f},
906                       {6.0f, 5.0f},
907                       {4.0f, 3.0f},
908                       {2.0f, 1.0f}}));
909 
910   XlaBuilder builder(this->TestName());
911   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
912   auto lhs_arg_0 = Parameter(
913       &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0");
914   auto lhs_arg_1 = Parameter(
915       &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1");
916   auto lhs_arg_2 = Parameter(
917       &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2");
918   Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
919       rhs_constant);
920 
921   std::unique_ptr<Array2D<T>> arg_0_value_array(
922       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
923   std::unique_ptr<Array2D<T>> arg_1_value_array(
924       new Array2D<T>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}));
925   std::unique_ptr<Array2D<T>> arg_2_value_array(
926       new Array2D<T>({{1.0f}, {2.0f}}));
927 
928   TF_ASSERT_OK_AND_ASSIGN(
929       auto arg_0_value,
930       this->client_->TransferToServer(
931           LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
932   TF_ASSERT_OK_AND_ASSIGN(
933       auto arg_1_value,
934       this->client_->TransferToServer(
935           LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
936   TF_ASSERT_OK_AND_ASSIGN(
937       auto arg_2_value,
938       this->client_->TransferToServer(
939           LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
940 
941   Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
942   this->template ComputeAndCompareR2<T>(
943       &builder, expected,
944       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
945       this->error_spec_);
946 }
947 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSClassicMM)948 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
949   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
950       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
951   std::unique_ptr<Array2D<float>> constant_rhs_array(
952       new Array2D<float>({{1.0, 2.0, 3.0},
953                           {4.0, 5.0, 6.0},
954                           {7.0, 8.0, 9.0},
955                           {9.0, 8.0, 7.0},
956                           {6.0, 5.0, 4.0},
957                           {3.0, 2.0, 1.0}}));
958   // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
959 
960   XlaBuilder builder(TestName());
961   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
962   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
963   auto one = ConstantR0<int32>(&builder, 1);
964   auto zero = ConstantR0<int32>(&builder, 0);
965   auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
966 
967   DotDimensionNumbers dot_dnums;
968   dot_dnums.add_lhs_contracting_dimensions(1);
969   dot_dnums.add_rhs_contracting_dimensions(0);
970   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
971 
972   Array2D<float> expected({{96.0, 105.0, 114.0}});
973   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
974 }
975 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSClassicMM)976 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
977   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
978       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
979   std::unique_ptr<Array2D<float>> constant_rhs_array(
980       new Array2D<float>({{1.0, 2.0, 3.0},
981                           {4.0, 5.0, 6.0},
982                           {7.0, 8.0, 9.0},
983                           {9.0, 8.0, 7.0},
984                           {6.0, 5.0, 4.0},
985                           {3.0, 2.0, 1.0}}));
986   // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
987 
988   XlaBuilder builder(TestName());
989   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
990   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
991   auto zero = ConstantR0<int32>(&builder, 0);
992   auto one = ConstantR0<int32>(&builder, 1);
993   auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
994 
995   DotDimensionNumbers dot_dnums;
996   dot_dnums.add_lhs_contracting_dimensions(1);
997   dot_dnums.add_rhs_contracting_dimensions(0);
998   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
999 
1000   Array2D<float> expected({{105.0}, {105.0}});
1001   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1002 }
1003 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSReverseMM)1004 XLA_TEST_F(DotOperationTest,
1005 
1006            DotOfGatherOptimizationWithConstRHSReverseMM) {
1007   std::unique_ptr<Array2D<float>> constant_lhs_array(
1008       new Array2D<float>({{1.0, 2.0, 3.0},
1009                           {4.0, 5.0, 6.0},
1010                           {7.0, 8.0, 9.0},
1011                           {9.0, 8.0, 7.0},
1012                           {6.0, 5.0, 4.0},
1013                           {3.0, 2.0, 1.0}}));
1014   std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1015       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1016   // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1017 
1018   XlaBuilder builder(TestName());
1019   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1020   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1021   auto zero = ConstantR0<int32>(&builder, 0);
1022   auto one = ConstantR0<int32>(&builder, 1);
1023   auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1024 
1025   DotDimensionNumbers dot_dnums;
1026   dot_dnums.add_lhs_contracting_dimensions(0);
1027   dot_dnums.add_rhs_contracting_dimensions(1);
1028   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1029 
1030   Array2D<float> expected({{105.0, 105.0}});
1031   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1032 }
1033 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSReverseMM)1034 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
1035   std::unique_ptr<Array2D<float>> constant_lhs_array(
1036       new Array2D<float>({{1.0, 2.0, 3.0},
1037                           {4.0, 5.0, 6.0},
1038                           {7.0, 8.0, 9.0},
1039                           {9.0, 8.0, 7.0},
1040                           {6.0, 5.0, 4.0},
1041                           {3.0, 2.0, 1.0}}));
1042   std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1043       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1044   // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1045 
1046   XlaBuilder builder(TestName());
1047   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1048   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1049   auto zero = ConstantR0<int32>(&builder, 0);
1050   auto one = ConstantR0<int32>(&builder, 1);
1051   auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1052 
1053   DotDimensionNumbers dot_dnums;
1054   dot_dnums.add_lhs_contracting_dimensions(0);
1055   dot_dnums.add_rhs_contracting_dimensions(1);
1056   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1057 
1058   Array2D<float> expected({{96.0}, {105.0}, {114.0}});
1059   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1060 }
1061 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSRows)1062 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
1063   std::unique_ptr<Array2D<float>> constant_lhs_array(
1064       new Array2D<float>({{1.0, 2.0},
1065                           {3.0, 4.0},
1066                           {5.0, 6.0},
1067                           {6.0, 5.0},
1068                           {4.0, 3.0},
1069                           {2.0, 1.0}}));
1070   std::unique_ptr<Array2D<float>> constant_rhs_array(
1071       new Array2D<float>({{1.0, 2.0, 3.0},
1072                           {4.0, 5.0, 6.0},
1073                           {7.0, 8.0, 9.0},
1074                           {9.0, 8.0, 7.0},
1075                           {6.0, 5.0, 4.0},
1076                           {3.0, 2.0, 1.0}}));
1077   // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1078 
1079   XlaBuilder builder(TestName());
1080   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1081   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1082   auto zero = ConstantR0<int32>(&builder, 0);
1083   auto one = ConstantR0<int32>(&builder, 1);
1084   auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1085 
1086   DotDimensionNumbers dot_dnums;
1087   dot_dnums.add_lhs_contracting_dimensions(0);
1088   dot_dnums.add_rhs_contracting_dimensions(0);
1089   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1090 
1091   Array2D<float> expected({{126.0, 129.0, 132.0}});
1092   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1093 }
1094 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSRows)1095 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
1096   std::unique_ptr<Array2D<float>> constant_lhs_array(
1097       new Array2D<float>({{1.0, 2.0},
1098                           {3.0, 4.0},
1099                           {5.0, 6.0},
1100                           {6.0, 5.0},
1101                           {4.0, 3.0},
1102                           {2.0, 1.0}}));
1103   std::unique_ptr<Array2D<float>> constant_rhs_array(
1104       new Array2D<float>({{1.0, 2.0, 3.0},
1105                           {4.0, 5.0, 6.0},
1106                           {7.0, 8.0, 9.0},
1107                           {9.0, 8.0, 7.0},
1108                           {6.0, 5.0, 4.0},
1109                           {3.0, 2.0, 1.0}}));
1110   // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1111 
1112   XlaBuilder builder(TestName());
1113   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1114   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1115   auto zero = ConstantR0<int32>(&builder, 0);
1116   auto one = ConstantR0<int32>(&builder, 1);
1117   auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
1118 
1119   DotDimensionNumbers dot_dnums;
1120   dot_dnums.add_lhs_contracting_dimensions(0);
1121   dot_dnums.add_rhs_contracting_dimensions(0);
1122   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1123 
1124   Array2D<float> expected({{129.0}, {129.0}});
1125   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1126 }
1127 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSCols)1128 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
1129   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1130       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1131   std::unique_ptr<Array2D<float>> constant_rhs_array(
1132       new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1133                           {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1134                           {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1135   // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1136 
1137   XlaBuilder builder(TestName());
1138   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1139   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1140   auto zero = ConstantR0<int32>(&builder, 0);
1141   auto one = ConstantR0<int32>(&builder, 1);
1142   auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
1143 
1144   DotDimensionNumbers dot_dnums;
1145   dot_dnums.add_lhs_contracting_dimensions(1);
1146   dot_dnums.add_rhs_contracting_dimensions(1);
1147   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1148 
1149   Array2D<float> expected({{56.0, 168.0, 91.0}});
1150   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1151 }
1152 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSCols)1153 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
1154   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1155       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1156   std::unique_ptr<Array2D<float>> constant_rhs_array(
1157       new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1158                           {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1159                           {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1160   // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1161 
1162   XlaBuilder builder(TestName());
1163   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1164   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1165   auto zero = ConstantR0<int32>(&builder, 0);
1166   auto one = ConstantR0<int32>(&builder, 1);
1167   auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1168 
1169   DotDimensionNumbers dot_dnums;
1170   dot_dnums.add_lhs_contracting_dimensions(1);
1171   dot_dnums.add_rhs_contracting_dimensions(1);
1172   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1173 
1174   Array2D<float> expected({{168.0}, {168.0}});
1175   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1176 }
1177 
XLA_TEST_F(DotOperationTest,DotRank2AndRank2NonDefaultContractionDims)1178 XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
1179   XlaBuilder builder(TestName());
1180 
1181   Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
1182   auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
1183 
1184   Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
1185   auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
1186 
1187   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1188   DotDimensionNumbers dot_dnums;
1189   dot_dnums.add_lhs_contracting_dimensions(0);
1190   dot_dnums.add_rhs_contracting_dimensions(0);
1191   DotGeneral(lhs_constant, rhs_constant, dot_dnums);
1192 
1193   Array2D<float> expected({
1194       {26.f, 30.f},
1195       {38.f, 44.f},
1196   });
1197 
1198   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1199 }
1200 
1201 using EinsumParamType =
1202     std::tuple<std::vector<int64>, std::vector<int64>, string>;
1203 class EinsumTest : public DotOperationTest,
1204                    public ::testing::WithParamInterface<EinsumParamType> {};
XLA_TEST_P(EinsumTest,SimpleEinsumTest)1205 XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
1206   XlaBuilder builder(TestName());
1207   auto x = AddParam(
1208       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1209           .ValueOrDie(),
1210       &builder);
1211   auto y = AddParam(
1212       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1213           .ValueOrDie(),
1214       &builder);
1215   auto config = std::get<2>(GetParam());
1216   if (config.find(',') == config.npos) {
1217     Einsum(x, config);
1218   } else {
1219     Einsum(x, y, config);
1220   }
1221   ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1222 }
1223 
GetEinsumTestCases()1224 std::vector<EinsumParamType> GetEinsumTestCases() {
1225   using v = std::vector<int64>;
1226   using p = EinsumParamType;
1227   std::vector<p> test_cases = {
1228       p{v{5, 6}, v{6, 7}, "mk,kn->mn"},
1229       p{v{5, 6}, v{6, 7}, "mk,kn->nm"},
1230       p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"},
1231       p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"},
1232       p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"},
1233       p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
1234       p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
1235       p{v{6}, v{6, 7}, "b,bc->c"},
1236       p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
1237       p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
1238       p{v{77}, v{77}, "a,a->a"},
1239       p{v{77}, v{77, 55}, "a,ab->ba"},
1240       p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
1241       p{v{55}, v{}, "a,->a"},
1242       p{v{11, 111}, v{11}, "ab,a->ab"},
1243       p{v{16, 34}, v{16, 34}, "ab,ab->ab"},
1244       p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"},
1245       p{v{5, 19}, v{}, "ab,->ab"},
1246       p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"},
1247       p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
1248       p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
1249       p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
1250       p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
1251       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1252       p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1253       p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
1254       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
1255       p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
1256       p{v{5, 6}, v{6, 7}, "mk,kn"},
1257       p{v{5, 6}, v{6, 7}, "mk,kn"},
1258       p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
1259       p{v{5, 6}, v{6, 7}, "ab,cd"},
1260       p{v{6}, v{6, 7}, "b,bc"},
1261       p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
1262       p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
1263       p{v{77}, v{77}, "a,a"},
1264       p{v{77}, v{77, 55}, "a,ab"},
1265       p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
1266       p{v{55}, v{}, "a"},
1267       p{v{11, 111}, v{11}, "ab,a"},
1268       p{v{16, 34}, v{16, 34}, "ab,ab"},
1269       p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
1270       p{v{5, 19}, v{}, "ab"},
1271       p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
1272       p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
1273       p{v{5, 6}, v{}, "...mk"},
1274       p{v{5, 6, 12, 13}, v{}, "...mk"},
1275       p{v{5, 6, 12, 13}, v{}, "m...k"},
1276       p{v{5, 6, 12, 13}, v{}, "mk..."},
1277       p{v{5, 6}, v{6, 7}, "...mk->km..."},
1278       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1279       p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1280       p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
1281       p{v{16, 16, 16}, v{}, "iii"},
1282       p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
1283   };
1284   return test_cases;
1285 }
1286 
1287 INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest,
1288                          ::testing::ValuesIn(GetEinsumTestCases()));
1289 
1290 using BatchDotParamType =
1291     std::tuple<std::vector<int64>, std::vector<int64>, std::vector<int64>>;
1292 class BatchDotTest : public DotOperationTest,
1293                      public ::testing::WithParamInterface<BatchDotParamType> {};
XLA_TEST_P(BatchDotTest,BroadcastingBatchDotTest)1294 XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) {
1295   XlaBuilder builder(TestName());
1296   auto x = AddParam(
1297       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1298           .ValueOrDie(),
1299       &builder);
1300   auto y = AddParam(
1301       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1302           .ValueOrDie(),
1303       &builder);
1304   auto batch_dot = BatchDot(x, y);
1305   auto output_shape = builder.GetShape(batch_dot).ValueOrDie();
1306   EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam()));
1307   ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1308 }
1309 
GetBatchDotTestCases()1310 std::vector<BatchDotParamType> GetBatchDotTestCases() {
1311   using v = std::vector<int64>;
1312   using p = BatchDotParamType;
1313   std::vector<p> test_cases = {
1314       p{v{5, 6}, v{6, 7}, v{5, 7}},
1315       p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1316       p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}},
1317       p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}},
1318       p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1319       p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1320       p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}},
1321       p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}},
1322   };
1323   return test_cases;
1324 }
1325 
1326 INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest,
1327                          ::testing::ValuesIn(GetBatchDotTestCases()));
1328 
1329 class DotOperationTextTest : public HloTestBase {};
1330 
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDims)1331 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
1332   absl::string_view hlo_string =
1333       R"(
1334 HloModule ComplexDotMultipleNonContracting
1335 
1336 ENTRY %test {
1337   %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
1338   %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
1339   ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
1340 }
1341 )";
1342 
1343   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1344 }
1345 
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDimsAndMultipleContracting)1346 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
1347   absl::string_view hlo_string =
1348       R"(
1349 HloModule ComplexDotMultipleNonContracting
1350 
1351 ENTRY %test {
1352   %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
1353   %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
1354   ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
1355 }
1356 )";
1357 
1358   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1359 }
1360 
XLA_TEST_F(DotOperationTextTest,DotWithNoDnums)1361 XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) {
1362   absl::string_view hlo_string =
1363       R"(
1364 HloModule DotWithNoDnums
1365 
1366 ENTRY %test {
1367   %lhs = f32[2,3]{1,0} parameter(0)
1368   %rhs = f32[4,5]{1,0} parameter(1)
1369   ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs)
1370 }
1371 )";
1372 
1373   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1374 }
1375 
XLA_TEST_F(DotOperationTextTest,Einsum)1376 XLA_TEST_F(DotOperationTextTest, Einsum) {
1377   absl::string_view hlo_string =
1378       R"(
1379 HloModule Einsum
1380 
1381 ENTRY %test {
1382   %lhs = f32[8,64,96]{2,1,0} parameter(0)
1383   %rhs = f32[96,32,4]{2,1,0} parameter(1)
1384   ROOT %dot = f32[8,64,32,4]{3,2,1,0}  dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0}
1385 }
1386 )";
1387 
1388   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1389 }
1390 
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_1)1391 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
1392   // Tests for a caching bug in the XLA CPU backend.
1393   absl::string_view hlo_string =
1394       R"(
1395 HloModule CpuTiledDotEmitterCachingBug
1396 
1397 ENTRY main {
1398   lhs = f32[20,40] parameter(0)
1399   rhs_0 = f32[40,1] parameter(2)
1400   rhs_1 = f32[1,40] parameter(1)
1401 
1402   dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1403   dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1404 
1405   ROOT result = f32[20,1] divide(dot_0, dot_1)
1406 }
1407 )";
1408 
1409   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1410 }
1411 
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_2)1412 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
1413   // Tests for a caching bug in the XLA CPU backend.
1414   absl::string_view hlo_string =
1415       R"(
1416 HloModule CpuTiledDotEmitterCachingBug
1417 
1418 ENTRY main {
1419   lhs_0 = f32[20,40] parameter(0)
1420   rhs_0 = f32[40,1] parameter(1)
1421   lhs_1 = f32[1,40] parameter(2)
1422   rhs_1 = f32[20,40] parameter(3)
1423 
1424   dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1425   dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1426 
1427   dot_0_reshaped = f32[20] reshape(dot_0)
1428   dot_1_reshaped = f32[20] reshape(dot_1)
1429 
1430   ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
1431 }
1432 )";
1433 
1434   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1435 }
1436 
XLA_TEST_F(DotOperationTextTest,S32IotaDot)1437 XLA_TEST_F(DotOperationTextTest, S32IotaDot) {
1438   absl::string_view hlo_string =
1439       R"(
1440 HloModule SmallIntegerDot
1441 
1442 ENTRY SmallIntegerDot {
1443   arg0 = s32[5,55,8] iota(), iota_dimension=1
1444   arg1 = s32[5,8,200] iota(), iota_dimension=2
1445   ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1446 }
1447 )";
1448 
1449   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1450 }
1451 
XLA_TEST_F(DotOperationTextTest,S32IotaSquaredDot)1452 XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) {
1453   absl::string_view hlo_string =
1454       R"(
1455 HloModule SmallIntegerDot
1456 
1457 ENTRY SmallIntegerDot {
1458   arg0 = s32[16,2] iota(), iota_dimension=0
1459   a = s32[16,2] multiply(arg0, arg0)
1460   r = s32[16,2] multiply(a, a)
1461   arg1 = s32[2,98] iota(), iota_dimension=1
1462   b = s32[2,98] multiply(arg1, arg1)
1463   s = s32[2,98] multiply(b, b)
1464   ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1465 }
1466 )";
1467 
1468   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1469 }
1470 
XLA_TEST_F(DotOperationTextTest,U16IotaDot)1471 XLA_TEST_F(DotOperationTextTest, U16IotaDot) {
1472   absl::string_view hlo_string =
1473       R"(
1474 HloModule SmallIntegerDot
1475 
1476 ENTRY SmallIntegerDot {
1477   arg0 = u16[5,55,8] parameter(0)
1478   arg1 = u16[5,8,200] parameter(1)
1479   dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1480   ROOT c = s32[5,55,200] convert(dot)
1481 }
1482 )";
1483 
1484   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1485 }
1486 
XLA_TEST_F(DotOperationTextTest,U16IotaSquaredDot)1487 XLA_TEST_F(DotOperationTextTest, U16IotaSquaredDot) {
1488   absl::string_view hlo_string =
1489       R"(
1490 HloModule SmallIntegerDot
1491 
1492 ENTRY SmallIntegerDot {
1493   arg0 = u16[16,2] iota(), iota_dimension=0
1494   a = u16[16,2] multiply(arg0, arg0)
1495   r = u16[16,2] multiply(a, a)
1496   arg1 = u16[2,98] iota(), iota_dimension=1
1497   b = u16[2,98] multiply(arg1, arg1)
1498   s = u16[2,98] multiply(b, b)
1499   ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1500 }
1501 )";
1502 
1503   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1504 }
1505 
XLA_TEST_F(DotOperationTextTest,S16IotaDot)1506 XLA_TEST_F(DotOperationTextTest, S16IotaDot) {
1507   absl::string_view hlo_string =
1508       R"(
1509 HloModule SmallIntegerDot
1510 
1511 ENTRY SmallIntegerDot {
1512   arg0 = s16[5,55,8] iota(), iota_dimension=1
1513   arg1 = s16[5,8,200] iota(), iota_dimension=2
1514   ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1515 }
1516 )";
1517 
1518   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1519 }
1520 
XLA_TEST_F(DotOperationTextTest,S16IotaSquaredDot)1521 XLA_TEST_F(DotOperationTextTest, S16IotaSquaredDot) {
1522   absl::string_view hlo_string =
1523       R"(
1524 HloModule SmallIntegerDot
1525 
1526 ENTRY SmallIntegerDot {
1527   arg0 = s16[16,2] iota(), iota_dimension=0
1528   a = s16[16,2] multiply(arg0, arg0)
1529   r = s16[16,2] multiply(a, a)
1530   arg1 = s16[2,98] iota(), iota_dimension=1
1531   b = s16[2,98] multiply(arg1, arg1)
1532   s = s16[2,98] multiply(b, b)
1533   ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1534 }
1535 )";
1536 
1537   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1538 }
1539 
XLA_TEST_F(DotOperationTextTest,PREDDot)1540 XLA_TEST_F(DotOperationTextTest, PREDDot) {
1541   absl::string_view hlo_string =
1542       R"(
1543 HloModule SmallIntegerDot
1544 
1545 ENTRY SmallIntegerDot {
1546   arg0 = pred[20,2] parameter(0)
1547   arg1 = pred[2,20] parameter(1)
1548   ROOT dot = pred[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1549 }
1550 )";
1551 
1552   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1553 }
1554 
XLA_TEST_F(DotOperationTextTest,S8Dot)1555 XLA_TEST_F(DotOperationTextTest, S8Dot) {
1556   absl::string_view hlo_string =
1557       R"(
1558 HloModule SmallIntegerDot
1559 
1560 ENTRY SmallIntegerDot {
1561   arg0 = s8[20,2] parameter(0)
1562   arg1 = s8[2,20] parameter(1)
1563   ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1564 }
1565 )";
1566 
1567   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1568 }
1569 
XLA_TEST_F(DotOperationTextTest,S32Dot)1570 XLA_TEST_F(DotOperationTextTest, S32Dot) {
1571   absl::string_view hlo_string =
1572       R"(
1573 HloModule SmallIntegerDot
1574 
1575 ENTRY SmallIntegerDot {
1576   arg0 = s32[20,55] parameter(0)
1577   arg1 = s32[55,20] parameter(1)
1578   ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1579 }
1580 )";
1581 
1582   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1583 }
1584 
XLA_TEST_F(DotOperationTextTest,GpuTransposeOutput)1585 XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) {
1586   absl::string_view hlo_string =
1587       R"(
1588 HloModule TransposeOutput
1589 
1590 ENTRY TransposeOutput {
1591   p0 = f32[32,32] parameter(0)
1592   p1 = f32[32,64] parameter(1)
1593   dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1594   ROOT tr = f32[64,32] transpose(dot), dimensions={1,0}
1595 }
1596 )";
1597 
1598   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1599 }
1600 
XLA_TEST_F(DotOperationTextTest,MatrixVectorComplex)1601 XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) {
1602   absl::string_view hlo_string =
1603       R"(
1604 HloModule MatrixVectorComplex
1605 
1606 ENTRY MatrixVectorComplex {
1607   p0 = c64[5,5] parameter(0)
1608   p1 = c64[5,1] parameter(1)
1609   p2 = c64[5,1] parameter(2)
1610   dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1611   ROOT add = c64[5,1] add(dot, p2)
1612 }
1613 )";
1614 
1615   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1616                           ParseAndReturnUnverifiedModule(hlo_string));
1617   EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
1618 }
1619 
XLA_TEST_F(DotOperationTextTest,MatrixVectorBF16)1620 XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
1621   absl::string_view hlo_string =
1622       R"(
1623 HloModule MatrixVectorBF16
1624 
1625 ENTRY MatrixVectorBF16 {
1626   p0 = bf16[128] parameter(0)
1627   p1 = bf16[128,256] parameter(1)
1628   p2 = bf16[256] parameter(2)
1629   dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1630   ROOT add = bf16[256] add(dot, p2)
1631 }
1632 )";
1633 
1634   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1635 }
1636 
1637 // Regression test for b/138155357, where we were incorrectly creating a dot-add
1638 // fusion where the dot had a batch dimension.  This isn't supported on the CPU
1639 // backend.
XLA_TEST_F(DotOperationTextTest,FusedBatchDotRegressionTest)1640 XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) {
1641   absl::string_view module_string = R"(
1642 HloModule jaxpr_computation__5.33
1643 
1644 jaxpr_computation__6.8 {
1645   tuple.9 = () tuple()
1646   parameter.14 = () parameter(4)
1647   parameter.13 = (f32[2]{0}) parameter(3)
1648   get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0
1649   reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15)
1650   parameter.10 = f32[2,2]{1,0} parameter(0)
1651   reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15)
1652   dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1653   reshape.19 = f32[2]{0} reshape(dot.18)
1654   reshape.20 = f32[2,1]{1,0} reshape(reshape.19)
1655   dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1656   reshape.22 = f32[] reshape(dot.21)
1657   parameter.11 = f32[2,1,2]{2,1,0} parameter(1)
1658   broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2}
1659   dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1660   broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2}
1661   parameter.12 = f32[2,2,1]{2,1,0} parameter(2)
1662   dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1663   add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26)
1664   reshape.28 = f32[2]{0} reshape(add.27)
1665   ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28)
1666 }
1667 
1668 ENTRY jaxpr_computation__5.33 {
1669   constant.2 = f32[] constant(1)
1670   broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={}
1671   constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } })
1672   constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } })
1673   parameter.6 = f32[2]{0} parameter(0)
1674   tuple.7 = (f32[2]{0}) tuple(parameter.6)
1675   tuple.1 = () tuple()
1676   call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8
1677   get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0
1678   ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1
1679 })";
1680   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1681                           ParseAndReturnVerifiedModule(module_string));
1682   EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/absl::nullopt));
1683 }
1684 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstLHS_RL)1685 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) {
1686   Array3D<float> input_arr(2, 3, 2);
1687   Array2D<float> const_arr(2, 6);
1688   input_arr.FillIota(0);
1689   const_arr.FillIota(0);
1690 
1691   XlaBuilder builder(TestName());
1692   auto t0 =
1693       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1694   auto t1 = Transpose(t0, {1, 0, 2});
1695   auto rhs = Reshape(t1, {6, 2});
1696   auto lhs = ConstantR2FromArray2D(&builder, const_arr);
1697   Dot(lhs, rhs);
1698 
1699   ComputeAndCompare(&builder, {}, error_spec_);
1700 }
1701 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_LR)1702 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) {
1703   Array3D<float> input_arr(2, 3, 2);
1704   Array2D<float> const_arr(2, 6);
1705   input_arr.FillIota(0);
1706   const_arr.FillIota(0);
1707 
1708   XlaBuilder builder(TestName());
1709   auto t0 =
1710       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1711   auto t1 = Transpose(t0, {1, 0, 2});
1712   auto lhs = Reshape(t1, {6, 2});
1713   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1714 
1715   DotDimensionNumbers dims;
1716   dims.add_lhs_contracting_dimensions(0);
1717   dims.add_rhs_contracting_dimensions(1);
1718   DotGeneral(lhs, rhs, dims);
1719 
1720   ComputeAndCompare(&builder, {}, error_spec_);
1721 }
1722 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_RL)1723 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) {
1724   Array4D<float> input_arr(2, 2, 3, 4);
1725   Array2D<float> const_arr(24, 2);
1726   input_arr.FillIota(0);
1727   const_arr.FillIota(0);
1728 
1729   XlaBuilder builder(TestName());
1730   auto t0 =
1731       AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1732   auto t1 = Transpose(t0, {0, 2, 3, 1});
1733   auto lhs = Reshape(t1, {2, 24});
1734   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1735   Dot(lhs, rhs);
1736 
1737   ComputeAndCompare(&builder, {}, error_spec_);
1738 }
1739 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_MM)1740 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) {
1741   Array3D<float> input_arr(2, 6, 2);
1742   Array3D<float> const_arr(2, 6, 3);
1743   input_arr.FillIota(0);
1744   const_arr.FillIota(0);
1745 
1746   XlaBuilder builder(TestName());
1747   auto t0 =
1748       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1749   auto t1 = Reshape(t0, {2, 2, 3, 2});
1750   auto t2 = Transpose(t1, {0, 2, 1, 3});
1751   auto lhs = Reshape(t2, {2, 6, 2});
1752   auto rhs = ConstantR3FromArray3D(&builder, const_arr);
1753 
1754   DotDimensionNumbers dims;
1755   dims.add_lhs_contracting_dimensions(1);
1756   dims.add_rhs_contracting_dimensions(1);
1757   dims.add_lhs_batch_dimensions(0);
1758   dims.add_rhs_batch_dimensions(0);
1759   DotGeneral(lhs, rhs, dims);
1760 
1761   ComputeAndCompare(&builder, {}, error_spec_);
1762 }
1763 
XLA_TEST_F(DotOperationTest,ReorderContractingDims_Multipass)1764 XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
1765   Array4D<float> input_arr(2, 2, 3, 5);
1766   Array2D<float> const_arr(2, 30);
1767   input_arr.FillIota(0);
1768   const_arr.FillIota(0);
1769 
1770   XlaBuilder builder(TestName());
1771   auto t0 =
1772       AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1773   auto t1 = Transpose(t0, {0, 2, 1, 3});
1774   auto t2 = Reshape(t1, {2, 6, 5});
1775   auto t3 = Transpose(t2, {0, 2, 1});
1776   auto lhs = Reshape(t3, {2, 30});
1777   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1778 
1779   DotDimensionNumbers dims;
1780   dims.add_lhs_contracting_dimensions(1);
1781   dims.add_rhs_contracting_dimensions(1);
1782   DotGeneral(lhs, rhs, dims);
1783 
1784   // Constant folding are disabled by default in unit tests. algsimp
1785   // optimization can be applied multiple times if we fold the transpose
1786   // and reshape that are moved to the constant side of the dot.
1787   mutable_debug_options()->clear_xla_disable_hlo_passes();
1788   ComputeAndCompare(&builder, {}, error_spec_);
1789 }
1790 
XLA_TEST_F(DotOperationTextTest,WiderIntegralResultAccumulation)1791 XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
1792   absl::string_view hlo_string =
1793       R"(
1794 HloModule WiderIntegralAccumulation
1795 
1796 ENTRY MatrixVectorComplex {
1797   p0 = s8[5,5]{1,0} parameter(0)
1798   p1 = s16[5,1]{0,1} parameter(1)
1799   ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
1800                                         rhs_contracting_dims={0}
1801 }
1802 )";
1803 
1804   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1805 }
1806 
1807 // This benchmark is to show the performance impact of the following
1808 // transformation:
1809 //   dot(reshape(transpose(A)), Const) ==>
1810 //   dot(reshape(A), reshape(transpose(reshape(Const)))),
1811 // and then fold the reshape and transpose on the Const side.
1812 // We can compare performance with and without algsimp pass to see the impact.
DOT_ReorderContracting(::testing::benchmark::State & state)1813 void DOT_ReorderContracting(::testing::benchmark::State& state) {
1814   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1815   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1816   se::StreamExecutorMemoryAllocator allocator(platform, executors);
1817 
1818   xla::LocalClientOptions client_options;
1819   client_options.set_platform(platform);
1820   auto client =
1821       ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
1822 
1823   int device_ordinal = client->default_device_ordinal();
1824 
1825   const int64 d0 = 128;
1826   const int64 d1 = 128;
1827   const int64 d2 = 128;
1828   const int64 d3 = 128;
1829 
1830   Array3D<float> input_arr(d0, d1, d2);
1831   Array2D<float> const_arr(d1 * d2, d3);
1832   input_arr.FillIota(0);
1833   const_arr.FillIota(0);
1834   XlaBuilder builder("ReorderContracting");
1835   auto t0 =
1836       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0");
1837   auto t1 = Transpose(t0, {0, 2, 1});
1838   auto lhs = Reshape(t1, {d0, d2 * d1});
1839   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1840   Dot(lhs, rhs);
1841   auto computation = builder.Build().ConsumeValueOrDie();
1842 
1843   auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr);
1844   ScopedShapedBuffer buffer0 =
1845       client->LiteralToShapedBuffer(input_literal, device_ordinal)
1846           .ConsumeValueOrDie();
1847 
1848   TF_ASSERT_OK_AND_ASSIGN(
1849       auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
1850                                         ExecutableBuildOptions()));
1851   auto executable = std::move(executables[0]);
1852 
1853   se::Stream stream(executors[device_ordinal]);
1854   stream.Init();
1855 
1856   ExecutableRunOptions options;
1857   options.set_allocator(&allocator);
1858 
1859   const int kWarmups = 2;
1860   for (int i = 0; i < kWarmups; ++i) {
1861     ASSERT_IS_OK(executable->Run({&buffer0}, options));
1862   }
1863 
1864   const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3;
1865   for (auto s : state) {
1866     ASSERT_IS_OK(executable->Run({&buffer0}, options));
1867   }
1868   state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float));
1869 }
1870 
1871 BENCHMARK(DOT_ReorderContracting)->UseRealTime();
1872 
1873 }  // namespace
1874 }  // namespace xla
1875