1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31
32 namespace xla {
33 namespace {
34
35 class BroadcastSimpleTest : public ClientLibraryTestBase {
36 public:
BuildBinOp(HloOpcode op,const XlaOp & lhs,const XlaOp & rhs,XlaBuilder * builder)37 XlaOp BuildBinOp(HloOpcode op, const XlaOp& lhs, const XlaOp& rhs,
38 XlaBuilder* builder) {
39 switch (op) {
40 case HloOpcode::kMinimum: {
41 return Min(lhs, rhs);
42 }
43 case HloOpcode::kMaximum: {
44 return Max(lhs, rhs);
45 }
46 case HloOpcode::kMultiply: {
47 return Mul(lhs, rhs);
48 }
49 default: {
50 // Default to Add
51 return Add(lhs, rhs);
52 }
53 }
54 }
55
MakeR3Data(absl::Span<const int64> bounds,absl::Span<const int64> minor_to_major,Shape * r3_shape,Array3D<float> * r3_array,float start,float end,int seed)56 std::unique_ptr<GlobalData> MakeR3Data(absl::Span<const int64> bounds,
57 absl::Span<const int64> minor_to_major,
58 Shape* r3_shape,
59 Array3D<float>* r3_array, float start,
60 float end, int seed) {
61 *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
62 r3_array->FillRandom(start, end, seed);
63 auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
64 LayoutUtil::MakeLayout(minor_to_major));
65 std::unique_ptr<GlobalData> r3_global_data =
66 client_->TransferToServer(r3_data).ConsumeValueOrDie();
67 return r3_global_data;
68 }
69
MakeR2Data(absl::Span<const int64> bounds,absl::Span<const int64> minor_to_major,Shape * r2_shape,Array2D<float> * r2_array,float start,float end,int seed)70 std::unique_ptr<GlobalData> MakeR2Data(absl::Span<const int64> bounds,
71 absl::Span<const int64> minor_to_major,
72 Shape* r2_shape,
73 Array2D<float>* r2_array, float start,
74 float end, int seed) {
75 *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
76 r2_array->FillRandom(start, end, seed);
77 auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
78 LayoutUtil::MakeLayout(minor_to_major));
79 std::unique_ptr<GlobalData> r2_global_data =
80 client_->TransferToServer(r2_data).ConsumeValueOrDie();
81 return r2_global_data;
82 }
83
ApplyOpToFloats(HloOpcode op,float lhs,float rhs)84 float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
85 switch (op) {
86 case HloOpcode::kMinimum: {
87 return std::min(lhs, rhs);
88 }
89 case HloOpcode::kMaximum: {
90 return std::max(lhs, rhs);
91 }
92 case HloOpcode::kMultiply: {
93 return lhs * rhs;
94 }
95 case HloOpcode::kAdd: {
96 return lhs + rhs;
97 }
98 default: {
99 // Default to Add
100 LOG(FATAL);
101 }
102 }
103 }
104 };
105
106 using ::testing::HasSubstr;
107
XLA_TEST_F(BroadcastSimpleTest,ScalarNoOpBroadcast)108 XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
109 XlaBuilder b(TestName());
110 Broadcast(ConstantR0<float>(&b, 1.5), {});
111 ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
112 }
113
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x3)114 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
115 XlaBuilder b(TestName());
116 Broadcast(ConstantR0<float>(&b, 2.25), {2, 3});
117 Array2D<float> expected(2, 3, 2.25);
118 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
119 }
120
XLA_TEST_F(BroadcastSimpleTest,ScalarParamTo2D_2x3)121 XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
122 XlaBuilder b(TestName());
123 XlaOp src;
124 std::unique_ptr<GlobalData> param_data =
125 CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
126 /*builder=*/&b, /*data_handle=*/&src);
127
128 Broadcast(src, {2, 3});
129 Array2D<float> expected(2, 3, 2.25);
130 ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
131 ErrorSpec(0.0001));
132 }
133
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x0)134 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
135 XlaBuilder b(TestName());
136 Broadcast(ConstantR0<float>(&b, 2.25), {2, 0});
137 Array2D<float> expected(2, 0);
138 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
139 }
140
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_0x2)141 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
142 XlaBuilder b(TestName());
143 Broadcast(ConstantR0<float>(&b, 2.25), {0, 2});
144 Array2D<float> expected(0, 2);
145 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
146 }
147
148 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
149 XlaBuilder b(TestName());
150 Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {2});
151
152 Array2D<float> expected(2, 3);
153 expected(0, 0) = 1;
154 expected(0, 1) = 2;
155 expected(0, 2) = 3;
156 expected(1, 0) = 1;
157 expected(1, 1) = 2;
158 expected(1, 2) = 3;
159 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
160 }
161
162 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
163 XlaBuilder b(TestName());
164 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {1});
165
166 Array2D<float> expected(2, 2);
167 expected(0, 0) = 1;
168 expected(0, 1) = 2;
169 expected(1, 0) = 1;
170 expected(1, 1) = 2;
171
172 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
173 }
174
175 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
176 XlaBuilder b(TestName());
177 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {0});
178
179 Array2D<float> expected(2, 2);
180 expected(0, 0) = 1;
181 expected(0, 1) = 1;
182 expected(1, 0) = 2;
183 expected(1, 1) = 2;
184
185 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
186 }
187
188 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
189 XlaBuilder b(TestName());
190 BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
191 {0, 1});
192
193 Array3D<float> expected(2, 2, 2);
194 expected(0, 0, 0) = 1.0;
195 expected(1, 0, 0) = 2.0;
196 expected(0, 0, 1) = 1.0;
197 expected(1, 0, 1) = 2.0;
198 expected(0, 1, 0) = 5.0;
199 expected(1, 1, 0) = 6.0;
200 expected(1, 1, 1) = 6.0;
201 expected(0, 1, 1) = 5.0;
202
203 ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
204 }
205
206 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
207 XlaBuilder b(TestName());
208 BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
209 {0, 2});
210
211 Array3D<float> expected(2, 2, 2);
212 expected(0, 0, 0) = 1.0;
213 expected(1, 0, 0) = 2.0;
214 expected(0, 0, 1) = 5.0;
215 expected(1, 0, 1) = 6.0;
216 expected(0, 1, 0) = 1.0;
217 expected(1, 1, 0) = 2.0;
218 expected(1, 1, 1) = 6.0;
219 expected(0, 1, 1) = 5.0;
220
221 ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
222 }
223
224 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
225 XlaBuilder b(TestName());
226 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {3, 2}, {1});
227
228 Array2D<float> expected(3, 2);
229 expected(0, 0) = 1;
230 expected(0, 1) = 2;
231 expected(1, 0) = 1;
232 expected(1, 1) = 2;
233 expected(2, 0) = 1;
234 expected(2, 1) = 2;
235
236 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
237 }
238
239 // Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest,BooleanAnd2DTo3D_Pred)240 XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
241 XlaBuilder b(TestName());
242
243 Array2D<bool> x_vals(2, 1);
244 x_vals(0, 0) = true;
245 x_vals(1, 0) = false;
246 Array3D<bool> y_vals(2, 2, 1);
247 y_vals(0, 0, 0) = false;
248 y_vals(0, 1, 0) = false;
249 y_vals(1, 0, 0) = true;
250 y_vals(1, 1, 0) = true;
251
252 XlaOp x, y;
253 auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
254 auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
255 And(x, y, /*broadcast_dimensions=*/{1, 2});
256
257 Array3D<bool> expected(2, 2, 1);
258 expected(0, 0, 0) = false;
259 expected(0, 1, 0) = false;
260 expected(1, 0, 0) = true;
261 expected(1, 1, 0) = false;
262
263 ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
264 }
265
XLA_TEST_F(BroadcastSimpleTest,ZeroElement_1DTo2D)266 XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
267 XlaBuilder b(TestName());
268 Broadcast(ConstantR1<float>(&b, {}), {2});
269
270 Array2D<float> expected(2, 0);
271 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
272 }
273
274 XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
275 XlaBuilder b(TestName());
276 Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {0});
277
278 Array2D<float> expected(0, 3);
279 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
280 }
281
XLA_TEST_F(BroadcastSimpleTest,InDimensionAndDegenerateBroadcasting)282 XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
283 // Verify that binary op and degenerate dimension broadcast work together in
284 // the same operation.
285 //
286 // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
287 // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
288 // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
289 // dimensions.
290 XlaBuilder b(TestName());
291
292 Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
293 ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
294 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
295 /*broadcast_dimensions=*/{1, 2});
296
297 auto expected =
298 LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
299 {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
300
301 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
302 }
303
304 struct R3ImplicitBroadcastSpec {
305 std::array<int64, 3> output_bounds;
306 std::array<int64, 3> minor2major_layout;
307 std::array<int64, 3> input_bounds;
308 HloOpcode op;
309 } kR3ImplicitBroadcastTestCases[] = {
310 {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
311 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
312 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
313 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
314 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
315 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
316 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
317 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
318 {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
319 {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
320 };
321
322 class BroadcastR3ImplicitTest
323 : public BroadcastSimpleTest,
324 public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
325
XLA_TEST_P(BroadcastR3ImplicitTest,Doit)326 XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
327 const R3ImplicitBroadcastSpec& spec = GetParam();
328 XlaBuilder builder(TestName());
329
330 Shape r3_shape, r3_implicit_shape;
331 Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
332 spec.output_bounds[2]);
333 Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
334 spec.input_bounds[2]);
335
336 std::unique_ptr<GlobalData> r3_global_data =
337 MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
338 &r3_array, 1.0, 2.5, 56789);
339 std::unique_ptr<GlobalData> r3_implicit_global_data =
340 MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
341 &r3_implicit_array, 1.0, 0.2, 56789);
342
343 auto r3_implicit_parameter =
344 Parameter(&builder, 0, r3_implicit_shape, "input");
345 auto r3_parameter = Parameter(&builder, 1, r3_shape, "input");
346 BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
347
348 Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
349 spec.output_bounds[2]);
350 auto Each = ([&](absl::Span<const int64> indices, float* value) {
351 float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
352 indices[1] % spec.input_bounds[1],
353 indices[2] % spec.input_bounds[2]);
354 float r3 = r3_array(indices[0], indices[1], indices[2]);
355 *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
356 });
357
358 int n1 = expected_array.n1();
359 int n2 = expected_array.n2();
360 int n3 = expected_array.n3();
361 for (int64 i = 0; i < n1; i++) {
362 for (int64 j = 0; j < n2; j++) {
363 for (int64 k = 0; k < n3; k++) {
364 Each({i, j, k}, &expected_array(i, j, k));
365 }
366 }
367 }
368 auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
369 ComputeAndCompareLiteral(
370 &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
371 ErrorSpec(1e-7, 1e-7));
372 }
373
374 INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
375 BroadcastR3ImplicitTest,
376 ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
377
378 // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1_2)379 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
380 XlaBuilder b(TestName());
381 XlaOp r1h;
382 XlaOp r3h;
383
384 Array3D<float> r1d = {{{1}}, {{2}}};
385 Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
386 auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
387 auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
388
389 Add(r3h, r1h);
390
391 auto expected =
392 LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
393
394 ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
395 ErrorSpec(0.0001));
396 }
397
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1)398 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
399 XlaBuilder b(TestName());
400 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
401 auto r3 = ConstantLiteral(
402 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
403 Add(r3, r1);
404
405 auto expected =
406 LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
407
408 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
409 }
410
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_2)411 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
412 XlaBuilder b(TestName());
413 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
414 auto r3 = ConstantLiteral(
415 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
416 Add(r3, r1);
417
418 auto expected =
419 LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
420
421 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
422 }
423
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0)424 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
425 XlaBuilder b(TestName());
426 auto r1 =
427 ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
428 auto r3 = ConstantLiteral(
429 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
430 Add(r3, r1);
431
432 auto expected =
433 LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
434
435 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
436 }
437
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1)438 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
439 XlaBuilder b(TestName());
440 auto r1 =
441 ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
442 auto r3 = ConstantLiteral(
443 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
444 Add(r3, r1);
445
446 auto expected =
447 LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
448
449 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
450 }
451
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_2)452 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
453 XlaBuilder b(TestName());
454 auto r1 = ConstantLiteral(
455 &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
456 auto r3 = ConstantLiteral(
457 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
458 Add(r3, r1);
459
460 auto expected =
461 LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
462
463 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
464 }
465
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1_2)466 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
467 XlaBuilder b(TestName());
468 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
469 auto r3 = ConstantLiteral(
470 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
471 Add(r3, r1);
472
473 auto expected =
474 LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
475
476 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
477 }
478
479 struct R2ImplicitBroadcastSpec {
480 std::array<int64, 2> output_bounds;
481 std::array<int64, 2> minor2major_layout;
482 std::array<int64, 2> input_bounds1;
483 std::array<int64, 2> input_bounds2;
484 HloOpcode op1;
485 HloOpcode op2;
486 } kR2ImplicitBroadcastTestCases[] = {
487 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
488 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
489 {{{2, 3}},
490 {{1, 0}},
491 {{2, 1}},
492 {{1, 1}},
493 HloOpcode::kAdd,
494 HloOpcode::kMinimum},
495 {{{2, 3}},
496 {{1, 0}},
497 {{1, 3}},
498 {{1, 1}},
499 HloOpcode::kAdd,
500 HloOpcode::kMinimum},
501 {{{2, 3}},
502 {{1, 0}},
503 {{1, 1}},
504 {{1, 1}},
505 HloOpcode::kAdd,
506 HloOpcode::kMinimum},
507 {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
508 {{{150, 150}},
509 {{1, 0}},
510 {{150, 1}},
511 {{150, 1}},
512 HloOpcode::kAdd,
513 HloOpcode::kAdd},
514 {{{150, 150}},
515 {{1, 0}},
516 {{150, 1}},
517 {{1, 150}},
518 HloOpcode::kAdd,
519 HloOpcode::kAdd},
520 {{{150, 150}},
521 {{1, 0}},
522 {{150, 1}},
523 {{1, 1}},
524 HloOpcode::kAdd,
525 HloOpcode::kAdd},
526 {{{50, 150}},
527 {{1, 0}},
528 {{50, 1}},
529 {{50, 1}},
530 HloOpcode::kAdd,
531 HloOpcode::kAdd},
532 {{{50, 150}},
533 {{1, 0}},
534 {{50, 1}},
535 {{1, 150}},
536 HloOpcode::kAdd,
537 HloOpcode::kAdd},
538 {{{50, 150}},
539 {{1, 0}},
540 {{50, 1}},
541 {{1, 1}},
542 HloOpcode::kAdd,
543 HloOpcode::kAdd},
544 {{{150, 50}},
545 {{1, 0}},
546 {{150, 1}},
547 {{150, 1}},
548 HloOpcode::kAdd,
549 HloOpcode::kAdd},
550 {{{150, 50}},
551 {{1, 0}},
552 {{150, 1}},
553 {{1, 50}},
554 HloOpcode::kAdd,
555 HloOpcode::kAdd},
556 {{{150, 50}},
557 {{1, 0}},
558 {{150, 1}},
559 {{1, 1}},
560 HloOpcode::kAdd,
561 HloOpcode::kAdd}};
562
563 class BroadcastR2ImplicitTest
564 : public BroadcastSimpleTest,
565 public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
566
567 // Test r2 op1 r2_implicit_1 op2 r2_implicit_2
568 // where R2 is a rank-2 operand, and r2_implicit_2 are two
569 // rank-2 operands with degenerate dimensions:
XLA_TEST_P(BroadcastR2ImplicitTest,Doit)570 XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
571 const R2ImplicitBroadcastSpec& spec = GetParam();
572
573 XlaBuilder builder(TestName());
574
575 // Operands with degenerate dimensions require implicit broadcasting:
576 Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
577 Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
578 Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
579 spec.input_bounds1[1]);
580 Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
581 spec.input_bounds2[1]);
582
583 std::unique_ptr<GlobalData> r2_global_data =
584 MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
585 &r2_array, 1.0, 2.5, 56789);
586 std::unique_ptr<GlobalData> r2_implicit_global_data1 =
587 MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
588 &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
589 std::unique_ptr<GlobalData> r2_implicit_global_data2 =
590 MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
591 &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
592
593 auto r2_implicit_parameter1 =
594 Parameter(&builder, 0, r2_implicit_shape1, "input0");
595 auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1");
596 auto r2_implicit_parameter2 =
597 Parameter(&builder, 2, r2_implicit_shape2, "input2");
598
599 XlaOp op1 =
600 BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
601 BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
602
603 Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
604
605 expected_array.Each([&](int64 i, int64 j, float* v) {
606 float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
607 j % spec.input_bounds1[1]);
608 float v2 = r2_array(i, j);
609 float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
610 j % spec.input_bounds2[1]);
611 float tmp = ApplyOpToFloats(spec.op1, v1, v2);
612 *v = ApplyOpToFloats(spec.op2, tmp, v3);
613 });
614
615 auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
616 ComputeAndCompareLiteral(
617 &builder, expected,
618 {r2_implicit_global_data1.get(), r2_global_data.get(),
619 r2_implicit_global_data2.get()},
620 ErrorSpec(1e-6, 1e-6));
621 }
622
623 INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
624 BroadcastR2ImplicitTest,
625 ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
626
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_0)627 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
628 XlaBuilder b(TestName());
629 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
630 auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
631 Add(r2, r1);
632
633 auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
634
635 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
636 }
637
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_1)638 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
639 XlaBuilder b(TestName());
640 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
641 auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
642 Add(r2, r1);
643
644 auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
645
646 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
647 }
648
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim0)649 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
650 XlaBuilder b(TestName());
651 auto r1 = ConstantR1<float>(&b, {10, 20});
652 auto r3 = ConstantLiteral(
653 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
654 Add(r3, r1, {0});
655
656 auto expected = LiteralUtil::CreateR3<float>(
657 {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
658
659 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
660 }
661
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim1)662 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
663 XlaBuilder b(TestName());
664 auto r1 = ConstantR1<float>(&b, {10, 20});
665 auto r3 = ConstantLiteral(
666 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
667 Add(r1, r3, {1});
668
669 auto expected = LiteralUtil::CreateR3<float>(
670 {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
671
672 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
673 }
674
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim2)675 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
676 XlaBuilder b(TestName());
677 auto r1 = ConstantR1<float>(&b, {10, 20});
678 auto r3 = ConstantLiteral(
679 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
680 Add(r1, r3, {2});
681
682 auto expected = LiteralUtil::CreateR3<float>(
683 {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
684
685 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
686 }
687
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAll)688 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
689 XlaBuilder b(TestName());
690 auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
691 auto r1_1 = ConstantR1<float>(&b, {100, 200});
692 auto r1_2 = ConstantR1<float>(&b, {10, 20});
693 auto r3 = ConstantLiteral(
694 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
695 for (int i = 0; i < 3; ++i) {
696 r3 = Add(r1_0, r3, {0});
697 r3 = Add(r3, r1_1, {1});
698 r3 = Add(r1_2, r3, {2});
699 }
700 r3 = Mul(r3, ConstantR0<float>(&b, -2));
701
702 auto expected = LiteralUtil::CreateR3<float>(
703 {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
704 {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
705
706 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
707 }
708
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAllWithScalarBroadcast)709 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
710 XlaBuilder b(TestName());
711 auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
712 auto r1_1 = ConstantR1<float>(&b, {100, 200});
713 auto r1_2 = ConstantR1<float>(&b, {10, 20});
714 auto r0 = ConstantR0<float>(&b, 3);
715 auto r3 = Broadcast(r0, {2, 2, 2});
716 for (int i = 0; i < 3; ++i) {
717 r3 = Add(r1_0, r3, {0});
718 r3 = Add(r3, r1_1, {1});
719 r3 = Add(r1_2, r3, {2});
720 }
721 r3 = Mul(r3, ConstantR0<float>(&b, -1));
722
723 auto expected = LiteralUtil::CreateR3<float>(
724 {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
725 {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
726
727 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
728 }
729
XLA_TEST_F(BroadcastSimpleTest,InvalidBinaryAndDegenerateBroadcasting)730 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
731 // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
732 // results in a shape incompatible with the lhs [2, 3, 1].
733 XlaBuilder b(TestName());
734
735 Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
736 ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
737 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
738 /*broadcast_dimensions=*/{1, 2});
739
740 auto result_status = Execute(&b, {});
741 EXPECT_FALSE(result_status.ok());
742 EXPECT_THAT(result_status.status().error_message(),
743 HasSubstr("dimension 0 mismatch"));
744 }
745
XLA_TEST_F(BroadcastSimpleTest,InvalidInDimensionBroadcasting)746 XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
747 // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
748 XlaBuilder b(TestName());
749
750 Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
751 ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
752
753 auto result_status = Execute(&b, {});
754 EXPECT_FALSE(result_status.ok());
755 EXPECT_THAT(result_status.status().error_message(),
756 HasSubstr("op add with incompatible shapes"));
757 }
758
XLA_TEST_F(BroadcastSimpleTest,InvalidDegenerateBroadcasting)759 XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
760 // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
761 XlaBuilder b(TestName());
762
763 Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
764 ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
765
766 auto result_status = Execute(&b, {});
767 EXPECT_FALSE(result_status.ok());
768 EXPECT_THAT(result_status.status().error_message(),
769 HasSubstr("op add with incompatible shapes"));
770 }
771
772 } // namespace
773 } // namespace xla
774