1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/reference_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39 namespace {
40
41 class DotOperationTest : public ClientLibraryTestBase {
42 public:
43 ErrorSpec error_spec_{0.0001, 1e-5};
44 };
45
46 using TypesF16F32 = ::testing::Types<
47 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
48 Eigen::half,
49 #endif
50 float>;
51
52 using TypesF16F32F64 = ::testing::Types<
53 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
54 Eigen::half,
55 #endif
56 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
57 double,
58 #endif
59 float>;
60
61 using TypesF16F32F64CF64 = ::testing::Types<
62 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
63 Eigen::half,
64 #endif
65 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
66 double, complex64,
67 #endif
68 float>;
69
70 // Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest,DotOfInputTupleElem)71 XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
72 XlaBuilder builder(TestName());
73
74 XlaOp param;
75 TF_ASSERT_OK_AND_ASSIGN(
76 auto param_data,
77 CreateParameterAndTransferLiteral(
78 0,
79 LiteralUtil::MakeTupleFromSlices(
80 {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
81 LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
82 "arg0", &builder, ¶m));
83 auto lhs = GetTupleElement(param, 0);
84 auto rhs = GetTupleElement(param, 1);
85 Dot(lhs, rhs);
86
87 ComputeAndCompareLiteral(&builder,
88 LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
89 {param_data.get()});
90 }
91
92 template <typename T>
93 class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
94 TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);
95
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ZeroElementVectorDot)96 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
97 using T = TypeParam;
98 XlaBuilder builder(this->TestName());
99
100 auto lhs = ConstantR1<T>(&builder, {});
101 auto rhs = ConstantR1<T>(&builder, {});
102 Dot(lhs, rhs);
103
104 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
105 this->error_spec_);
106 }
107
108 template <typename T>
109 class DotOperationTest_F16F32F64 : public DotOperationTest {};
110 TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64);
111
XLA_TYPED_TEST(DotOperationTest_F16F32F64,TrivialMatrixVectorDot)112 XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
113 using T = TypeParam;
114 XlaBuilder builder(this->TestName());
115 auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}});
116 auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f});
117 Dot(lhs, rhs);
118
119 this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
120 this->error_spec_);
121 }
122
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,OneElementVectorDot)123 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
124 using T = TypeParam;
125 XlaBuilder builder(this->TestName());
126 auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
127 auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)});
128 Dot(lhs, rhs);
129
130 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
131 this->error_spec_);
132 }
133
XLA_TYPED_TEST(DotOperationTest_F16F32F64,VectorDot)134 XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
135 using T = TypeParam;
136 XlaBuilder builder(this->TestName());
137 auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f});
138 auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f});
139 Dot(lhs, rhs);
140
141 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
142 this->error_spec_);
143 }
144
MinorToMajorForIsRowMajor(bool row_major)145 std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
146 return {row_major ? 1 : 0, row_major ? 0 : 1};
147 }
148
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x0)149 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
150 using T = TypeParam;
151 XlaBuilder builder(this->TestName());
152 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
153 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
154 Dot(lhs, rhs);
155
156 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
157 this->error_spec_);
158 }
159
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x3)160 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
161 using T = TypeParam;
162 XlaBuilder builder(this->TestName());
163 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
164 auto rhs = ConstantR2FromArray2D<T>(
165 &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
166 Dot(lhs, rhs);
167
168 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
169 this->error_spec_);
170 }
171
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_3x2_2x0)172 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
173 using T = TypeParam;
174 XlaBuilder builder(this->TestName());
175 auto lhs = ConstantR2FromArray2D<T>(
176 &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
177 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
178 Dot(lhs, rhs);
179
180 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
181 this->error_spec_);
182 }
183
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_2x0_0x2)184 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
185 using T = TypeParam;
186 XlaBuilder builder(this->TestName());
187 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
188 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
189 Dot(lhs, rhs);
190
191 this->template ComputeAndCompareR2<T>(
192 &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
193 }
194
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,FusedDot)195 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
196 using T = TypeParam;
197 XlaBuilder builder(this->TestName());
198 auto param0 =
199 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
200 auto param1 =
201 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
202 auto exp0 = Exp(param0);
203 Dot(exp0, param1);
204
205 auto lhs_handle =
206 this->client_
207 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
208 {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
209 .ConsumeValueOrDie();
210 auto rhs_handle = this->client_
211 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
212 {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
213 .ConsumeValueOrDie();
214
215 if (std::is_same<Eigen::half, T>::value) {
216 this->error_spec_ = ErrorSpec{0.0001, 1e-3};
217 }
218
219 this->template ComputeAndCompareR2<T>(
220 &builder, Array2D<T>({{296.14560492846033f}, {0.8611737683031964f}}),
221 {lhs_handle.get(), rhs_handle.get()}, this->error_spec_);
222 }
223
224 template <typename T>
225 class SquareMatrixDot : public DotOperationTest {
226 public:
TestImpl(bool lhs_row_major,bool rhs_row_major)227 void TestImpl(bool lhs_row_major, bool rhs_row_major) {
228 auto lhs_handle =
229 client_
230 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
231 {{1.0f, 2.0f}, {3.0f, -4.0f}},
232 LayoutUtil::MakeLayout(
233 MinorToMajorForIsRowMajor(lhs_row_major))))
234 .ConsumeValueOrDie();
235 auto rhs_handle =
236 client_
237 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
238 {{1.0f, 6.0f}, {7.0f, -4.0f}},
239 LayoutUtil::MakeLayout(
240 MinorToMajorForIsRowMajor(rhs_row_major))))
241 .ConsumeValueOrDie();
242 XlaBuilder builder(TestName());
243 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
244 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
245 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
246
247 Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}});
248 ComputeAndCompareR2<T>(&builder, expected,
249 {lhs_handle.get(), rhs_handle.get()}, error_spec_);
250 }
251 };
252
253 TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(SquareMatrixDot,TypesFF)254 XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesFT)255 XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTF)256 XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTT)257 XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); }
258
259 struct DotTestParam {
260 int m;
261 int k;
262 int n;
263 bool dot_lhs_row_major;
264 bool dot_rhs_row_major;
265 bool has_addend;
266 bool addend_row_major;
267 };
268
PrintDotTestParam(const::testing::TestParamInfo<DotTestParam> & test_param)269 string PrintDotTestParam(
270 const ::testing::TestParamInfo<DotTestParam>& test_param) {
271 const DotTestParam& param = test_param.param;
272 if (param.has_addend) {
273 return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
274 param.dot_lhs_row_major ? "T" : "F",
275 param.dot_rhs_row_major ? "T" : "F",
276 param.addend_row_major ? "T" : "F");
277 } else {
278 return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
279 param.dot_lhs_row_major ? "T" : "F",
280 param.dot_rhs_row_major ? "T" : "F");
281 }
282 }
283
284 class ParametricDotTest : public DotOperationTest,
285 public ::testing::WithParamInterface<DotTestParam> {
286 protected:
287 template <typename NativeT>
288 void TestImpl();
289
290 template <typename NativeT>
291 void ComputeAndCompareR2WithError(XlaBuilder* builder,
292 const Array2D<NativeT>& expected,
293 absl::Span<GlobalData* const> arguments);
294 };
295
296 template <typename NativeT>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)297 void ParametricDotTest::ComputeAndCompareR2WithError(
298 XlaBuilder* builder, const Array2D<NativeT>& expected,
299 absl::Span<GlobalData* const> arguments) {
300 ErrorSpec error_spec(0.3, 3e-3);
301 ComputeAndCompareR2(builder, expected, arguments, error_spec);
302 }
303
304 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<Eigen::half> & expected,absl::Span<GlobalData * const> arguments)305 void ParametricDotTest::ComputeAndCompareR2WithError<Eigen::half>(
306 XlaBuilder* builder, const Array2D<Eigen::half>& expected,
307 absl::Span<GlobalData* const> arguments) {
308 ErrorSpec error_spec(0.3, 7e-3);
309 ComputeAndCompareR2(builder, expected, arguments, error_spec);
310 }
311
312 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<int32> & expected,absl::Span<GlobalData * const> arguments)313 void ParametricDotTest::ComputeAndCompareR2WithError<int32>(
314 XlaBuilder* builder, const Array2D<int32>& expected,
315 absl::Span<GlobalData* const> arguments) {
316 ComputeAndCompareR2(builder, expected, arguments);
317 }
318
319 template <typename NativeT>
TestImpl()320 void ParametricDotTest::TestImpl() {
321 DotTestParam param = GetParam();
322
323 std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
324 MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
325 Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
326 *dot_lhs_data, LayoutUtil::MakeLayout(
327 MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
328 std::unique_ptr<GlobalData> dot_lhs_handle =
329 client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
330
331 std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
332 MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
333 Layout rhs_layout = LayoutUtil::MakeLayout(
334 MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
335 Literal dot_rhs_lit =
336 LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
337 std::unique_ptr<GlobalData> dot_rhs_handle =
338 client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
339
340 std::unique_ptr<Array2D<NativeT>> addend_data;
341 Literal addend_lit;
342 std::unique_ptr<GlobalData> addend_handle;
343
344 if (param.has_addend) {
345 addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
346 addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
347 *addend_data, LayoutUtil::MakeLayout(
348 MinorToMajorForIsRowMajor(param.addend_row_major)));
349 addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
350 }
351
352 XlaBuilder builder(TestName());
353 auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>();
354 auto result =
355 Dot(Parameter(&builder, 0,
356 ShapeUtil::MakeShapeWithLayout(
357 prim_type, {param.m, param.k},
358 MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
359 "dot_lhs"),
360 Parameter(&builder, 1,
361 ShapeUtil::MakeShapeWithLayout(
362 prim_type, {param.k, param.n},
363 MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
364 "dot_rhs"));
365
366 if (param.has_addend) {
367 result =
368 Add(result,
369 Parameter(&builder, 2,
370 ShapeUtil::MakeShapeWithLayout(
371 prim_type, {param.m, param.n},
372 MinorToMajorForIsRowMajor(param.addend_row_major)),
373 "addend"));
374 }
375
376 std::unique_ptr<Array2D<NativeT>> expected;
377 if (param.has_addend) {
378 expected = ReferenceUtil::ApplyElementwise2D(
379 std::plus<NativeT>(),
380 *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
381 *addend_data);
382 } else {
383 expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
384 }
385
386 std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
387 if (param.has_addend) {
388 args.push_back(addend_handle.get());
389 }
390 ComputeAndCompareR2WithError<NativeT>(&builder, *expected, args);
391 }
392
CreateDotTestParameters()393 std::vector<DotTestParam> CreateDotTestParameters() {
394 std::vector<DotTestParam> params;
395
396 auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
397 for (bool lhs_row_major : {true, false}) {
398 for (bool rhs_row_major : {true, false}) {
399 params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
400 /*dot_lhs_row_major=*/lhs_row_major,
401 /*dot_rhs_row_major=*/rhs_row_major,
402 /*has_addend=*/false, /*addend_row_major=*/true});
403 }
404 }
405 };
406
407 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
408 add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
409 add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
410 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
411 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
412 add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
413 add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
414 add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
415
416 return params;
417 }
418
419 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTest,TestF16)420 XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
421 #endif
XLA_TEST_P(ParametricDotTest,TestF32)422 XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
XLA_TEST_P(ParametricDotTest,TestF64)423 XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
XLA_TEST_P(ParametricDotTest,TestC64)424 XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
425 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
XLA_TEST_P(ParametricDotTest,TestC128)426 XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
427 #endif
XLA_TEST_P(ParametricDotTest,TestS32)428 XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32>(); }
429
430 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
431 ::testing::ValuesIn(CreateDotTestParameters()),
432 PrintDotTestParam);
433
434 class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
435 public:
ParametricDotTestWithoutLayoutAssignment()436 ParametricDotTestWithoutLayoutAssignment() {
437 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
438 "layout-assignment");
439 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
440 "tiling-assignment");
441 // Disable algebraic simplification because the pass may replace a dot
442 // instruction with a layout-changing multiplication instruction.
443 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
444 "algsimp");
445 }
446 };
447
CreateNoLayoutAssignmentDotTestParameters()448 std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
449 std::vector<DotTestParam> params;
450
451 auto add_matrix_vector_dot_test = [&](int k, int n) {
452 for (bool lhs_row_major : {true, false}) {
453 for (bool rhs_row_major : {true, false}) {
454 for (bool has_addend : {true, false}) {
455 // The addend needs to be row major to match the result of the dot.
456 params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
457 /*dot_lhs_row_major=*/lhs_row_major,
458 /*dot_rhs_row_major=*/rhs_row_major,
459 /*has_addend=*/has_addend,
460 /*addend_row_major=*/true});
461 if (n != 1) {
462 params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
463 /*dot_lhs_row_major=*/lhs_row_major,
464 /*dot_rhs_row_major=*/rhs_row_major,
465 /*has_addend=*/has_addend,
466 /*addend_row_major=*/true});
467 }
468 }
469 }
470 }
471 };
472
473 add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
474 add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
475 add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
476 add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
477 add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
478 add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
479 add_matrix_vector_dot_test(/*k=*/1, /*n=*/4);
480 add_matrix_vector_dot_test(/*k=*/1, /*n=*/3);
481 add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
482 add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
483 add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
484 add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
485 add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
486 add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
487
488 return params;
489 }
490
491 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF16)492 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) {
493 TestImpl<Eigen::half>();
494 }
495 #endif
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF32)496 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) {
497 TestImpl<float>();
498 }
499 // TODO(b/147505663): Disabled for now.
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,DISABLED_TestF64)500 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, DISABLED_TestF64) {
501 TestImpl<double>();
502 }
503
504 INSTANTIATE_TEST_CASE_P(
505 DotTests, ParametricDotTestWithoutLayoutAssignment,
506 ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()),
507 PrintDotTestParam);
508
509 template <typename T>
510 class NonsquareMatrixDot : public DotOperationTest {
511 public:
TestImpl(bool lhs_row_major,bool rhs_row_major)512 void TestImpl(bool lhs_row_major, bool rhs_row_major) {
513 auto lhs_handle =
514 client_
515 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
516 {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
517 LayoutUtil::MakeLayout(
518 MinorToMajorForIsRowMajor(lhs_row_major))))
519 .ConsumeValueOrDie();
520 auto rhs_handle =
521 client_
522 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
523 {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
524 LayoutUtil::MakeLayout(
525 MinorToMajorForIsRowMajor(rhs_row_major))))
526 .ConsumeValueOrDie();
527
528 XlaBuilder builder(TestName());
529 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
530 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
531 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
532
533 Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
534
535 ComputeAndCompareR2<T>(&builder, expected,
536 {lhs_handle.get(), rhs_handle.get()}, error_spec_);
537 }
538 };
539
540 TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(NonsquareMatrixDot,TestFF)541 XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestFT)542 XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTF)543 XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTT)544 XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
545
XLA_TEST_F(DotOperationTest,MatrixVectorC64)546 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
547 auto lhs_handle =
548 client_
549 ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
550 {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
551 .ConsumeValueOrDie();
552 auto rhs_handle =
553 client_
554 ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
555 {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
556 LayoutUtil::MakeLayout({1, 0})))
557 .ConsumeValueOrDie();
558
559 XlaBuilder builder(TestName());
560 auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
561 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
562 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
563
564 Array2D<complex64> expected({{30.0, -2.0}});
565
566 ComputeAndCompareR2<complex64>(
567 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
568 }
569
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ConcurrentMatMult)570 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
571 using T = TypeParam;
572
573 XlaBuilder builder(this->TestName());
574 auto matrix1 =
575 ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}});
576 auto matrix2 =
577 ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}});
578 auto matrix12 = Dot(matrix1, matrix2);
579 auto matrix21 = Dot(matrix2, matrix1);
580 Add(matrix12, matrix21);
581
582 Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}});
583 this->template ComputeAndCompareR2<T>(&builder, expected, {},
584 this->error_spec_);
585 }
586
587 template <typename T>
588 class DotOperationTestForBatchMatMul : public DotOperationTest {};
589 TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64);
590
591 // Regression test for b/32055648. The root of the graph is a kFusion of 4
592 // bitcasts. Although bitcasts don't map to thunks, the root should still be
593 // sync-dependent on bitcasts' operands.
XLA_TYPED_TEST(DotOperationTestForBatchMatMul,Types)594 XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
595 using T = TypeParam;
596 XlaBuilder builder(this->TestName());
597 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
598 "x");
599 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
600 "y");
601
602 auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
603 auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
604
605 // Slice batches into individual matrices and multiply them.
606 std::vector<XlaOp> out_slices;
607 for (int i = 0; i < 4; ++i) {
608 // Slice off individual matrices and reshape to 2D tensors.
609 auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
610 x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2});
611 auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
612 y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2});
613
614 auto out = Dot(x_slice, y_slice);
615 out = Reshape(out, {0, 1}, {1, 2, 2});
616 out_slices.push_back(out);
617 }
618 auto out_flat = ConcatInDim(&builder, out_slices, 0);
619 Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
620
621 auto x_data = this->client_
622 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
623 {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
624 {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
625 {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
626 {{4000.0f, 400.0f}, {40.0f, 4.0f}}}}))
627 .ConsumeValueOrDie();
628 auto y_data =
629 this->client_
630 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
631 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
632 {{{11.0f, 22.0f}, {33.0f, 44.0f}},
633 {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
634 .ConsumeValueOrDie();
635
636 if (std::is_same<Eigen::half, T>::value) {
637 this->error_spec_ = ErrorSpec{0.0001, 1e-3};
638 }
639 this->template ComputeAndCompareR4<T>(
640 &builder,
641 /*expected=*/
642 {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}},
643 {{11400.0f, 13600.0f}, {114.0f, 136.0f}}},
644 {{{42900.0f, 79200.0f}, {429.0f, 792.0f}},
645 {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}},
646 {x_data.get(), y_data.get()}, this->error_spec_);
647 }
648
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMul)649 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
650 using T = TypeParam;
651
652 XlaBuilder builder(this->TestName());
653 auto x =
654 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
655 auto y =
656 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
657
658 DotDimensionNumbers dnums;
659 dnums.add_lhs_contracting_dimensions(2);
660 dnums.add_rhs_contracting_dimensions(1);
661 dnums.add_lhs_batch_dimensions(0);
662 dnums.add_rhs_batch_dimensions(0);
663
664 DotGeneral(x, y, dnums);
665
666 auto x_data =
667 this->client_
668 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
669 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
670 .ConsumeValueOrDie();
671
672 auto y_data =
673 this->client_
674 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
675 {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
676 .ConsumeValueOrDie();
677
678 this->template ComputeAndCompareR3<T>(
679 &builder,
680 /*expected=*/
681 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
682 {x_data.get(), y_data.get()}, this->error_spec_);
683 }
684
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR3LhsR2Rhs)685 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) {
686 using T = TypeParam;
687
688 XlaBuilder builder(this->TestName());
689 auto x =
690 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
691 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y");
692
693 DotDimensionNumbers dnums;
694 dnums.add_lhs_contracting_dimensions(1);
695 dnums.add_rhs_contracting_dimensions(1);
696 dnums.add_lhs_batch_dimensions(0);
697 dnums.add_rhs_batch_dimensions(0);
698
699 DotGeneral(x, y, dnums);
700
701 auto x_data =
702 this->client_
703 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
704 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
705 .ConsumeValueOrDie();
706
707 auto y_data = this->client_
708 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
709 {{1.0f, 0.0f}, {0.0f, 1.0f}}))
710 .ConsumeValueOrDie();
711
712 this->template ComputeAndCompareR2<T>(
713 &builder,
714 /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
715 this->error_spec_);
716 }
717
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR2LhsR3Rhs)718 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) {
719 using T = TypeParam;
720
721 XlaBuilder builder(this->TestName());
722 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x");
723 auto y =
724 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
725
726 DotDimensionNumbers dnums;
727 dnums.add_lhs_contracting_dimensions(1);
728 dnums.add_rhs_contracting_dimensions(1);
729 dnums.add_lhs_batch_dimensions(0);
730 dnums.add_rhs_batch_dimensions(0);
731
732 DotGeneral(x, y, dnums);
733
734 auto x_data = this->client_
735 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
736 {{1.0f, 0.0f}, {0.0f, 1.0f}}))
737 .ConsumeValueOrDie();
738
739 auto y_data =
740 this->client_
741 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
742 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
743 .ConsumeValueOrDie();
744
745 this->template ComputeAndCompareR2<T>(
746 &builder,
747 /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
748 this->error_spec_);
749 }
750
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulMultipleBatch)751 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
752 using T = TypeParam;
753
754 XlaBuilder builder(this->TestName());
755 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
756 "x");
757 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
758 "y");
759
760 DotDimensionNumbers dnums;
761 dnums.add_lhs_contracting_dimensions(3);
762 dnums.add_rhs_contracting_dimensions(2);
763 dnums.add_lhs_batch_dimensions(0);
764 dnums.add_lhs_batch_dimensions(1);
765 dnums.add_rhs_batch_dimensions(0);
766 dnums.add_rhs_batch_dimensions(1);
767
768 DotGeneral(x, y, dnums);
769
770 auto x_data =
771 this->client_
772 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
773 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
774 {{{9.0f, 10.0f}, {11.0f, 12.0f}},
775 {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
776 .ConsumeValueOrDie();
777
778 auto y_data =
779 this->client_
780 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
781 {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
782 {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
783 .ConsumeValueOrDie();
784
785 this->template ComputeAndCompareR4<T>(
786 &builder,
787 /*expected=*/
788 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
789 {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
790 {x_data.get(), y_data.get()}, this->error_spec_);
791 }
792
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,TransposeFolding)793 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
794 using T = TypeParam;
795 for (bool transpose_lhs : {false, true}) {
796 for (bool transpose_rhs : {false, true}) {
797 for (bool row_major : {false, true}) {
798 std::unique_ptr<Array2D<T>> lhs(
799 new Array2D<T>({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}));
800 std::unique_ptr<Array2D<T>> rhs(
801 new Array2D<T>({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}));
802
803 if (transpose_lhs) {
804 lhs = ReferenceUtil::TransposeArray2D(*lhs);
805 }
806 if (transpose_rhs) {
807 rhs = ReferenceUtil::TransposeArray2D(*rhs);
808 }
809 auto lhs_handle =
810 this->client_
811 ->TransferToServer(
812 LiteralUtil::CreateR2FromArray2DWithLayout<T>(
813 *lhs, LayoutUtil::MakeLayout(
814 MinorToMajorForIsRowMajor(row_major))))
815 .ConsumeValueOrDie();
816 auto rhs_handle =
817 this->client_
818 ->TransferToServer(
819 LiteralUtil::CreateR2FromArray2DWithLayout<T>(
820 *rhs, LayoutUtil::MakeLayout(
821 MinorToMajorForIsRowMajor(row_major))))
822 .ConsumeValueOrDie();
823
824 XlaBuilder builder(this->TestName());
825 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
826 auto lhs_arg = Parameter(
827 &builder, 0,
828 ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
829 "lhs");
830 auto rhs_arg = Parameter(
831 &builder, 1,
832 ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
833 "rhs");
834 if (transpose_lhs) {
835 lhs_arg = Transpose(lhs_arg, {1, 0});
836 }
837 if (transpose_rhs) {
838 rhs_arg = Transpose(rhs_arg, {1, 0});
839 }
840 Dot(lhs_arg, rhs_arg);
841
842 Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
843 VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
844 << transpose_rhs << " " << row_major;
845 this->template ComputeAndCompareR2<T>(
846 &builder, expected, {lhs_handle.get(), rhs_handle.get()},
847 this->error_spec_);
848 }
849 }
850 }
851 }
852
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstLHS)853 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
854 DotOfConcatOptimizationWithConstLHS) {
855 using T = TypeParam;
856 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
857
858 std::unique_ptr<Array2D<T>> constant_lhs_array(
859 new Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
860 {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}}));
861
862 XlaBuilder builder(this->TestName());
863 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
864 auto rhs_arg_0 = Parameter(
865 &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0");
866 auto rhs_arg_1 = Parameter(
867 &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1");
868 auto rhs_arg_2 = Parameter(
869 &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2");
870 Dot(lhs_constant,
871 ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
872
873 std::unique_ptr<Array2D<T>> arg_0_value_array(
874 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
875 std::unique_ptr<Array2D<T>> arg_1_value_array(
876 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}));
877 std::unique_ptr<Array2D<T>> arg_2_value_array(new Array2D<T>({{1.0f, 2.0f}}));
878
879 TF_ASSERT_OK_AND_ASSIGN(
880 auto arg_0_value,
881 this->client_->TransferToServer(
882 LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
883 TF_ASSERT_OK_AND_ASSIGN(
884 auto arg_1_value,
885 this->client_->TransferToServer(
886 LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
887 TF_ASSERT_OK_AND_ASSIGN(
888 auto arg_2_value,
889 this->client_->TransferToServer(
890 LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
891
892 Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
893 this->template ComputeAndCompareR2<T>(
894 &builder, expected,
895 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
896 this->error_spec_);
897 }
898
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstRHS)899 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
900 DotOfConcatOptimizationWithConstRHS) {
901 using T = TypeParam;
902 std::unique_ptr<Array2D<T>> constant_rhs_array(
903 new Array2D<T>({{1.0f, 2.0f},
904 {3.0f, 4.0f},
905 {5.0f, 6.0f},
906 {6.0f, 5.0f},
907 {4.0f, 3.0f},
908 {2.0f, 1.0f}}));
909
910 XlaBuilder builder(this->TestName());
911 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
912 auto lhs_arg_0 = Parameter(
913 &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0");
914 auto lhs_arg_1 = Parameter(
915 &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1");
916 auto lhs_arg_2 = Parameter(
917 &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2");
918 Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
919 rhs_constant);
920
921 std::unique_ptr<Array2D<T>> arg_0_value_array(
922 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
923 std::unique_ptr<Array2D<T>> arg_1_value_array(
924 new Array2D<T>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}));
925 std::unique_ptr<Array2D<T>> arg_2_value_array(
926 new Array2D<T>({{1.0f}, {2.0f}}));
927
928 TF_ASSERT_OK_AND_ASSIGN(
929 auto arg_0_value,
930 this->client_->TransferToServer(
931 LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
932 TF_ASSERT_OK_AND_ASSIGN(
933 auto arg_1_value,
934 this->client_->TransferToServer(
935 LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
936 TF_ASSERT_OK_AND_ASSIGN(
937 auto arg_2_value,
938 this->client_->TransferToServer(
939 LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
940
941 Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
942 this->template ComputeAndCompareR2<T>(
943 &builder, expected,
944 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
945 this->error_spec_);
946 }
947
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSClassicMM)948 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
949 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
950 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
951 std::unique_ptr<Array2D<float>> constant_rhs_array(
952 new Array2D<float>({{1.0, 2.0, 3.0},
953 {4.0, 5.0, 6.0},
954 {7.0, 8.0, 9.0},
955 {9.0, 8.0, 7.0},
956 {6.0, 5.0, 4.0},
957 {3.0, 2.0, 1.0}}));
958 // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
959
960 XlaBuilder builder(TestName());
961 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
962 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
963 auto one = ConstantR0<int32>(&builder, 1);
964 auto zero = ConstantR0<int32>(&builder, 0);
965 auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
966
967 DotDimensionNumbers dot_dnums;
968 dot_dnums.add_lhs_contracting_dimensions(1);
969 dot_dnums.add_rhs_contracting_dimensions(0);
970 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
971
972 Array2D<float> expected({{96.0, 105.0, 114.0}});
973 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
974 }
975
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSClassicMM)976 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
977 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
978 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
979 std::unique_ptr<Array2D<float>> constant_rhs_array(
980 new Array2D<float>({{1.0, 2.0, 3.0},
981 {4.0, 5.0, 6.0},
982 {7.0, 8.0, 9.0},
983 {9.0, 8.0, 7.0},
984 {6.0, 5.0, 4.0},
985 {3.0, 2.0, 1.0}}));
986 // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
987
988 XlaBuilder builder(TestName());
989 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
990 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
991 auto zero = ConstantR0<int32>(&builder, 0);
992 auto one = ConstantR0<int32>(&builder, 1);
993 auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
994
995 DotDimensionNumbers dot_dnums;
996 dot_dnums.add_lhs_contracting_dimensions(1);
997 dot_dnums.add_rhs_contracting_dimensions(0);
998 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
999
1000 Array2D<float> expected({{105.0}, {105.0}});
1001 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1002 }
1003
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSReverseMM)1004 XLA_TEST_F(DotOperationTest,
1005
1006 DotOfGatherOptimizationWithConstRHSReverseMM) {
1007 std::unique_ptr<Array2D<float>> constant_lhs_array(
1008 new Array2D<float>({{1.0, 2.0, 3.0},
1009 {4.0, 5.0, 6.0},
1010 {7.0, 8.0, 9.0},
1011 {9.0, 8.0, 7.0},
1012 {6.0, 5.0, 4.0},
1013 {3.0, 2.0, 1.0}}));
1014 std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1015 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1016 // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1017
1018 XlaBuilder builder(TestName());
1019 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1020 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1021 auto zero = ConstantR0<int32>(&builder, 0);
1022 auto one = ConstantR0<int32>(&builder, 1);
1023 auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1024
1025 DotDimensionNumbers dot_dnums;
1026 dot_dnums.add_lhs_contracting_dimensions(0);
1027 dot_dnums.add_rhs_contracting_dimensions(1);
1028 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1029
1030 Array2D<float> expected({{105.0, 105.0}});
1031 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1032 }
1033
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSReverseMM)1034 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
1035 std::unique_ptr<Array2D<float>> constant_lhs_array(
1036 new Array2D<float>({{1.0, 2.0, 3.0},
1037 {4.0, 5.0, 6.0},
1038 {7.0, 8.0, 9.0},
1039 {9.0, 8.0, 7.0},
1040 {6.0, 5.0, 4.0},
1041 {3.0, 2.0, 1.0}}));
1042 std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1043 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1044 // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1045
1046 XlaBuilder builder(TestName());
1047 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1048 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1049 auto zero = ConstantR0<int32>(&builder, 0);
1050 auto one = ConstantR0<int32>(&builder, 1);
1051 auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1052
1053 DotDimensionNumbers dot_dnums;
1054 dot_dnums.add_lhs_contracting_dimensions(0);
1055 dot_dnums.add_rhs_contracting_dimensions(1);
1056 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1057
1058 Array2D<float> expected({{96.0}, {105.0}, {114.0}});
1059 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1060 }
1061
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSRows)1062 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
1063 std::unique_ptr<Array2D<float>> constant_lhs_array(
1064 new Array2D<float>({{1.0, 2.0},
1065 {3.0, 4.0},
1066 {5.0, 6.0},
1067 {6.0, 5.0},
1068 {4.0, 3.0},
1069 {2.0, 1.0}}));
1070 std::unique_ptr<Array2D<float>> constant_rhs_array(
1071 new Array2D<float>({{1.0, 2.0, 3.0},
1072 {4.0, 5.0, 6.0},
1073 {7.0, 8.0, 9.0},
1074 {9.0, 8.0, 7.0},
1075 {6.0, 5.0, 4.0},
1076 {3.0, 2.0, 1.0}}));
1077 // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1078
1079 XlaBuilder builder(TestName());
1080 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1081 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1082 auto zero = ConstantR0<int32>(&builder, 0);
1083 auto one = ConstantR0<int32>(&builder, 1);
1084 auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1085
1086 DotDimensionNumbers dot_dnums;
1087 dot_dnums.add_lhs_contracting_dimensions(0);
1088 dot_dnums.add_rhs_contracting_dimensions(0);
1089 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1090
1091 Array2D<float> expected({{126.0, 129.0, 132.0}});
1092 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1093 }
1094
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSRows)1095 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
1096 std::unique_ptr<Array2D<float>> constant_lhs_array(
1097 new Array2D<float>({{1.0, 2.0},
1098 {3.0, 4.0},
1099 {5.0, 6.0},
1100 {6.0, 5.0},
1101 {4.0, 3.0},
1102 {2.0, 1.0}}));
1103 std::unique_ptr<Array2D<float>> constant_rhs_array(
1104 new Array2D<float>({{1.0, 2.0, 3.0},
1105 {4.0, 5.0, 6.0},
1106 {7.0, 8.0, 9.0},
1107 {9.0, 8.0, 7.0},
1108 {6.0, 5.0, 4.0},
1109 {3.0, 2.0, 1.0}}));
1110 // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1111
1112 XlaBuilder builder(TestName());
1113 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1114 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1115 auto zero = ConstantR0<int32>(&builder, 0);
1116 auto one = ConstantR0<int32>(&builder, 1);
1117 auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
1118
1119 DotDimensionNumbers dot_dnums;
1120 dot_dnums.add_lhs_contracting_dimensions(0);
1121 dot_dnums.add_rhs_contracting_dimensions(0);
1122 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1123
1124 Array2D<float> expected({{129.0}, {129.0}});
1125 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1126 }
1127
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSCols)1128 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
1129 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1130 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1131 std::unique_ptr<Array2D<float>> constant_rhs_array(
1132 new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1133 {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1134 {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1135 // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1136
1137 XlaBuilder builder(TestName());
1138 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1139 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1140 auto zero = ConstantR0<int32>(&builder, 0);
1141 auto one = ConstantR0<int32>(&builder, 1);
1142 auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
1143
1144 DotDimensionNumbers dot_dnums;
1145 dot_dnums.add_lhs_contracting_dimensions(1);
1146 dot_dnums.add_rhs_contracting_dimensions(1);
1147 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1148
1149 Array2D<float> expected({{56.0, 168.0, 91.0}});
1150 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1151 }
1152
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSCols)1153 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
1154 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1155 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1156 std::unique_ptr<Array2D<float>> constant_rhs_array(
1157 new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1158 {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1159 {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1160 // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1161
1162 XlaBuilder builder(TestName());
1163 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1164 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1165 auto zero = ConstantR0<int32>(&builder, 0);
1166 auto one = ConstantR0<int32>(&builder, 1);
1167 auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1168
1169 DotDimensionNumbers dot_dnums;
1170 dot_dnums.add_lhs_contracting_dimensions(1);
1171 dot_dnums.add_rhs_contracting_dimensions(1);
1172 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1173
1174 Array2D<float> expected({{168.0}, {168.0}});
1175 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1176 }
1177
XLA_TEST_F(DotOperationTest,DotRank2AndRank2NonDefaultContractionDims)1178 XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
1179 XlaBuilder builder(TestName());
1180
1181 Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
1182 auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
1183
1184 Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
1185 auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
1186
1187 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1188 DotDimensionNumbers dot_dnums;
1189 dot_dnums.add_lhs_contracting_dimensions(0);
1190 dot_dnums.add_rhs_contracting_dimensions(0);
1191 DotGeneral(lhs_constant, rhs_constant, dot_dnums);
1192
1193 Array2D<float> expected({
1194 {26.f, 30.f},
1195 {38.f, 44.f},
1196 });
1197
1198 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1199 }
1200
1201 using EinsumParamType =
1202 std::tuple<std::vector<int64>, std::vector<int64>, string>;
1203 class EinsumTest : public DotOperationTest,
1204 public ::testing::WithParamInterface<EinsumParamType> {};
XLA_TEST_P(EinsumTest,SimpleEinsumTest)1205 XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
1206 XlaBuilder builder(TestName());
1207 auto x = AddParam(
1208 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1209 .ValueOrDie(),
1210 &builder);
1211 auto y = AddParam(
1212 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1213 .ValueOrDie(),
1214 &builder);
1215 auto config = std::get<2>(GetParam());
1216 if (config.find(',') == config.npos) {
1217 Einsum(x, config);
1218 } else {
1219 Einsum(x, y, config);
1220 }
1221 ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1222 }
1223
GetEinsumTestCases()1224 std::vector<EinsumParamType> GetEinsumTestCases() {
1225 using v = std::vector<int64>;
1226 using p = EinsumParamType;
1227 std::vector<p> test_cases = {
1228 p{v{5, 6}, v{6, 7}, "mk,kn->mn"},
1229 p{v{5, 6}, v{6, 7}, "mk,kn->nm"},
1230 p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"},
1231 p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"},
1232 p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"},
1233 p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
1234 p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
1235 p{v{6}, v{6, 7}, "b,bc->c"},
1236 p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
1237 p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
1238 p{v{77}, v{77}, "a,a->a"},
1239 p{v{77}, v{77, 55}, "a,ab->ba"},
1240 p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
1241 p{v{55}, v{}, "a,->a"},
1242 p{v{11, 111}, v{11}, "ab,a->ab"},
1243 p{v{16, 34}, v{16, 34}, "ab,ab->ab"},
1244 p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"},
1245 p{v{5, 19}, v{}, "ab,->ab"},
1246 p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"},
1247 p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
1248 p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
1249 p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
1250 p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
1251 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1252 p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1253 p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
1254 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
1255 p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
1256 p{v{5, 6}, v{6, 7}, "mk,kn"},
1257 p{v{5, 6}, v{6, 7}, "mk,kn"},
1258 p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
1259 p{v{5, 6}, v{6, 7}, "ab,cd"},
1260 p{v{6}, v{6, 7}, "b,bc"},
1261 p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
1262 p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
1263 p{v{77}, v{77}, "a,a"},
1264 p{v{77}, v{77, 55}, "a,ab"},
1265 p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
1266 p{v{55}, v{}, "a"},
1267 p{v{11, 111}, v{11}, "ab,a"},
1268 p{v{16, 34}, v{16, 34}, "ab,ab"},
1269 p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
1270 p{v{5, 19}, v{}, "ab"},
1271 p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
1272 p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
1273 p{v{5, 6}, v{}, "...mk"},
1274 p{v{5, 6, 12, 13}, v{}, "...mk"},
1275 p{v{5, 6, 12, 13}, v{}, "m...k"},
1276 p{v{5, 6, 12, 13}, v{}, "mk..."},
1277 p{v{5, 6}, v{6, 7}, "...mk->km..."},
1278 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1279 p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1280 p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
1281 p{v{16, 16, 16}, v{}, "iii"},
1282 p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
1283 };
1284 return test_cases;
1285 }
1286
1287 INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest,
1288 ::testing::ValuesIn(GetEinsumTestCases()));
1289
1290 using BatchDotParamType =
1291 std::tuple<std::vector<int64>, std::vector<int64>, std::vector<int64>>;
1292 class BatchDotTest : public DotOperationTest,
1293 public ::testing::WithParamInterface<BatchDotParamType> {};
XLA_TEST_P(BatchDotTest,BroadcastingBatchDotTest)1294 XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) {
1295 XlaBuilder builder(TestName());
1296 auto x = AddParam(
1297 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1298 .ValueOrDie(),
1299 &builder);
1300 auto y = AddParam(
1301 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1302 .ValueOrDie(),
1303 &builder);
1304 auto batch_dot = BatchDot(x, y);
1305 auto output_shape = builder.GetShape(batch_dot).ValueOrDie();
1306 EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam()));
1307 ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1308 }
1309
GetBatchDotTestCases()1310 std::vector<BatchDotParamType> GetBatchDotTestCases() {
1311 using v = std::vector<int64>;
1312 using p = BatchDotParamType;
1313 std::vector<p> test_cases = {
1314 p{v{5, 6}, v{6, 7}, v{5, 7}},
1315 p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1316 p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}},
1317 p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}},
1318 p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1319 p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1320 p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}},
1321 p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}},
1322 };
1323 return test_cases;
1324 }
1325
1326 INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest,
1327 ::testing::ValuesIn(GetBatchDotTestCases()));
1328
1329 class DotOperationTextTest : public HloTestBase {};
1330
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDims)1331 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
1332 absl::string_view hlo_string =
1333 R"(
1334 HloModule ComplexDotMultipleNonContracting
1335
1336 ENTRY %test {
1337 %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
1338 %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
1339 ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
1340 }
1341 )";
1342
1343 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1344 }
1345
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDimsAndMultipleContracting)1346 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
1347 absl::string_view hlo_string =
1348 R"(
1349 HloModule ComplexDotMultipleNonContracting
1350
1351 ENTRY %test {
1352 %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
1353 %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
1354 ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
1355 }
1356 )";
1357
1358 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1359 }
1360
XLA_TEST_F(DotOperationTextTest,DotWithNoDnums)1361 XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) {
1362 absl::string_view hlo_string =
1363 R"(
1364 HloModule DotWithNoDnums
1365
1366 ENTRY %test {
1367 %lhs = f32[2,3]{1,0} parameter(0)
1368 %rhs = f32[4,5]{1,0} parameter(1)
1369 ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs)
1370 }
1371 )";
1372
1373 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1374 }
1375
XLA_TEST_F(DotOperationTextTest,Einsum)1376 XLA_TEST_F(DotOperationTextTest, Einsum) {
1377 absl::string_view hlo_string =
1378 R"(
1379 HloModule Einsum
1380
1381 ENTRY %test {
1382 %lhs = f32[8,64,96]{2,1,0} parameter(0)
1383 %rhs = f32[96,32,4]{2,1,0} parameter(1)
1384 ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0}
1385 }
1386 )";
1387
1388 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1389 }
1390
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_1)1391 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
1392 // Tests for a caching bug in the XLA CPU backend.
1393 absl::string_view hlo_string =
1394 R"(
1395 HloModule CpuTiledDotEmitterCachingBug
1396
1397 ENTRY main {
1398 lhs = f32[20,40] parameter(0)
1399 rhs_0 = f32[40,1] parameter(2)
1400 rhs_1 = f32[1,40] parameter(1)
1401
1402 dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1403 dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1404
1405 ROOT result = f32[20,1] divide(dot_0, dot_1)
1406 }
1407 )";
1408
1409 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1410 }
1411
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_2)1412 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
1413 // Tests for a caching bug in the XLA CPU backend.
1414 absl::string_view hlo_string =
1415 R"(
1416 HloModule CpuTiledDotEmitterCachingBug
1417
1418 ENTRY main {
1419 lhs_0 = f32[20,40] parameter(0)
1420 rhs_0 = f32[40,1] parameter(1)
1421 lhs_1 = f32[1,40] parameter(2)
1422 rhs_1 = f32[20,40] parameter(3)
1423
1424 dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1425 dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1426
1427 dot_0_reshaped = f32[20] reshape(dot_0)
1428 dot_1_reshaped = f32[20] reshape(dot_1)
1429
1430 ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
1431 }
1432 )";
1433
1434 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1435 }
1436
XLA_TEST_F(DotOperationTextTest,S32IotaDot)1437 XLA_TEST_F(DotOperationTextTest, S32IotaDot) {
1438 absl::string_view hlo_string =
1439 R"(
1440 HloModule SmallIntegerDot
1441
1442 ENTRY SmallIntegerDot {
1443 arg0 = s32[5,55,8] iota(), iota_dimension=1
1444 arg1 = s32[5,8,200] iota(), iota_dimension=2
1445 ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1446 }
1447 )";
1448
1449 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1450 }
1451
XLA_TEST_F(DotOperationTextTest,S32IotaSquaredDot)1452 XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) {
1453 absl::string_view hlo_string =
1454 R"(
1455 HloModule SmallIntegerDot
1456
1457 ENTRY SmallIntegerDot {
1458 arg0 = s32[16,2] iota(), iota_dimension=0
1459 a = s32[16,2] multiply(arg0, arg0)
1460 r = s32[16,2] multiply(a, a)
1461 arg1 = s32[2,98] iota(), iota_dimension=1
1462 b = s32[2,98] multiply(arg1, arg1)
1463 s = s32[2,98] multiply(b, b)
1464 ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1465 }
1466 )";
1467
1468 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1469 }
1470
XLA_TEST_F(DotOperationTextTest,U16IotaDot)1471 XLA_TEST_F(DotOperationTextTest, U16IotaDot) {
1472 absl::string_view hlo_string =
1473 R"(
1474 HloModule SmallIntegerDot
1475
1476 ENTRY SmallIntegerDot {
1477 arg0 = u16[5,55,8] parameter(0)
1478 arg1 = u16[5,8,200] parameter(1)
1479 dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1480 ROOT c = s32[5,55,200] convert(dot)
1481 }
1482 )";
1483
1484 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1485 }
1486
XLA_TEST_F(DotOperationTextTest,U16IotaSquaredDot)1487 XLA_TEST_F(DotOperationTextTest, U16IotaSquaredDot) {
1488 absl::string_view hlo_string =
1489 R"(
1490 HloModule SmallIntegerDot
1491
1492 ENTRY SmallIntegerDot {
1493 arg0 = u16[16,2] iota(), iota_dimension=0
1494 a = u16[16,2] multiply(arg0, arg0)
1495 r = u16[16,2] multiply(a, a)
1496 arg1 = u16[2,98] iota(), iota_dimension=1
1497 b = u16[2,98] multiply(arg1, arg1)
1498 s = u16[2,98] multiply(b, b)
1499 ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1500 }
1501 )";
1502
1503 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1504 }
1505
XLA_TEST_F(DotOperationTextTest,S16IotaDot)1506 XLA_TEST_F(DotOperationTextTest, S16IotaDot) {
1507 absl::string_view hlo_string =
1508 R"(
1509 HloModule SmallIntegerDot
1510
1511 ENTRY SmallIntegerDot {
1512 arg0 = s16[5,55,8] iota(), iota_dimension=1
1513 arg1 = s16[5,8,200] iota(), iota_dimension=2
1514 ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1515 }
1516 )";
1517
1518 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1519 }
1520
XLA_TEST_F(DotOperationTextTest,S16IotaSquaredDot)1521 XLA_TEST_F(DotOperationTextTest, S16IotaSquaredDot) {
1522 absl::string_view hlo_string =
1523 R"(
1524 HloModule SmallIntegerDot
1525
1526 ENTRY SmallIntegerDot {
1527 arg0 = s16[16,2] iota(), iota_dimension=0
1528 a = s16[16,2] multiply(arg0, arg0)
1529 r = s16[16,2] multiply(a, a)
1530 arg1 = s16[2,98] iota(), iota_dimension=1
1531 b = s16[2,98] multiply(arg1, arg1)
1532 s = s16[2,98] multiply(b, b)
1533 ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1534 }
1535 )";
1536
1537 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1538 }
1539
XLA_TEST_F(DotOperationTextTest,PREDDot)1540 XLA_TEST_F(DotOperationTextTest, PREDDot) {
1541 absl::string_view hlo_string =
1542 R"(
1543 HloModule SmallIntegerDot
1544
1545 ENTRY SmallIntegerDot {
1546 arg0 = pred[20,2] parameter(0)
1547 arg1 = pred[2,20] parameter(1)
1548 ROOT dot = pred[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1549 }
1550 )";
1551
1552 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1553 }
1554
XLA_TEST_F(DotOperationTextTest,S8Dot)1555 XLA_TEST_F(DotOperationTextTest, S8Dot) {
1556 absl::string_view hlo_string =
1557 R"(
1558 HloModule SmallIntegerDot
1559
1560 ENTRY SmallIntegerDot {
1561 arg0 = s8[20,2] parameter(0)
1562 arg1 = s8[2,20] parameter(1)
1563 ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1564 }
1565 )";
1566
1567 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1568 }
1569
XLA_TEST_F(DotOperationTextTest,S32Dot)1570 XLA_TEST_F(DotOperationTextTest, S32Dot) {
1571 absl::string_view hlo_string =
1572 R"(
1573 HloModule SmallIntegerDot
1574
1575 ENTRY SmallIntegerDot {
1576 arg0 = s32[20,55] parameter(0)
1577 arg1 = s32[55,20] parameter(1)
1578 ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1579 }
1580 )";
1581
1582 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1583 }
1584
XLA_TEST_F(DotOperationTextTest,GpuTransposeOutput)1585 XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) {
1586 absl::string_view hlo_string =
1587 R"(
1588 HloModule TransposeOutput
1589
1590 ENTRY TransposeOutput {
1591 p0 = f32[32,32] parameter(0)
1592 p1 = f32[32,64] parameter(1)
1593 dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1594 ROOT tr = f32[64,32] transpose(dot), dimensions={1,0}
1595 }
1596 )";
1597
1598 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1599 }
1600
XLA_TEST_F(DotOperationTextTest,MatrixVectorComplex)1601 XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) {
1602 absl::string_view hlo_string =
1603 R"(
1604 HloModule MatrixVectorComplex
1605
1606 ENTRY MatrixVectorComplex {
1607 p0 = c64[5,5] parameter(0)
1608 p1 = c64[5,1] parameter(1)
1609 p2 = c64[5,1] parameter(2)
1610 dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1611 ROOT add = c64[5,1] add(dot, p2)
1612 }
1613 )";
1614
1615 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1616 ParseAndReturnUnverifiedModule(hlo_string));
1617 EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
1618 }
1619
XLA_TEST_F(DotOperationTextTest,MatrixVectorBF16)1620 XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
1621 absl::string_view hlo_string =
1622 R"(
1623 HloModule MatrixVectorBF16
1624
1625 ENTRY MatrixVectorBF16 {
1626 p0 = bf16[128] parameter(0)
1627 p1 = bf16[128,256] parameter(1)
1628 p2 = bf16[256] parameter(2)
1629 dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1630 ROOT add = bf16[256] add(dot, p2)
1631 }
1632 )";
1633
1634 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1635 }
1636
1637 // Regression test for b/138155357, where we were incorrectly creating a dot-add
1638 // fusion where the dot had a batch dimension. This isn't supported on the CPU
1639 // backend.
XLA_TEST_F(DotOperationTextTest,FusedBatchDotRegressionTest)1640 XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) {
1641 absl::string_view module_string = R"(
1642 HloModule jaxpr_computation__5.33
1643
1644 jaxpr_computation__6.8 {
1645 tuple.9 = () tuple()
1646 parameter.14 = () parameter(4)
1647 parameter.13 = (f32[2]{0}) parameter(3)
1648 get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0
1649 reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15)
1650 parameter.10 = f32[2,2]{1,0} parameter(0)
1651 reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15)
1652 dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1653 reshape.19 = f32[2]{0} reshape(dot.18)
1654 reshape.20 = f32[2,1]{1,0} reshape(reshape.19)
1655 dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1656 reshape.22 = f32[] reshape(dot.21)
1657 parameter.11 = f32[2,1,2]{2,1,0} parameter(1)
1658 broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2}
1659 dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1660 broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2}
1661 parameter.12 = f32[2,2,1]{2,1,0} parameter(2)
1662 dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1663 add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26)
1664 reshape.28 = f32[2]{0} reshape(add.27)
1665 ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28)
1666 }
1667
1668 ENTRY jaxpr_computation__5.33 {
1669 constant.2 = f32[] constant(1)
1670 broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={}
1671 constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } })
1672 constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } })
1673 parameter.6 = f32[2]{0} parameter(0)
1674 tuple.7 = (f32[2]{0}) tuple(parameter.6)
1675 tuple.1 = () tuple()
1676 call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8
1677 get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0
1678 ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1
1679 })";
1680 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1681 ParseAndReturnVerifiedModule(module_string));
1682 EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/absl::nullopt));
1683 }
1684
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstLHS_RL)1685 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) {
1686 Array3D<float> input_arr(2, 3, 2);
1687 Array2D<float> const_arr(2, 6);
1688 input_arr.FillIota(0);
1689 const_arr.FillIota(0);
1690
1691 XlaBuilder builder(TestName());
1692 auto t0 =
1693 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1694 auto t1 = Transpose(t0, {1, 0, 2});
1695 auto rhs = Reshape(t1, {6, 2});
1696 auto lhs = ConstantR2FromArray2D(&builder, const_arr);
1697 Dot(lhs, rhs);
1698
1699 ComputeAndCompare(&builder, {}, error_spec_);
1700 }
1701
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_LR)1702 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) {
1703 Array3D<float> input_arr(2, 3, 2);
1704 Array2D<float> const_arr(2, 6);
1705 input_arr.FillIota(0);
1706 const_arr.FillIota(0);
1707
1708 XlaBuilder builder(TestName());
1709 auto t0 =
1710 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1711 auto t1 = Transpose(t0, {1, 0, 2});
1712 auto lhs = Reshape(t1, {6, 2});
1713 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1714
1715 DotDimensionNumbers dims;
1716 dims.add_lhs_contracting_dimensions(0);
1717 dims.add_rhs_contracting_dimensions(1);
1718 DotGeneral(lhs, rhs, dims);
1719
1720 ComputeAndCompare(&builder, {}, error_spec_);
1721 }
1722
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_RL)1723 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) {
1724 Array4D<float> input_arr(2, 2, 3, 4);
1725 Array2D<float> const_arr(24, 2);
1726 input_arr.FillIota(0);
1727 const_arr.FillIota(0);
1728
1729 XlaBuilder builder(TestName());
1730 auto t0 =
1731 AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1732 auto t1 = Transpose(t0, {0, 2, 3, 1});
1733 auto lhs = Reshape(t1, {2, 24});
1734 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1735 Dot(lhs, rhs);
1736
1737 ComputeAndCompare(&builder, {}, error_spec_);
1738 }
1739
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_MM)1740 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) {
1741 Array3D<float> input_arr(2, 6, 2);
1742 Array3D<float> const_arr(2, 6, 3);
1743 input_arr.FillIota(0);
1744 const_arr.FillIota(0);
1745
1746 XlaBuilder builder(TestName());
1747 auto t0 =
1748 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1749 auto t1 = Reshape(t0, {2, 2, 3, 2});
1750 auto t2 = Transpose(t1, {0, 2, 1, 3});
1751 auto lhs = Reshape(t2, {2, 6, 2});
1752 auto rhs = ConstantR3FromArray3D(&builder, const_arr);
1753
1754 DotDimensionNumbers dims;
1755 dims.add_lhs_contracting_dimensions(1);
1756 dims.add_rhs_contracting_dimensions(1);
1757 dims.add_lhs_batch_dimensions(0);
1758 dims.add_rhs_batch_dimensions(0);
1759 DotGeneral(lhs, rhs, dims);
1760
1761 ComputeAndCompare(&builder, {}, error_spec_);
1762 }
1763
XLA_TEST_F(DotOperationTest,ReorderContractingDims_Multipass)1764 XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
1765 Array4D<float> input_arr(2, 2, 3, 5);
1766 Array2D<float> const_arr(2, 30);
1767 input_arr.FillIota(0);
1768 const_arr.FillIota(0);
1769
1770 XlaBuilder builder(TestName());
1771 auto t0 =
1772 AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1773 auto t1 = Transpose(t0, {0, 2, 1, 3});
1774 auto t2 = Reshape(t1, {2, 6, 5});
1775 auto t3 = Transpose(t2, {0, 2, 1});
1776 auto lhs = Reshape(t3, {2, 30});
1777 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1778
1779 DotDimensionNumbers dims;
1780 dims.add_lhs_contracting_dimensions(1);
1781 dims.add_rhs_contracting_dimensions(1);
1782 DotGeneral(lhs, rhs, dims);
1783
1784 // Constant folding are disabled by default in unit tests. algsimp
1785 // optimization can be applied multiple times if we fold the transpose
1786 // and reshape that are moved to the constant side of the dot.
1787 mutable_debug_options()->clear_xla_disable_hlo_passes();
1788 ComputeAndCompare(&builder, {}, error_spec_);
1789 }
1790
XLA_TEST_F(DotOperationTextTest,WiderIntegralResultAccumulation)1791 XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
1792 absl::string_view hlo_string =
1793 R"(
1794 HloModule WiderIntegralAccumulation
1795
1796 ENTRY MatrixVectorComplex {
1797 p0 = s8[5,5]{1,0} parameter(0)
1798 p1 = s16[5,1]{0,1} parameter(1)
1799 ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
1800 rhs_contracting_dims={0}
1801 }
1802 )";
1803
1804 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1805 }
1806
1807 // This benchmark is to show the performance impact of the following
1808 // transformation:
1809 // dot(reshape(transpose(A)), Const) ==>
1810 // dot(reshape(A), reshape(transpose(reshape(Const)))),
1811 // and then fold the reshape and transpose on the Const side.
1812 // We can compare performance with and without algsimp pass to see the impact.
DOT_ReorderContracting(::testing::benchmark::State & state)1813 void DOT_ReorderContracting(::testing::benchmark::State& state) {
1814 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1815 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1816 se::StreamExecutorMemoryAllocator allocator(platform, executors);
1817
1818 xla::LocalClientOptions client_options;
1819 client_options.set_platform(platform);
1820 auto client =
1821 ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
1822
1823 int device_ordinal = client->default_device_ordinal();
1824
1825 const int64 d0 = 128;
1826 const int64 d1 = 128;
1827 const int64 d2 = 128;
1828 const int64 d3 = 128;
1829
1830 Array3D<float> input_arr(d0, d1, d2);
1831 Array2D<float> const_arr(d1 * d2, d3);
1832 input_arr.FillIota(0);
1833 const_arr.FillIota(0);
1834 XlaBuilder builder("ReorderContracting");
1835 auto t0 =
1836 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0");
1837 auto t1 = Transpose(t0, {0, 2, 1});
1838 auto lhs = Reshape(t1, {d0, d2 * d1});
1839 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1840 Dot(lhs, rhs);
1841 auto computation = builder.Build().ConsumeValueOrDie();
1842
1843 auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr);
1844 ScopedShapedBuffer buffer0 =
1845 client->LiteralToShapedBuffer(input_literal, device_ordinal)
1846 .ConsumeValueOrDie();
1847
1848 TF_ASSERT_OK_AND_ASSIGN(
1849 auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
1850 ExecutableBuildOptions()));
1851 auto executable = std::move(executables[0]);
1852
1853 se::Stream stream(executors[device_ordinal]);
1854 stream.Init();
1855
1856 ExecutableRunOptions options;
1857 options.set_allocator(&allocator);
1858
1859 const int kWarmups = 2;
1860 for (int i = 0; i < kWarmups; ++i) {
1861 ASSERT_IS_OK(executable->Run({&buffer0}, options));
1862 }
1863
1864 const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3;
1865 for (auto s : state) {
1866 ASSERT_IS_OK(executable->Run({&buffer0}, options));
1867 }
1868 state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float));
1869 }
1870
1871 BENCHMARK(DOT_ReorderContracting)->UseRealTime();
1872
1873 } // namespace
1874 } // namespace xla
1875