1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/reference_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/test.h"
34 
35 namespace xla {
36 namespace {
37 
38 using ConcatTest = ClientLibraryTestBase;
39 using ConcatTestHlo = HloTestBase;
40 using ::testing::HasSubstr;
41 
42 // Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest,Concat_Nothing)43 XLA_TEST_F(ConcatTest, Concat_Nothing) {
44   XlaBuilder builder(TestName());
45   ConcatInDim(&builder, {}, 0);
46   StatusOr<XlaComputation> computation_status = builder.Build();
47   ASSERT_FALSE(computation_status.ok());
48   EXPECT_THAT(computation_status.status().ToString(),
49               HasSubstr("Concatenate expects at least one argument"));
50 }
51 
52 // Concatenate with one argument works.
XLA_TEST_F(ConcatTest,Concat_R1_With_Nothing)53 XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
54   XlaBuilder builder(TestName());
55   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
56   ConcatInDim(&builder, {a}, 0);
57 
58   std::vector<float> expected = {42, 64};
59   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
60 }
61 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_Nothing)62 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
63   XlaBuilder builder(TestName());
64   auto a = ConstantR1<float>(&builder, {});
65   ConcatInDim(&builder, {a}, 0);
66 
67   std::vector<float> expected = {};
68   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
69 }
70 
71 // Show that we can't concatenate R0 with R0 because we can't name the dimension
72 // to concatenate on.
XLA_TEST_F(ConcatTest,CannotConcatR0WithR0)73 XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
74   XlaBuilder builder(TestName());
75   auto a = ConstantR0<float>(&builder, 42.0);
76   auto b = ConstantR0<float>(&builder, 64.0);
77   ConcatInDim(&builder, {a, b}, 0);
78   StatusOr<XlaComputation> computation_status = builder.Build();
79   ASSERT_FALSE(computation_status.ok());
80   EXPECT_THAT(computation_status.status().ToString(),
81               HasSubstr("out of bounds: 0"));
82 }
83 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L0)84 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
85   XlaBuilder builder(TestName());
86   auto a = ConstantR1<float>(&builder, {});
87   auto b = ConstantR1<float>(&builder, {});
88   ConcatInDim(&builder, {a, b}, 0);
89 
90   std::vector<float> expected = {};
91   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
92 }
93 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L1)94 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
95   XlaBuilder builder(TestName());
96   auto a = ConstantR1<float>(&builder, {});
97   auto b = ConstantR1<float>(&builder, {256.0});
98   ConcatInDim(&builder, {a, b}, 0);
99 
100   std::vector<float> expected = {256};
101   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
102 }
103 
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L0)104 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
105   XlaBuilder builder(TestName());
106   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
107   auto b = ConstantR1<float>(&builder, {});
108   ConcatInDim(&builder, {a, b}, 0);
109 
110   std::vector<float> expected = {42, 64};
111   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
112 }
113 
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L1)114 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
115   XlaBuilder builder(TestName());
116   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
117   auto b = ConstantR1<float>(&builder, {256.0});
118   ConcatInDim(&builder, {a, b}, 0);
119 
120   std::vector<float> expected = {42, 64, 256};
121   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
122 }
123 
XLA_TEST_F(ConcatTest,Concat_R1_L253_With_R1_L7)124 XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
125   std::vector<float> lhs(253);
126   std::vector<float> rhs(7);
127   std::vector<float> expected(253 + 7);
128   for (int i = 0; i < 253; ++i) {
129     expected[i] = lhs[i] = i + 1;
130   }
131   for (int i = 0; i < 7; ++i) {
132     expected[253 + i] = rhs[i] = 253 + i + 1;
133   }
134 
135   XlaBuilder builder(TestName());
136   auto a = ConstantR1<float>(&builder, lhs);
137   auto b = ConstantR1<float>(&builder, rhs);
138   ConcatInDim(&builder, {a, b}, 0);
139 
140   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
141 }
142 
XLA_TEST_F(ConcatTest,Concat_0x0_With_0x0)143 XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
144   for (int dim : {0, 1}) {
145     XlaBuilder builder(TestName());
146     auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
147     auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
148     ConcatInDim(&builder, {a, b}, dim);
149 
150     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
151                                ErrorSpec(0.0001));
152   }
153 }
154 
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim0)155 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
156   XlaBuilder builder(TestName());
157   auto a_array = CreatePatternedMatrix(1, 1);
158   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
159   auto a = ConstantR2FromArray2D(&builder, *a_array);
160   auto b = ConstantR2FromArray2D(&builder, *b_array);
161   ConcatInDim(&builder, {a, b}, 0);
162 
163   Array2D<float> expected({
164       {0},
165       {64},
166   });
167   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
168 }
169 
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim1)170 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
171   XlaBuilder builder(TestName());
172   auto a_array = CreatePatternedMatrix(1, 1);
173   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
174   auto a = ConstantR2FromArray2D(&builder, *a_array);
175   auto b = ConstantR2FromArray2D(&builder, *b_array);
176   ConcatInDim(&builder, {a, b}, 1);
177 
178   Array2D<float> expected({
179       {0, 64},
180   });
181   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
182 }
183 
XLA_TEST_F(ConcatTest,Concat2x0With2x5)184 XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
185   XlaBuilder builder(TestName());
186   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
187   auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
188   auto b = ConstantR2FromArray2D(&builder, *b_array);
189   ConcatInDim(&builder, {a, b}, 1);
190 
191   ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
192 }
193 
XLA_TEST_F(ConcatTest,Concat2x3With2x5)194 XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
195   XlaBuilder builder(TestName());
196   auto a_array = CreatePatternedMatrix(2, 3);
197   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
198   auto a = ConstantR2FromArray2D(&builder, *a_array);
199   auto b = ConstantR2FromArray2D(&builder, *b_array);
200   ConcatInDim(&builder, {a, b}, 1);
201 
202   Array2D<float> expected({
203       {0, 1, 2, 64, 65, 66, 67, 68},
204       {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
205   });
206   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
207 }
208 
XLA_TEST_F(ConcatTest,Concat3x2With0x2)209 XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
210   XlaBuilder builder(TestName());
211   auto a_array = CreatePatternedMatrix(3, 2);
212   auto a = ConstantR2FromArray2D(&builder, *a_array);
213   auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
214   ConcatInDim(&builder, {a, b}, 0);
215 
216   ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
217 }
218 
XLA_TEST_F(ConcatTest,Concat3x2With5x2)219 XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
220   XlaBuilder builder(TestName());
221   auto a_array = CreatePatternedMatrix(3, 2);
222   auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
223   auto a = ConstantR2FromArray2D(&builder, *a_array);
224   auto b = ConstantR2FromArray2D(&builder, *b_array);
225   ConcatInDim(&builder, {a, b}, 0);
226 
227   Array2D<float> expected({
228       {0, 1},
229       {1000, 1001},
230       {2000, 2001},
231       {64, 65},
232       {1064, 1065},
233       {2064, 2065},
234       {3064, 3065},
235       {4064, 4065},
236   });
237   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
238 }
239 
XLA_TEST_F(ConcatTest,Concat_R3_3x0x2_3x0x1)240 XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
241   XlaBuilder builder(TestName());
242   auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
243   auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
244   ConcatInDim(&builder, {a, b}, 2);
245   ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
246                              ErrorSpec(0.0001));
247 }
248 
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1)249 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
250   XlaBuilder builder(TestName());
251   Array3D<float> a_array({
252       // 3x1x2
253       {{0, 1}},
254       {{2, 3}},
255       {{4, 5}},
256   });
257   Array3D<float> b_array({
258       // 3x1x1
259       {{6}},
260       {{7}},
261       {{8}},
262   });
263   auto a = ConstantR3FromArray3D(&builder, a_array);
264   auto b = ConstantR3FromArray3D(&builder, b_array);
265   ConcatInDim(&builder, {a, b}, 2);
266 
267   Array3D<float> expected({
268       {{0, 1, 6}},
269       {{2, 3, 7}},
270       {{4, 5, 8}},
271   });
272   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
273 }
274 
XLA_TEST_F(ConcatTest,Concat_R1_1x1_1x1_1x1)275 XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
276   XlaBuilder builder(TestName());
277   auto a = ConstantR1<float>(&builder, {42.0});
278   auto b = ConstantR1<float>(&builder, {64.0});
279   auto c = ConstantR1<float>(&builder, {256.0});
280   ConcatInDim(&builder, {a, b, c}, 0);
281 
282   std::vector<float> expected = {42, 64, 256};
283   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
284 }
285 
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1_3x1x1)286 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
287   XlaBuilder builder(TestName());
288   Array3D<float> a_array({
289       // 3x1x2
290       {{0, 1}},
291       {{4, 5}},
292       {{8, 9}},
293   });
294   Array3D<float> b_array({
295       // 3x1x1
296       {{2}},
297       {{6}},
298       {{10}},
299   });
300   Array3D<float> c_array({
301       // 3x1x1
302       {{3}},
303       {{7}},
304       {{11}},
305   });
306   auto a = ConstantR3FromArray3D(&builder, a_array);
307   auto b = ConstantR3FromArray3D(&builder, b_array);
308   auto c = ConstantR3FromArray3D(&builder, c_array);
309   ConcatInDim(&builder, {a, b, c}, 2);
310 
311   Array3D<float> expected({
312       {{0, 1, 2, 3}},
313       {{4, 5, 6, 7}},
314       {{8, 9, 10, 11}},
315   });
316   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
317 }
318 
XLA_TEST_F(ConcatTest,DoubleConcatLeftAssociative)319 XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
320   XlaBuilder builder(TestName());
321   auto a = ConstantR1<float>(&builder, {42.0});
322   auto b = ConstantR1<float>(&builder, {64.0});
323   auto c = ConstantR1<float>(&builder, {256.0});
324   // concatenated = (a concat b) concat c
325   ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
326 
327   std::vector<float> expected = {42, 64, 256};
328   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
329 }
330 
XLA_TEST_F(ConcatTest,DoubleConcatRightAssociative)331 XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
332   XlaBuilder builder(TestName());
333   auto a = ConstantR1<float>(&builder, {42.0});
334   auto b = ConstantR1<float>(&builder, {64.0});
335   auto c = ConstantR1<float>(&builder, {256.0});
336   // concatenated = a concat (b concat c)
337   ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
338 
339   std::vector<float> expected = {42, 64, 256};
340   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
341 }
342 
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim0)343 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
344   Array2D<float> lhs(1, 1024);
345   Array2D<float> rhs(1, 1024);
346   for (int i = 0; i < 1024; ++i) {
347     lhs(0, i) = i;
348     rhs(0, i) = i + 1024;
349   }
350 
351   XlaBuilder builder(TestName());
352   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
353   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
354   ConcatInDim(&builder, {a, b}, 0);
355 
356   Array2D<float> expected(2, 1024);
357   for (int i = 0; i < 1024; ++i) {
358     expected(0, i) = i;
359     expected(1, i) = i + 1024;
360   }
361   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
362 }
363 
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim1)364 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
365   Array2D<float> lhs(1, 1024);
366   Array2D<float> rhs(1, 1024);
367   for (int i = 0; i < 1024; ++i) {
368     lhs(0, i) = i;
369     rhs(0, i) = i + 1024;
370   }
371 
372   XlaBuilder builder(TestName());
373   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
374   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
375   ConcatInDim(&builder, {a, b}, 1);
376 
377   Array2D<float> expected(1, 2048);
378   for (int i = 0; i < 1024; ++i) {
379     expected(0, i) = i;
380     expected(0, i + 1024) = i + 1024;
381   }
382   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
383 }
384 
XLA_TEST_F(ConcatTest,Concat_64x64_With_64x2)385 XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
386   Array2D<float> lhs(64, 64);
387   Array2D<float> rhs(64, 2);
388   for (int i0 = 0; i0 < 64; ++i0) {
389     for (int i1 = 0; i1 < 64; ++i1) {
390       lhs(i0, i1) = (i0 << 10) | i1;
391     }
392     for (int i1 = 0; i1 < 2; ++i1) {
393       rhs(i0, i1) = (i0 << 10) | (i1 + 64);
394     }
395   }
396 
397   XlaBuilder builder(TestName());
398   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
399   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
400   ConcatInDim(&builder, {a, b}, 1);
401 
402   Array2D<float> expected(64, 66);
403   for (int i0 = 0; i0 < 64; ++i0) {
404     for (int i1 = 0; i1 < 66; ++i1) {
405       expected(i0, i1) = (i0 << 10) | i1;
406     }
407   }
408   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
409 }
410 
411 // Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest,CannotConcatOpaques)412 XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
413   XlaBuilder builder(TestName());
414   auto opaque_shape = ShapeUtil::MakeOpaqueShape();
415   auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
416   auto x = Parameter(&builder, 0, r1f32, "x");
417   auto y = Parameter(&builder, 1, opaque_shape, "y");
418   ConcatInDim(&builder, {x, y}, 0);
419   StatusOr<XlaComputation> computation_status = builder.Build();
420   ASSERT_FALSE(computation_status.ok());
421   EXPECT_THAT(
422       computation_status.status().ToString(),
423       HasSubstr("Expected array argument for operand of concatenation"));
424 }
425 
426 // Show that we can't concatenate with tokens.
XLA_TEST_F(ConcatTest,CannotConcatTokens)427 XLA_TEST_F(ConcatTest, CannotConcatTokens) {
428   XlaBuilder builder(TestName());
429   auto token_shape = ShapeUtil::MakeTokenShape();
430   auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
431   auto x = Parameter(&builder, 0, r1f32, "x");
432   auto y = Parameter(&builder, 1, token_shape, "y");
433   ConcatInDim(&builder, {x, y}, 0);
434   StatusOr<XlaComputation> computation_status = builder.Build();
435   ASSERT_FALSE(computation_status.ok());
436   EXPECT_THAT(
437       computation_status.status().ToString(),
438       HasSubstr("Expected array argument for operand of concatenation"));
439 }
440 
XLA_TEST_F(ConcatTest,ConcatSeveralBoxedPredicates)441 XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
442   XlaBuilder builder(TestName());
443   auto p0 = ConstantR1<bool>(&builder, {true});
444   auto p1 = ConstantR1<bool>(&builder, {false});
445   auto p2 = ConstantR1<bool>(&builder, {true});
446   ConcatInDim(&builder, {p0, p1, p2}, 0);
447 
448   bool expected[] = {true, false, true};
449   ComputeAndCompareR1<bool>(&builder, expected, {});
450 }
451 
XLA_TEST_F(ConcatTest,ConcatSeveralR1S32s)452 XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
453   XlaBuilder builder(TestName());
454   auto a0 = ConstantR1<int32_t>(&builder, {1});
455   auto a1 = ConstantR1<int32_t>(&builder, {2, 3});
456   auto a2 = ConstantR1<int32_t>(&builder, {4, 5, 6});
457   auto a3 = ConstantR1<int32_t>(&builder, {7, 8, 9, 10});
458   ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
459 
460   std::vector<int32_t> expected(10);
461   std::iota(expected.begin(), expected.end(), 1);
462   ComputeAndCompareR1<int32_t>(&builder, expected, {});
463 }
464 
XLA_TEST_F(ConcatTest,ConcatR3WeirdDims)465 XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
466   XlaBuilder builder(TestName());
467 
468   Array3D<float> arr0(9, 17, 1);
469   arr0.Fill(1);
470 
471   Array3D<float> arr1(9, 17, 256);
472   arr1.Fill(2);
473 
474   Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
475   for (int64_t i = 0; i < expected.n1(); ++i) {
476     for (int64_t j = 0; j < expected.n2(); ++j) {
477       int64_t kk = 0;
478       for (const Array3D<float>& arr : {arr0, arr1}) {
479         for (int64_t k = 0; k < arr.n3(); ++k, ++kk) {
480           expected(i, j, kk) = arr(i, j, k);
481         }
482       }
483     }
484   }
485 
486   XlaOp h0;
487   auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
488                                      &builder, &h0);
489   XlaOp h1;
490   auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
491                                      &builder, &h1);
492 
493   ConcatInDim(&builder, {h0, h1}, 2);
494 
495   ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
496 }
497 
XLA_TEST_F(ConcatTest,ConcatDeeplyNested)498 XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
499   XlaBuilder builder(TestName());
500   auto a_literal = LiteralUtil::CreateR1<float>({256.0});
501   auto a = Parameter(&builder, 0, a_literal.shape(), "x");
502   auto b = ConcatInDim(&builder, {a, a}, 0);
503   auto c = ConcatInDim(&builder, {b, b}, 0);
504   auto d = ConcatInDim(&builder, {c, c}, 0);
505   auto e = ConcatInDim(&builder, {d, d}, 0);
506   auto f = ConcatInDim(&builder, {e, e}, 0);
507   auto g = ConcatInDim(&builder, {f, f}, 0);
508   auto h = ConcatInDim(&builder, {g, g}, 0);
509   auto i = ConcatInDim(&builder, {h, h}, 0);
510   auto j = ConcatInDim(&builder, {i, i}, 0);
511   auto k = ConcatInDim(&builder, {j, j}, 0);
512   auto l = ConcatInDim(&builder, {k, k}, 0);
513   auto m = ConcatInDim(&builder, {l, l}, 0);
514   auto n = ConcatInDim(&builder, {m, m}, 0);
515   auto o = ConcatInDim(&builder, {n, n}, 0);
516   auto p = ConcatInDim(&builder, {o, o}, 0);
517   auto q = ConcatInDim(&builder, {p, p}, 0);
518   ConcatInDim(&builder, {q, q}, 0);
519   std::vector<float> expected(131072, 256.0);
520   auto a_data = client_->TransferToServer(a_literal).value();
521   ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
522 }
523 
XLA_TEST_F(ConcatTestHlo,ConcatWithBitcast)524 XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
525   auto module = ParseAndReturnVerifiedModule(R"(
526 HloModule jit_broken.874
527 
528 primitive_computation_add.866 {
529   parameter.867 = f32[] parameter(0)
530   parameter.868 = f32[] parameter(1)
531   ROOT add.869 = f32[] add(parameter.867, parameter.868)
532 }
533 
534 ENTRY jit_broken.874 {
535   parameter.38 = f32[4,2]{1,0} parameter(0)
536   reshape.723 = f32[4,2,1]{2,1,0} reshape(parameter.38)
537   reshape.724 = f32[4,2,1]{2,1,0} reshape(parameter.38)
538   concatenate.42 = f32[4,2,2]{2,1,0} concatenate(reshape.723, reshape.724), dimensions={2}
539   slice.351 = f32[4,1,2]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:2]}
540   reshape.1058 = f32[4,2]{1,0} reshape(slice.351)
541   slice.352 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [1:2]}
542   reshape.1059 = f32[4]{0} reshape(slice.352)
543   slice.353 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
544   reshape.1060 = f32[4]{0} reshape(slice.353)
545   add.124 = f32[4]{0} add(reshape.1059, reshape.1060)
546   slice.354 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [0:1]}
547   reshape.1061 = f32[4]{0} reshape(slice.354)
548   slice.379 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
549   reshape.1062 = f32[4]{0} reshape(slice.379)
550   add.89 = f32[4]{0} add(reshape.1061, reshape.1062)
551   subtract.126 = f32[4]{0} subtract(add.124, add.89)
552   is-finite.127 = pred[4]{0} is-finite(subtract.126)
553   not.128 = pred[4]{0} not(is-finite.127)
554   abs.129 = f32[4]{0} abs(subtract.126)
555   constant.130 = f32[] constant(inf)
556   broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
557   compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ
558   not.133 = pred[4]{0} not(compare.132)
559   and.134 = pred[4]{0} and(not.128, not.133)
560   add.135 = f32[4]{0} add(add.124, add.89)
561   maximum.125 = f32[4]{0} maximum(add.124, add.89)
562   abs.136 = f32[4]{0} abs(subtract.126)
563   negate.137 = f32[4]{0} negate(abs.136)
564   exponential.138 = f32[4]{0} exponential(negate.137)
565   log-plus-one.139 = f32[4]{0} log-plus-one(exponential.138)
566   add.140 = f32[4]{0} add(maximum.125, log-plus-one.139)
567   select.141 = f32[4]{0} select(and.134, add.135, add.140)
568   slice.356 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
569   reshape.1064 = f32[4]{0} reshape(slice.356)
570   add.214 = f32[4]{0} add(select.141, reshape.1064)
571   slice.380 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
572   reshape.1066 = f32[4]{0} reshape(slice.380)
573   add.179 = f32[4]{0} add(select.141, reshape.1066)
574   subtract.216 = f32[4]{0} subtract(add.214, add.179)
575   is-finite.217 = pred[4]{0} is-finite(subtract.216)
576   not.218 = pred[4]{0} not(is-finite.217)
577   abs.219 = f32[4]{0} abs(subtract.216)
578   constant.220 = f32[] constant(inf)
579   broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
580   compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ
581   not.223 = pred[4]{0} not(compare.222)
582   and.224 = pred[4]{0} and(not.218, not.223)
583   add.225 = f32[4]{0} add(add.214, add.179)
584   maximum.215 = f32[4]{0} maximum(add.214, add.179)
585   abs.226 = f32[4]{0} abs(subtract.216)
586   negate.227 = f32[4]{0} negate(abs.226)
587   exponential.228 = f32[4]{0} exponential(negate.227)
588   log-plus-one.229 = f32[4]{0} log-plus-one(exponential.228)
589   add.230 = f32[4]{0} add(maximum.215, log-plus-one.229)
590   select.231 = f32[4]{0} select(and.224, add.225, add.230)
591   slice.359 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
592   reshape.1068 = f32[4]{0} reshape(slice.359)
593   add.304 = f32[4]{0} add(select.231, reshape.1068)
594   slice.381 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
595   reshape.1070 = f32[4]{0} reshape(slice.381)
596   add.269 = f32[4]{0} add(select.231, reshape.1070)
597   subtract.306 = f32[4]{0} subtract(add.304, add.269)
598   is-finite.307 = pred[4]{0} is-finite(subtract.306)
599   not.308 = pred[4]{0} not(is-finite.307)
600   abs.309 = f32[4]{0} abs(subtract.306)
601   constant.310 = f32[] constant(inf)
602   broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
603   compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ
604   not.313 = pred[4]{0} not(compare.312)
605   and.314 = pred[4]{0} and(not.308, not.313)
606   add.315 = f32[4]{0} add(add.304, add.269)
607   maximum.305 = f32[4]{0} maximum(add.304, add.269)
608   abs.316 = f32[4]{0} abs(subtract.306)
609   negate.317 = f32[4]{0} negate(abs.316)
610   exponential.318 = f32[4]{0} exponential(negate.317)
611   log-plus-one.319 = f32[4]{0} log-plus-one(exponential.318)
612   add.320 = f32[4]{0} add(maximum.305, log-plus-one.319)
613   select.321 = f32[4]{0} select(and.314, add.315, add.320)
614   slice.362 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
615   reshape.1072 = f32[4]{0} reshape(slice.362)
616   add.394 = f32[4]{0} add(select.321, reshape.1072)
617   slice.382 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
618   reshape.1074 = f32[4]{0} reshape(slice.382)
619   add.359 = f32[4]{0} add(select.321, reshape.1074)
620   subtract.396 = f32[4]{0} subtract(add.394, add.359)
621   is-finite.397 = pred[4]{0} is-finite(subtract.396)
622   not.398 = pred[4]{0} not(is-finite.397)
623   abs.399 = f32[4]{0} abs(subtract.396)
624   constant.400 = f32[] constant(inf)
625   broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
626   compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ
627   not.403 = pred[4]{0} not(compare.402)
628   and.404 = pred[4]{0} and(not.398, not.403)
629   add.405 = f32[4]{0} add(add.394, add.359)
630   maximum.395 = f32[4]{0} maximum(add.394, add.359)
631   abs.406 = f32[4]{0} abs(subtract.396)
632   negate.407 = f32[4]{0} negate(abs.406)
633   exponential.408 = f32[4]{0} exponential(negate.407)
634   log-plus-one.409 = f32[4]{0} log-plus-one(exponential.408)
635   add.410 = f32[4]{0} add(maximum.395, log-plus-one.409)
636   select.411 = f32[4]{0} select(and.404, add.405, add.410)
637   slice.365 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
638   reshape.1076 = f32[4]{0} reshape(slice.365)
639   add.484 = f32[4]{0} add(select.411, reshape.1076)
640   slice.383 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
641   reshape.1078 = f32[4]{0} reshape(slice.383)
642   add.449 = f32[4]{0} add(select.411, reshape.1078)
643   subtract.486 = f32[4]{0} subtract(add.484, add.449)
644   is-finite.487 = pred[4]{0} is-finite(subtract.486)
645   not.488 = pred[4]{0} not(is-finite.487)
646   abs.489 = f32[4]{0} abs(subtract.486)
647   constant.490 = f32[] constant(inf)
648   broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
649   compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ
650   not.493 = pred[4]{0} not(compare.492)
651   and.494 = pred[4]{0} and(not.488, not.493)
652   add.495 = f32[4]{0} add(add.484, add.449)
653   maximum.485 = f32[4]{0} maximum(add.484, add.449)
654   abs.496 = f32[4]{0} abs(subtract.486)
655   negate.497 = f32[4]{0} negate(abs.496)
656   exponential.498 = f32[4]{0} exponential(negate.497)
657   log-plus-one.499 = f32[4]{0} log-plus-one(exponential.498)
658   add.500 = f32[4]{0} add(maximum.485, log-plus-one.499)
659   select.501 = f32[4]{0} select(and.494, add.495, add.500)
660   slice.368 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
661   reshape.1080 = f32[4]{0} reshape(slice.368)
662   add.574 = f32[4]{0} add(select.501, reshape.1080)
663   slice.384 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
664   reshape.1082 = f32[4]{0} reshape(slice.384)
665   add.539 = f32[4]{0} add(select.501, reshape.1082)
666   subtract.576 = f32[4]{0} subtract(add.574, add.539)
667   is-finite.577 = pred[4]{0} is-finite(subtract.576)
668   not.578 = pred[4]{0} not(is-finite.577)
669   abs.579 = f32[4]{0} abs(subtract.576)
670   constant.580 = f32[] constant(inf)
671   broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
672   compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ
673   not.583 = pred[4]{0} not(compare.582)
674   and.584 = pred[4]{0} and(not.578, not.583)
675   add.585 = f32[4]{0} add(add.574, add.539)
676   maximum.575 = f32[4]{0} maximum(add.574, add.539)
677   abs.586 = f32[4]{0} abs(subtract.576)
678   negate.587 = f32[4]{0} negate(abs.586)
679   exponential.588 = f32[4]{0} exponential(negate.587)
680   log-plus-one.589 = f32[4]{0} log-plus-one(exponential.588)
681   add.590 = f32[4]{0} add(maximum.575, log-plus-one.589)
682   select.591 = f32[4]{0} select(and.584, add.585, add.590)
683   slice.371 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
684   reshape.1084 = f32[4]{0} reshape(slice.371)
685   add.664 = f32[4]{0} add(select.591, reshape.1084)
686   slice.385 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
687   reshape.1086 = f32[4]{0} reshape(slice.385)
688   add.629 = f32[4]{0} add(select.591, reshape.1086)
689   subtract.666 = f32[4]{0} subtract(add.664, add.629)
690   is-finite.667 = pred[4]{0} is-finite(subtract.666)
691   not.668 = pred[4]{0} not(is-finite.667)
692   abs.669 = f32[4]{0} abs(subtract.666)
693   constant.670 = f32[] constant(inf)
694   broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
695   compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ
696   not.673 = pred[4]{0} not(compare.672)
697   and.674 = pred[4]{0} and(not.668, not.673)
698   add.675 = f32[4]{0} add(add.664, add.629)
699   maximum.665 = f32[4]{0} maximum(add.664, add.629)
700   abs.676 = f32[4]{0} abs(subtract.666)
701   negate.677 = f32[4]{0} negate(abs.676)
702   exponential.678 = f32[4]{0} exponential(negate.677)
703   log-plus-one.679 = f32[4]{0} log-plus-one(exponential.678)
704   add.680 = f32[4]{0} add(maximum.665, log-plus-one.679)
705   select.681 = f32[4]{0} select(and.674, add.675, add.680)
706   slice.374 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
707   reshape.1088 = f32[4]{0} reshape(slice.374)
708   add.754 = f32[4]{0} add(select.681, reshape.1088)
709   slice.386 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
710   reshape.1090 = f32[4]{0} reshape(slice.386)
711   add.719 = f32[4]{0} add(select.681, reshape.1090)
712   subtract.756 = f32[4]{0} subtract(add.754, add.719)
713   is-finite.757 = pred[4]{0} is-finite(subtract.756)
714   not.758 = pred[4]{0} not(is-finite.757)
715   abs.759 = f32[4]{0} abs(subtract.756)
716   constant.760 = f32[] constant(inf)
717   broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
718   compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ
719   not.763 = pred[4]{0} not(compare.762)
720   and.764 = pred[4]{0} and(not.758, not.763)
721   add.765 = f32[4]{0} add(add.754, add.719)
722   maximum.755 = f32[4]{0} maximum(add.754, add.719)
723   abs.766 = f32[4]{0} abs(subtract.756)
724   negate.767 = f32[4]{0} negate(abs.766)
725   exponential.768 = f32[4]{0} exponential(negate.767)
726   log-plus-one.769 = f32[4]{0} log-plus-one(exponential.768)
727   add.770 = f32[4]{0} add(maximum.755, log-plus-one.769)
728   select.771 = f32[4]{0} select(and.764, add.765, add.770)
729   slice.377 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
730   reshape.1092 = f32[4]{0} reshape(slice.377)
731   add.844 = f32[4]{0} add(select.771, reshape.1092)
732   slice.387 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
733   reshape.1094 = f32[4]{0} reshape(slice.387)
734   add.809 = f32[4]{0} add(select.771, reshape.1094)
735   subtract.846 = f32[4]{0} subtract(add.844, add.809)
736   is-finite.847 = pred[4]{0} is-finite(subtract.846)
737   not.848 = pred[4]{0} not(is-finite.847)
738   abs.849 = f32[4]{0} abs(subtract.846)
739   constant.850 = f32[] constant(inf)
740   broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
741   compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ
742   not.853 = pred[4]{0} not(compare.852)
743   and.854 = pred[4]{0} and(not.848, not.853)
744   add.855 = f32[4]{0} add(add.844, add.809)
745   maximum.845 = f32[4]{0} maximum(add.844, add.809)
746   abs.856 = f32[4]{0} abs(subtract.846)
747   negate.857 = f32[4]{0} negate(abs.856)
748   exponential.858 = f32[4]{0} exponential(negate.857)
749   log-plus-one.859 = f32[4]{0} log-plus-one(exponential.858)
750   add.860 = f32[4]{0} add(maximum.845, log-plus-one.859)
751   select.861 = f32[4]{0} select(and.854, add.855, add.860)
752   constant.865 = f32[] constant(0)
753   reduce.2 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
754   reduce.3 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
755   add.77 = f32[] add(reduce.2, reduce.3)
756   constant.719 = f32[] constant(0.125)
757   multiply = f32[] multiply(add.77, constant.719)
758   ROOT tuple.873 = (f32[]) tuple(multiply)
759 })")
760                     .value();
761   auto input_array = std::make_unique<Array2D<float>>(4, 2);
762   input_array->FillUnique(1.0f);
763   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
764   EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
765 }
766 
767 // Describes a binary rank-2 concatenation test.
768 struct R2BinarySpec {
769   int64_t lhs_dim0;
770   int64_t lhs_dim1;
771   int64_t rhs_dim0;
772   int64_t rhs_dim1;
773   int64_t concat_dimension;
774 };
775 
776 // TEST_P harness for binary rank-2 concatenation.
777 class ConcatR2BinaryTest : public ClientLibraryTestBase,
778                            public ::testing::WithParamInterface<R2BinarySpec> {
779 };
780 
TEST_P(ConcatR2BinaryTest,DoIt)781 TEST_P(ConcatR2BinaryTest, DoIt) {
782   const R2BinarySpec& spec = GetParam();
783   Array2D<int32_t> lhs(spec.lhs_dim0, spec.lhs_dim1);
784   lhs.FillUnique();
785   Array2D<int32_t> rhs(spec.rhs_dim0, spec.rhs_dim1);
786   rhs.FillUnique(1000);
787 
788   XlaBuilder builder(TestName());
789   auto a0 = ConstantR2FromArray2D<int32_t>(&builder, lhs);
790   auto a1 = ConstantR2FromArray2D<int32_t>(&builder, rhs);
791   ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
792 
793   std::unique_ptr<Array2D<int32_t>> expected =
794       ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
795   ComputeAndCompareR2<int32_t>(&builder, *expected, {});
796 }
797 
798 // Regression test for b/31944287. x*y is used (at the same index) by all
799 // operands of the concat. We should emit x*y in three incoming basic blocks of
800 // the concat because these basic blocks are not control-equivalent.
801 //
802 //      x*y
803 //    /  |   \
804 // add1 add2 add3
805 //    \  |   /
806 //     concat
XLA_TEST_F(ConcatTest,ConcatOperandsOfSameOperand)807 XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
808   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
809   auto x_literal = LiteralUtil::CreateR0<float>(2.f);
810   auto y_literal = LiteralUtil::CreateR0<float>(3.f);
811   auto x_data = client_->TransferToServer(x_literal).value();
812   auto y_data = client_->TransferToServer(y_literal).value();
813 
814   XlaBuilder builder(TestName());
815   auto x = Parameter(&builder, 0, f32_scalar, "x");
816   auto y = Parameter(&builder, 1, f32_scalar, "y");
817   auto mul = Mul(x, y);
818   auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
819   auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
820   auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
821   ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
822 
823   ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
824                              {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
825 }
826 
827 // Test that the HLO optimization to replace a concat of a broadcasted scalar
828 // produces the correct result in rank 1.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgument)829 XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
830   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
831   auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
832   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
833   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
834   auto x_data = client_->TransferToServer(x_literal).value();
835   auto y_data = client_->TransferToServer(y_literal).value();
836   auto z_data = client_->TransferToServer(z_literal).value();
837 
838   XlaBuilder builder(TestName());
839   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
840   auto y = Parameter(&builder, 1, f32_scalar, "y");
841   auto z = Parameter(&builder, 2, f32_scalar, "z");
842   auto bcast = Broadcast(y, {5});
843   auto bcast2 = Broadcast(z, {3});
844   auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
845   ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
846 
847   ComputeAndCompareR1<float>(
848       &builder,
849       {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
850       {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
851 }
852 
853 // Test that the HLO optimization to replace a concat of a broadcasted scalar
854 // produces the correct result in rank 3 with both high and low padding in
855 // different dimensions.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgumentR3)856 XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
857   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
858   Array3D<float> x3d(3, 5, 7, 3.14f);
859   auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
860   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
861   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
862   auto x_data = client_->TransferToServer(x_literal).value();
863   auto y_data = client_->TransferToServer(y_literal).value();
864   auto z_data = client_->TransferToServer(z_literal).value();
865 
866   XlaBuilder builder(TestName());
867   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
868   auto y = Parameter(&builder, 1, f32_scalar, "y");
869   auto z = Parameter(&builder, 2, f32_scalar, "y");
870   auto y_bcast = Broadcast(y, {1, 5, 7});
871   auto z_bcast = Broadcast(z, {4, 1, 7});
872   auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
873   ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
874   Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
875   Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
876   auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
877   auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
878 
879   ComputeAndCompareR3<float>(&builder, *concat1,
880                              {x_data.get(), y_data.get(), z_data.get()},
881                              ErrorSpec(1e-4));
882 }
883 
884 INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
885                         ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
886                                           R2BinarySpec{1, 1, 1, 1, 1},
887                                           R2BinarySpec{4, 3, 4, 3, 0},
888                                           R2BinarySpec{4, 3, 4, 3, 1},
889                                           R2BinarySpec{7, 128, 1, 128, 0},
890                                           R2BinarySpec{8, 127, 8, 1, 1}));
891 
892 }  // namespace
893 }  // namespace xla
894