1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/reference_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/test.h"
34
35 namespace xla {
36 namespace {
37
38 using ConcatTest = ClientLibraryTestBase;
39 using ConcatTestHlo = HloTestBase;
40 using ::testing::HasSubstr;
41
42 // Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest,Concat_Nothing)43 XLA_TEST_F(ConcatTest, Concat_Nothing) {
44 XlaBuilder builder(TestName());
45 ConcatInDim(&builder, {}, 0);
46 StatusOr<XlaComputation> computation_status = builder.Build();
47 ASSERT_FALSE(computation_status.ok());
48 EXPECT_THAT(computation_status.status().ToString(),
49 HasSubstr("Concatenate expects at least one argument"));
50 }
51
52 // Concatenate with one argument works.
XLA_TEST_F(ConcatTest,Concat_R1_With_Nothing)53 XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
54 XlaBuilder builder(TestName());
55 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
56 ConcatInDim(&builder, {a}, 0);
57
58 std::vector<float> expected = {42, 64};
59 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
60 }
61
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_Nothing)62 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
63 XlaBuilder builder(TestName());
64 auto a = ConstantR1<float>(&builder, {});
65 ConcatInDim(&builder, {a}, 0);
66
67 std::vector<float> expected = {};
68 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
69 }
70
71 // Show that we can't concatenate R0 with R0 because we can't name the dimension
72 // to concatenate on.
XLA_TEST_F(ConcatTest,CannotConcatR0WithR0)73 XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
74 XlaBuilder builder(TestName());
75 auto a = ConstantR0<float>(&builder, 42.0);
76 auto b = ConstantR0<float>(&builder, 64.0);
77 ConcatInDim(&builder, {a, b}, 0);
78 StatusOr<XlaComputation> computation_status = builder.Build();
79 ASSERT_FALSE(computation_status.ok());
80 EXPECT_THAT(computation_status.status().ToString(),
81 HasSubstr("out of bounds: 0"));
82 }
83
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L0)84 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
85 XlaBuilder builder(TestName());
86 auto a = ConstantR1<float>(&builder, {});
87 auto b = ConstantR1<float>(&builder, {});
88 ConcatInDim(&builder, {a, b}, 0);
89
90 std::vector<float> expected = {};
91 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
92 }
93
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L1)94 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
95 XlaBuilder builder(TestName());
96 auto a = ConstantR1<float>(&builder, {});
97 auto b = ConstantR1<float>(&builder, {256.0});
98 ConcatInDim(&builder, {a, b}, 0);
99
100 std::vector<float> expected = {256};
101 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
102 }
103
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L0)104 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
105 XlaBuilder builder(TestName());
106 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
107 auto b = ConstantR1<float>(&builder, {});
108 ConcatInDim(&builder, {a, b}, 0);
109
110 std::vector<float> expected = {42, 64};
111 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
112 }
113
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L1)114 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
115 XlaBuilder builder(TestName());
116 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
117 auto b = ConstantR1<float>(&builder, {256.0});
118 ConcatInDim(&builder, {a, b}, 0);
119
120 std::vector<float> expected = {42, 64, 256};
121 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
122 }
123
XLA_TEST_F(ConcatTest,Concat_R1_L253_With_R1_L7)124 XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
125 std::vector<float> lhs(253);
126 std::vector<float> rhs(7);
127 std::vector<float> expected(253 + 7);
128 for (int i = 0; i < 253; ++i) {
129 expected[i] = lhs[i] = i + 1;
130 }
131 for (int i = 0; i < 7; ++i) {
132 expected[253 + i] = rhs[i] = 253 + i + 1;
133 }
134
135 XlaBuilder builder(TestName());
136 auto a = ConstantR1<float>(&builder, lhs);
137 auto b = ConstantR1<float>(&builder, rhs);
138 ConcatInDim(&builder, {a, b}, 0);
139
140 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
141 }
142
XLA_TEST_F(ConcatTest,Concat_0x0_With_0x0)143 XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
144 for (int dim : {0, 1}) {
145 XlaBuilder builder(TestName());
146 auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
147 auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
148 ConcatInDim(&builder, {a, b}, dim);
149
150 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
151 ErrorSpec(0.0001));
152 }
153 }
154
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim0)155 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
156 XlaBuilder builder(TestName());
157 auto a_array = CreatePatternedMatrix(1, 1);
158 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
159 auto a = ConstantR2FromArray2D(&builder, *a_array);
160 auto b = ConstantR2FromArray2D(&builder, *b_array);
161 ConcatInDim(&builder, {a, b}, 0);
162
163 Array2D<float> expected({
164 {0},
165 {64},
166 });
167 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
168 }
169
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim1)170 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
171 XlaBuilder builder(TestName());
172 auto a_array = CreatePatternedMatrix(1, 1);
173 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
174 auto a = ConstantR2FromArray2D(&builder, *a_array);
175 auto b = ConstantR2FromArray2D(&builder, *b_array);
176 ConcatInDim(&builder, {a, b}, 1);
177
178 Array2D<float> expected({
179 {0, 64},
180 });
181 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
182 }
183
XLA_TEST_F(ConcatTest,Concat2x0With2x5)184 XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
185 XlaBuilder builder(TestName());
186 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
187 auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
188 auto b = ConstantR2FromArray2D(&builder, *b_array);
189 ConcatInDim(&builder, {a, b}, 1);
190
191 ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
192 }
193
XLA_TEST_F(ConcatTest,Concat2x3With2x5)194 XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
195 XlaBuilder builder(TestName());
196 auto a_array = CreatePatternedMatrix(2, 3);
197 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
198 auto a = ConstantR2FromArray2D(&builder, *a_array);
199 auto b = ConstantR2FromArray2D(&builder, *b_array);
200 ConcatInDim(&builder, {a, b}, 1);
201
202 Array2D<float> expected({
203 {0, 1, 2, 64, 65, 66, 67, 68},
204 {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
205 });
206 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
207 }
208
XLA_TEST_F(ConcatTest,Concat3x2With0x2)209 XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
210 XlaBuilder builder(TestName());
211 auto a_array = CreatePatternedMatrix(3, 2);
212 auto a = ConstantR2FromArray2D(&builder, *a_array);
213 auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
214 ConcatInDim(&builder, {a, b}, 0);
215
216 ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
217 }
218
XLA_TEST_F(ConcatTest,Concat3x2With5x2)219 XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
220 XlaBuilder builder(TestName());
221 auto a_array = CreatePatternedMatrix(3, 2);
222 auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
223 auto a = ConstantR2FromArray2D(&builder, *a_array);
224 auto b = ConstantR2FromArray2D(&builder, *b_array);
225 ConcatInDim(&builder, {a, b}, 0);
226
227 Array2D<float> expected({
228 {0, 1},
229 {1000, 1001},
230 {2000, 2001},
231 {64, 65},
232 {1064, 1065},
233 {2064, 2065},
234 {3064, 3065},
235 {4064, 4065},
236 });
237 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
238 }
239
XLA_TEST_F(ConcatTest,Concat_R3_3x0x2_3x0x1)240 XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
241 XlaBuilder builder(TestName());
242 auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
243 auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
244 ConcatInDim(&builder, {a, b}, 2);
245 ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
246 ErrorSpec(0.0001));
247 }
248
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1)249 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
250 XlaBuilder builder(TestName());
251 Array3D<float> a_array({
252 // 3x1x2
253 {{0, 1}},
254 {{2, 3}},
255 {{4, 5}},
256 });
257 Array3D<float> b_array({
258 // 3x1x1
259 {{6}},
260 {{7}},
261 {{8}},
262 });
263 auto a = ConstantR3FromArray3D(&builder, a_array);
264 auto b = ConstantR3FromArray3D(&builder, b_array);
265 ConcatInDim(&builder, {a, b}, 2);
266
267 Array3D<float> expected({
268 {{0, 1, 6}},
269 {{2, 3, 7}},
270 {{4, 5, 8}},
271 });
272 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
273 }
274
XLA_TEST_F(ConcatTest,Concat_R1_1x1_1x1_1x1)275 XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
276 XlaBuilder builder(TestName());
277 auto a = ConstantR1<float>(&builder, {42.0});
278 auto b = ConstantR1<float>(&builder, {64.0});
279 auto c = ConstantR1<float>(&builder, {256.0});
280 ConcatInDim(&builder, {a, b, c}, 0);
281
282 std::vector<float> expected = {42, 64, 256};
283 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
284 }
285
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1_3x1x1)286 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
287 XlaBuilder builder(TestName());
288 Array3D<float> a_array({
289 // 3x1x2
290 {{0, 1}},
291 {{4, 5}},
292 {{8, 9}},
293 });
294 Array3D<float> b_array({
295 // 3x1x1
296 {{2}},
297 {{6}},
298 {{10}},
299 });
300 Array3D<float> c_array({
301 // 3x1x1
302 {{3}},
303 {{7}},
304 {{11}},
305 });
306 auto a = ConstantR3FromArray3D(&builder, a_array);
307 auto b = ConstantR3FromArray3D(&builder, b_array);
308 auto c = ConstantR3FromArray3D(&builder, c_array);
309 ConcatInDim(&builder, {a, b, c}, 2);
310
311 Array3D<float> expected({
312 {{0, 1, 2, 3}},
313 {{4, 5, 6, 7}},
314 {{8, 9, 10, 11}},
315 });
316 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
317 }
318
XLA_TEST_F(ConcatTest,DoubleConcatLeftAssociative)319 XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
320 XlaBuilder builder(TestName());
321 auto a = ConstantR1<float>(&builder, {42.0});
322 auto b = ConstantR1<float>(&builder, {64.0});
323 auto c = ConstantR1<float>(&builder, {256.0});
324 // concatenated = (a concat b) concat c
325 ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
326
327 std::vector<float> expected = {42, 64, 256};
328 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
329 }
330
XLA_TEST_F(ConcatTest,DoubleConcatRightAssociative)331 XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
332 XlaBuilder builder(TestName());
333 auto a = ConstantR1<float>(&builder, {42.0});
334 auto b = ConstantR1<float>(&builder, {64.0});
335 auto c = ConstantR1<float>(&builder, {256.0});
336 // concatenated = a concat (b concat c)
337 ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
338
339 std::vector<float> expected = {42, 64, 256};
340 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
341 }
342
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim0)343 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
344 Array2D<float> lhs(1, 1024);
345 Array2D<float> rhs(1, 1024);
346 for (int i = 0; i < 1024; ++i) {
347 lhs(0, i) = i;
348 rhs(0, i) = i + 1024;
349 }
350
351 XlaBuilder builder(TestName());
352 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
353 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
354 ConcatInDim(&builder, {a, b}, 0);
355
356 Array2D<float> expected(2, 1024);
357 for (int i = 0; i < 1024; ++i) {
358 expected(0, i) = i;
359 expected(1, i) = i + 1024;
360 }
361 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
362 }
363
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim1)364 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
365 Array2D<float> lhs(1, 1024);
366 Array2D<float> rhs(1, 1024);
367 for (int i = 0; i < 1024; ++i) {
368 lhs(0, i) = i;
369 rhs(0, i) = i + 1024;
370 }
371
372 XlaBuilder builder(TestName());
373 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
374 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
375 ConcatInDim(&builder, {a, b}, 1);
376
377 Array2D<float> expected(1, 2048);
378 for (int i = 0; i < 1024; ++i) {
379 expected(0, i) = i;
380 expected(0, i + 1024) = i + 1024;
381 }
382 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
383 }
384
XLA_TEST_F(ConcatTest,Concat_64x64_With_64x2)385 XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
386 Array2D<float> lhs(64, 64);
387 Array2D<float> rhs(64, 2);
388 for (int i0 = 0; i0 < 64; ++i0) {
389 for (int i1 = 0; i1 < 64; ++i1) {
390 lhs(i0, i1) = (i0 << 10) | i1;
391 }
392 for (int i1 = 0; i1 < 2; ++i1) {
393 rhs(i0, i1) = (i0 << 10) | (i1 + 64);
394 }
395 }
396
397 XlaBuilder builder(TestName());
398 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
399 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
400 ConcatInDim(&builder, {a, b}, 1);
401
402 Array2D<float> expected(64, 66);
403 for (int i0 = 0; i0 < 64; ++i0) {
404 for (int i1 = 0; i1 < 66; ++i1) {
405 expected(i0, i1) = (i0 << 10) | i1;
406 }
407 }
408 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
409 }
410
411 // Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest,CannotConcatOpaques)412 XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
413 XlaBuilder builder(TestName());
414 auto opaque_shape = ShapeUtil::MakeOpaqueShape();
415 auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
416 auto x = Parameter(&builder, 0, r1f32, "x");
417 auto y = Parameter(&builder, 1, opaque_shape, "y");
418 ConcatInDim(&builder, {x, y}, 0);
419 StatusOr<XlaComputation> computation_status = builder.Build();
420 ASSERT_FALSE(computation_status.ok());
421 EXPECT_THAT(
422 computation_status.status().ToString(),
423 HasSubstr("Expected array argument for operand of concatenation"));
424 }
425
426 // Show that we can't concatenate with tokens.
XLA_TEST_F(ConcatTest,CannotConcatTokens)427 XLA_TEST_F(ConcatTest, CannotConcatTokens) {
428 XlaBuilder builder(TestName());
429 auto token_shape = ShapeUtil::MakeTokenShape();
430 auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
431 auto x = Parameter(&builder, 0, r1f32, "x");
432 auto y = Parameter(&builder, 1, token_shape, "y");
433 ConcatInDim(&builder, {x, y}, 0);
434 StatusOr<XlaComputation> computation_status = builder.Build();
435 ASSERT_FALSE(computation_status.ok());
436 EXPECT_THAT(
437 computation_status.status().ToString(),
438 HasSubstr("Expected array argument for operand of concatenation"));
439 }
440
XLA_TEST_F(ConcatTest,ConcatSeveralBoxedPredicates)441 XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
442 XlaBuilder builder(TestName());
443 auto p0 = ConstantR1<bool>(&builder, {true});
444 auto p1 = ConstantR1<bool>(&builder, {false});
445 auto p2 = ConstantR1<bool>(&builder, {true});
446 ConcatInDim(&builder, {p0, p1, p2}, 0);
447
448 bool expected[] = {true, false, true};
449 ComputeAndCompareR1<bool>(&builder, expected, {});
450 }
451
XLA_TEST_F(ConcatTest,ConcatSeveralR1S32s)452 XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
453 XlaBuilder builder(TestName());
454 auto a0 = ConstantR1<int32_t>(&builder, {1});
455 auto a1 = ConstantR1<int32_t>(&builder, {2, 3});
456 auto a2 = ConstantR1<int32_t>(&builder, {4, 5, 6});
457 auto a3 = ConstantR1<int32_t>(&builder, {7, 8, 9, 10});
458 ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
459
460 std::vector<int32_t> expected(10);
461 std::iota(expected.begin(), expected.end(), 1);
462 ComputeAndCompareR1<int32_t>(&builder, expected, {});
463 }
464
XLA_TEST_F(ConcatTest,ConcatR3WeirdDims)465 XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
466 XlaBuilder builder(TestName());
467
468 Array3D<float> arr0(9, 17, 1);
469 arr0.Fill(1);
470
471 Array3D<float> arr1(9, 17, 256);
472 arr1.Fill(2);
473
474 Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
475 for (int64_t i = 0; i < expected.n1(); ++i) {
476 for (int64_t j = 0; j < expected.n2(); ++j) {
477 int64_t kk = 0;
478 for (const Array3D<float>& arr : {arr0, arr1}) {
479 for (int64_t k = 0; k < arr.n3(); ++k, ++kk) {
480 expected(i, j, kk) = arr(i, j, k);
481 }
482 }
483 }
484 }
485
486 XlaOp h0;
487 auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
488 &builder, &h0);
489 XlaOp h1;
490 auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
491 &builder, &h1);
492
493 ConcatInDim(&builder, {h0, h1}, 2);
494
495 ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
496 }
497
XLA_TEST_F(ConcatTest,ConcatDeeplyNested)498 XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
499 XlaBuilder builder(TestName());
500 auto a_literal = LiteralUtil::CreateR1<float>({256.0});
501 auto a = Parameter(&builder, 0, a_literal.shape(), "x");
502 auto b = ConcatInDim(&builder, {a, a}, 0);
503 auto c = ConcatInDim(&builder, {b, b}, 0);
504 auto d = ConcatInDim(&builder, {c, c}, 0);
505 auto e = ConcatInDim(&builder, {d, d}, 0);
506 auto f = ConcatInDim(&builder, {e, e}, 0);
507 auto g = ConcatInDim(&builder, {f, f}, 0);
508 auto h = ConcatInDim(&builder, {g, g}, 0);
509 auto i = ConcatInDim(&builder, {h, h}, 0);
510 auto j = ConcatInDim(&builder, {i, i}, 0);
511 auto k = ConcatInDim(&builder, {j, j}, 0);
512 auto l = ConcatInDim(&builder, {k, k}, 0);
513 auto m = ConcatInDim(&builder, {l, l}, 0);
514 auto n = ConcatInDim(&builder, {m, m}, 0);
515 auto o = ConcatInDim(&builder, {n, n}, 0);
516 auto p = ConcatInDim(&builder, {o, o}, 0);
517 auto q = ConcatInDim(&builder, {p, p}, 0);
518 ConcatInDim(&builder, {q, q}, 0);
519 std::vector<float> expected(131072, 256.0);
520 auto a_data = client_->TransferToServer(a_literal).value();
521 ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
522 }
523
XLA_TEST_F(ConcatTestHlo,ConcatWithBitcast)524 XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
525 auto module = ParseAndReturnVerifiedModule(R"(
526 HloModule jit_broken.874
527
528 primitive_computation_add.866 {
529 parameter.867 = f32[] parameter(0)
530 parameter.868 = f32[] parameter(1)
531 ROOT add.869 = f32[] add(parameter.867, parameter.868)
532 }
533
534 ENTRY jit_broken.874 {
535 parameter.38 = f32[4,2]{1,0} parameter(0)
536 reshape.723 = f32[4,2,1]{2,1,0} reshape(parameter.38)
537 reshape.724 = f32[4,2,1]{2,1,0} reshape(parameter.38)
538 concatenate.42 = f32[4,2,2]{2,1,0} concatenate(reshape.723, reshape.724), dimensions={2}
539 slice.351 = f32[4,1,2]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:2]}
540 reshape.1058 = f32[4,2]{1,0} reshape(slice.351)
541 slice.352 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [1:2]}
542 reshape.1059 = f32[4]{0} reshape(slice.352)
543 slice.353 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
544 reshape.1060 = f32[4]{0} reshape(slice.353)
545 add.124 = f32[4]{0} add(reshape.1059, reshape.1060)
546 slice.354 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [0:1]}
547 reshape.1061 = f32[4]{0} reshape(slice.354)
548 slice.379 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
549 reshape.1062 = f32[4]{0} reshape(slice.379)
550 add.89 = f32[4]{0} add(reshape.1061, reshape.1062)
551 subtract.126 = f32[4]{0} subtract(add.124, add.89)
552 is-finite.127 = pred[4]{0} is-finite(subtract.126)
553 not.128 = pred[4]{0} not(is-finite.127)
554 abs.129 = f32[4]{0} abs(subtract.126)
555 constant.130 = f32[] constant(inf)
556 broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
557 compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
558 not.133 = pred[4]{0} not(compare.132)
559 and.134 = pred[4]{0} and(not.128, not.133)
560 add.135 = f32[4]{0} add(add.124, add.89)
561 maximum.125 = f32[4]{0} maximum(add.124, add.89)
562 abs.136 = f32[4]{0} abs(subtract.126)
563 negate.137 = f32[4]{0} negate(abs.136)
564 exponential.138 = f32[4]{0} exponential(negate.137)
565 log-plus-one.139 = f32[4]{0} log-plus-one(exponential.138)
566 add.140 = f32[4]{0} add(maximum.125, log-plus-one.139)
567 select.141 = f32[4]{0} select(and.134, add.135, add.140)
568 slice.356 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
569 reshape.1064 = f32[4]{0} reshape(slice.356)
570 add.214 = f32[4]{0} add(select.141, reshape.1064)
571 slice.380 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
572 reshape.1066 = f32[4]{0} reshape(slice.380)
573 add.179 = f32[4]{0} add(select.141, reshape.1066)
574 subtract.216 = f32[4]{0} subtract(add.214, add.179)
575 is-finite.217 = pred[4]{0} is-finite(subtract.216)
576 not.218 = pred[4]{0} not(is-finite.217)
577 abs.219 = f32[4]{0} abs(subtract.216)
578 constant.220 = f32[] constant(inf)
579 broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
580 compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
581 not.223 = pred[4]{0} not(compare.222)
582 and.224 = pred[4]{0} and(not.218, not.223)
583 add.225 = f32[4]{0} add(add.214, add.179)
584 maximum.215 = f32[4]{0} maximum(add.214, add.179)
585 abs.226 = f32[4]{0} abs(subtract.216)
586 negate.227 = f32[4]{0} negate(abs.226)
587 exponential.228 = f32[4]{0} exponential(negate.227)
588 log-plus-one.229 = f32[4]{0} log-plus-one(exponential.228)
589 add.230 = f32[4]{0} add(maximum.215, log-plus-one.229)
590 select.231 = f32[4]{0} select(and.224, add.225, add.230)
591 slice.359 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
592 reshape.1068 = f32[4]{0} reshape(slice.359)
593 add.304 = f32[4]{0} add(select.231, reshape.1068)
594 slice.381 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
595 reshape.1070 = f32[4]{0} reshape(slice.381)
596 add.269 = f32[4]{0} add(select.231, reshape.1070)
597 subtract.306 = f32[4]{0} subtract(add.304, add.269)
598 is-finite.307 = pred[4]{0} is-finite(subtract.306)
599 not.308 = pred[4]{0} not(is-finite.307)
600 abs.309 = f32[4]{0} abs(subtract.306)
601 constant.310 = f32[] constant(inf)
602 broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
603 compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
604 not.313 = pred[4]{0} not(compare.312)
605 and.314 = pred[4]{0} and(not.308, not.313)
606 add.315 = f32[4]{0} add(add.304, add.269)
607 maximum.305 = f32[4]{0} maximum(add.304, add.269)
608 abs.316 = f32[4]{0} abs(subtract.306)
609 negate.317 = f32[4]{0} negate(abs.316)
610 exponential.318 = f32[4]{0} exponential(negate.317)
611 log-plus-one.319 = f32[4]{0} log-plus-one(exponential.318)
612 add.320 = f32[4]{0} add(maximum.305, log-plus-one.319)
613 select.321 = f32[4]{0} select(and.314, add.315, add.320)
614 slice.362 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
615 reshape.1072 = f32[4]{0} reshape(slice.362)
616 add.394 = f32[4]{0} add(select.321, reshape.1072)
617 slice.382 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
618 reshape.1074 = f32[4]{0} reshape(slice.382)
619 add.359 = f32[4]{0} add(select.321, reshape.1074)
620 subtract.396 = f32[4]{0} subtract(add.394, add.359)
621 is-finite.397 = pred[4]{0} is-finite(subtract.396)
622 not.398 = pred[4]{0} not(is-finite.397)
623 abs.399 = f32[4]{0} abs(subtract.396)
624 constant.400 = f32[] constant(inf)
625 broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
626 compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
627 not.403 = pred[4]{0} not(compare.402)
628 and.404 = pred[4]{0} and(not.398, not.403)
629 add.405 = f32[4]{0} add(add.394, add.359)
630 maximum.395 = f32[4]{0} maximum(add.394, add.359)
631 abs.406 = f32[4]{0} abs(subtract.396)
632 negate.407 = f32[4]{0} negate(abs.406)
633 exponential.408 = f32[4]{0} exponential(negate.407)
634 log-plus-one.409 = f32[4]{0} log-plus-one(exponential.408)
635 add.410 = f32[4]{0} add(maximum.395, log-plus-one.409)
636 select.411 = f32[4]{0} select(and.404, add.405, add.410)
637 slice.365 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
638 reshape.1076 = f32[4]{0} reshape(slice.365)
639 add.484 = f32[4]{0} add(select.411, reshape.1076)
640 slice.383 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
641 reshape.1078 = f32[4]{0} reshape(slice.383)
642 add.449 = f32[4]{0} add(select.411, reshape.1078)
643 subtract.486 = f32[4]{0} subtract(add.484, add.449)
644 is-finite.487 = pred[4]{0} is-finite(subtract.486)
645 not.488 = pred[4]{0} not(is-finite.487)
646 abs.489 = f32[4]{0} abs(subtract.486)
647 constant.490 = f32[] constant(inf)
648 broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
649 compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
650 not.493 = pred[4]{0} not(compare.492)
651 and.494 = pred[4]{0} and(not.488, not.493)
652 add.495 = f32[4]{0} add(add.484, add.449)
653 maximum.485 = f32[4]{0} maximum(add.484, add.449)
654 abs.496 = f32[4]{0} abs(subtract.486)
655 negate.497 = f32[4]{0} negate(abs.496)
656 exponential.498 = f32[4]{0} exponential(negate.497)
657 log-plus-one.499 = f32[4]{0} log-plus-one(exponential.498)
658 add.500 = f32[4]{0} add(maximum.485, log-plus-one.499)
659 select.501 = f32[4]{0} select(and.494, add.495, add.500)
660 slice.368 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
661 reshape.1080 = f32[4]{0} reshape(slice.368)
662 add.574 = f32[4]{0} add(select.501, reshape.1080)
663 slice.384 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
664 reshape.1082 = f32[4]{0} reshape(slice.384)
665 add.539 = f32[4]{0} add(select.501, reshape.1082)
666 subtract.576 = f32[4]{0} subtract(add.574, add.539)
667 is-finite.577 = pred[4]{0} is-finite(subtract.576)
668 not.578 = pred[4]{0} not(is-finite.577)
669 abs.579 = f32[4]{0} abs(subtract.576)
670 constant.580 = f32[] constant(inf)
671 broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
672 compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
673 not.583 = pred[4]{0} not(compare.582)
674 and.584 = pred[4]{0} and(not.578, not.583)
675 add.585 = f32[4]{0} add(add.574, add.539)
676 maximum.575 = f32[4]{0} maximum(add.574, add.539)
677 abs.586 = f32[4]{0} abs(subtract.576)
678 negate.587 = f32[4]{0} negate(abs.586)
679 exponential.588 = f32[4]{0} exponential(negate.587)
680 log-plus-one.589 = f32[4]{0} log-plus-one(exponential.588)
681 add.590 = f32[4]{0} add(maximum.575, log-plus-one.589)
682 select.591 = f32[4]{0} select(and.584, add.585, add.590)
683 slice.371 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
684 reshape.1084 = f32[4]{0} reshape(slice.371)
685 add.664 = f32[4]{0} add(select.591, reshape.1084)
686 slice.385 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
687 reshape.1086 = f32[4]{0} reshape(slice.385)
688 add.629 = f32[4]{0} add(select.591, reshape.1086)
689 subtract.666 = f32[4]{0} subtract(add.664, add.629)
690 is-finite.667 = pred[4]{0} is-finite(subtract.666)
691 not.668 = pred[4]{0} not(is-finite.667)
692 abs.669 = f32[4]{0} abs(subtract.666)
693 constant.670 = f32[] constant(inf)
694 broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
695 compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
696 not.673 = pred[4]{0} not(compare.672)
697 and.674 = pred[4]{0} and(not.668, not.673)
698 add.675 = f32[4]{0} add(add.664, add.629)
699 maximum.665 = f32[4]{0} maximum(add.664, add.629)
700 abs.676 = f32[4]{0} abs(subtract.666)
701 negate.677 = f32[4]{0} negate(abs.676)
702 exponential.678 = f32[4]{0} exponential(negate.677)
703 log-plus-one.679 = f32[4]{0} log-plus-one(exponential.678)
704 add.680 = f32[4]{0} add(maximum.665, log-plus-one.679)
705 select.681 = f32[4]{0} select(and.674, add.675, add.680)
706 slice.374 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
707 reshape.1088 = f32[4]{0} reshape(slice.374)
708 add.754 = f32[4]{0} add(select.681, reshape.1088)
709 slice.386 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
710 reshape.1090 = f32[4]{0} reshape(slice.386)
711 add.719 = f32[4]{0} add(select.681, reshape.1090)
712 subtract.756 = f32[4]{0} subtract(add.754, add.719)
713 is-finite.757 = pred[4]{0} is-finite(subtract.756)
714 not.758 = pred[4]{0} not(is-finite.757)
715 abs.759 = f32[4]{0} abs(subtract.756)
716 constant.760 = f32[] constant(inf)
717 broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
718 compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
719 not.763 = pred[4]{0} not(compare.762)
720 and.764 = pred[4]{0} and(not.758, not.763)
721 add.765 = f32[4]{0} add(add.754, add.719)
722 maximum.755 = f32[4]{0} maximum(add.754, add.719)
723 abs.766 = f32[4]{0} abs(subtract.756)
724 negate.767 = f32[4]{0} negate(abs.766)
725 exponential.768 = f32[4]{0} exponential(negate.767)
726 log-plus-one.769 = f32[4]{0} log-plus-one(exponential.768)
727 add.770 = f32[4]{0} add(maximum.755, log-plus-one.769)
728 select.771 = f32[4]{0} select(and.764, add.765, add.770)
729 slice.377 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
730 reshape.1092 = f32[4]{0} reshape(slice.377)
731 add.844 = f32[4]{0} add(select.771, reshape.1092)
732 slice.387 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
733 reshape.1094 = f32[4]{0} reshape(slice.387)
734 add.809 = f32[4]{0} add(select.771, reshape.1094)
735 subtract.846 = f32[4]{0} subtract(add.844, add.809)
736 is-finite.847 = pred[4]{0} is-finite(subtract.846)
737 not.848 = pred[4]{0} not(is-finite.847)
738 abs.849 = f32[4]{0} abs(subtract.846)
739 constant.850 = f32[] constant(inf)
740 broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
741 compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
742 not.853 = pred[4]{0} not(compare.852)
743 and.854 = pred[4]{0} and(not.848, not.853)
744 add.855 = f32[4]{0} add(add.844, add.809)
745 maximum.845 = f32[4]{0} maximum(add.844, add.809)
746 abs.856 = f32[4]{0} abs(subtract.846)
747 negate.857 = f32[4]{0} negate(abs.856)
748 exponential.858 = f32[4]{0} exponential(negate.857)
749 log-plus-one.859 = f32[4]{0} log-plus-one(exponential.858)
750 add.860 = f32[4]{0} add(maximum.845, log-plus-one.859)
751 select.861 = f32[4]{0} select(and.854, add.855, add.860)
752 constant.865 = f32[] constant(0)
753 reduce.2 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
754 reduce.3 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
755 add.77 = f32[] add(reduce.2, reduce.3)
756 constant.719 = f32[] constant(0.125)
757 multiply = f32[] multiply(add.77, constant.719)
758 ROOT tuple.873 = (f32[]) tuple(multiply)
759 })")
760 .value();
761 auto input_array = std::make_unique<Array2D<float>>(4, 2);
762 input_array->FillUnique(1.0f);
763 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
764 EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
765 }
766
767 // Describes a binary rank-2 concatenation test.
768 struct R2BinarySpec {
769 int64_t lhs_dim0;
770 int64_t lhs_dim1;
771 int64_t rhs_dim0;
772 int64_t rhs_dim1;
773 int64_t concat_dimension;
774 };
775
776 // TEST_P harness for binary rank-2 concatenation.
777 class ConcatR2BinaryTest : public ClientLibraryTestBase,
778 public ::testing::WithParamInterface<R2BinarySpec> {
779 };
780
TEST_P(ConcatR2BinaryTest,DoIt)781 TEST_P(ConcatR2BinaryTest, DoIt) {
782 const R2BinarySpec& spec = GetParam();
783 Array2D<int32_t> lhs(spec.lhs_dim0, spec.lhs_dim1);
784 lhs.FillUnique();
785 Array2D<int32_t> rhs(spec.rhs_dim0, spec.rhs_dim1);
786 rhs.FillUnique(1000);
787
788 XlaBuilder builder(TestName());
789 auto a0 = ConstantR2FromArray2D<int32_t>(&builder, lhs);
790 auto a1 = ConstantR2FromArray2D<int32_t>(&builder, rhs);
791 ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
792
793 std::unique_ptr<Array2D<int32_t>> expected =
794 ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
795 ComputeAndCompareR2<int32_t>(&builder, *expected, {});
796 }
797
798 // Regression test for b/31944287. x*y is used (at the same index) by all
799 // operands of the concat. We should emit x*y in three incoming basic blocks of
800 // the concat because these basic blocks are not control-equivalent.
801 //
802 // x*y
803 // / | \
804 // add1 add2 add3
805 // \ | /
806 // concat
XLA_TEST_F(ConcatTest,ConcatOperandsOfSameOperand)807 XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
808 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
809 auto x_literal = LiteralUtil::CreateR0<float>(2.f);
810 auto y_literal = LiteralUtil::CreateR0<float>(3.f);
811 auto x_data = client_->TransferToServer(x_literal).value();
812 auto y_data = client_->TransferToServer(y_literal).value();
813
814 XlaBuilder builder(TestName());
815 auto x = Parameter(&builder, 0, f32_scalar, "x");
816 auto y = Parameter(&builder, 1, f32_scalar, "y");
817 auto mul = Mul(x, y);
818 auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
819 auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
820 auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
821 ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
822
823 ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
824 {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
825 }
826
827 // Test that the HLO optimization to replace a concat of a broadcasted scalar
828 // produces the correct result in rank 1.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgument)829 XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
830 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
831 auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
832 auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
833 auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
834 auto x_data = client_->TransferToServer(x_literal).value();
835 auto y_data = client_->TransferToServer(y_literal).value();
836 auto z_data = client_->TransferToServer(z_literal).value();
837
838 XlaBuilder builder(TestName());
839 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
840 auto y = Parameter(&builder, 1, f32_scalar, "y");
841 auto z = Parameter(&builder, 2, f32_scalar, "z");
842 auto bcast = Broadcast(y, {5});
843 auto bcast2 = Broadcast(z, {3});
844 auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
845 ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
846
847 ComputeAndCompareR1<float>(
848 &builder,
849 {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
850 {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
851 }
852
853 // Test that the HLO optimization to replace a concat of a broadcasted scalar
854 // produces the correct result in rank 3 with both high and low padding in
855 // different dimensions.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgumentR3)856 XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
857 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
858 Array3D<float> x3d(3, 5, 7, 3.14f);
859 auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
860 auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
861 auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
862 auto x_data = client_->TransferToServer(x_literal).value();
863 auto y_data = client_->TransferToServer(y_literal).value();
864 auto z_data = client_->TransferToServer(z_literal).value();
865
866 XlaBuilder builder(TestName());
867 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
868 auto y = Parameter(&builder, 1, f32_scalar, "y");
869 auto z = Parameter(&builder, 2, f32_scalar, "y");
870 auto y_bcast = Broadcast(y, {1, 5, 7});
871 auto z_bcast = Broadcast(z, {4, 1, 7});
872 auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
873 ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
874 Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
875 Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
876 auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
877 auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
878
879 ComputeAndCompareR3<float>(&builder, *concat1,
880 {x_data.get(), y_data.get(), z_data.get()},
881 ErrorSpec(1e-4));
882 }
883
884 INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
885 ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
886 R2BinarySpec{1, 1, 1, 1, 1},
887 R2BinarySpec{4, 3, 4, 3, 0},
888 R2BinarySpec{4, 3, 4, 3, 1},
889 R2BinarySpec{7, 128, 1, 128, 0},
890 R2BinarySpec{8, 127, 8, 1, 1}));
891
892 } // namespace
893 } // namespace xla
894