• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 
32 namespace xla {
33 namespace {
34 
35 class BroadcastSimpleTest : public ClientLibraryTestBase {
36  public:
BuildBinOp(HloOpcode op,const XlaOp & lhs,const XlaOp & rhs,XlaBuilder * builder)37   XlaOp BuildBinOp(HloOpcode op, const XlaOp& lhs, const XlaOp& rhs,
38                    XlaBuilder* builder) {
39     switch (op) {
40       case HloOpcode::kMinimum: {
41         return Min(lhs, rhs);
42       }
43       case HloOpcode::kMaximum: {
44         return Max(lhs, rhs);
45       }
46       case HloOpcode::kMultiply: {
47         return Mul(lhs, rhs);
48       }
49       default: {
50         // Default to Add
51         return Add(lhs, rhs);
52       }
53     }
54   }
55 
MakeR3Data(absl::Span<const int64> bounds,absl::Span<const int64> minor_to_major,Shape * r3_shape,Array3D<float> * r3_array,float start,float end,int seed)56   std::unique_ptr<GlobalData> MakeR3Data(absl::Span<const int64> bounds,
57                                          absl::Span<const int64> minor_to_major,
58                                          Shape* r3_shape,
59                                          Array3D<float>* r3_array, float start,
60                                          float end, int seed) {
61     *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
62     r3_array->FillRandom(start, end, seed);
63     auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
64         LayoutUtil::MakeLayout(minor_to_major));
65     std::unique_ptr<GlobalData> r3_global_data =
66         client_->TransferToServer(r3_data).ConsumeValueOrDie();
67     return r3_global_data;
68   }
69 
MakeR2Data(absl::Span<const int64> bounds,absl::Span<const int64> minor_to_major,Shape * r2_shape,Array2D<float> * r2_array,float start,float end,int seed)70   std::unique_ptr<GlobalData> MakeR2Data(absl::Span<const int64> bounds,
71                                          absl::Span<const int64> minor_to_major,
72                                          Shape* r2_shape,
73                                          Array2D<float>* r2_array, float start,
74                                          float end, int seed) {
75     *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
76     r2_array->FillRandom(start, end, seed);
77     auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
78         LayoutUtil::MakeLayout(minor_to_major));
79     std::unique_ptr<GlobalData> r2_global_data =
80         client_->TransferToServer(r2_data).ConsumeValueOrDie();
81     return r2_global_data;
82   }
83 
ApplyOpToFloats(HloOpcode op,float lhs,float rhs)84   float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
85     switch (op) {
86       case HloOpcode::kMinimum: {
87         return std::min(lhs, rhs);
88       }
89       case HloOpcode::kMaximum: {
90         return std::max(lhs, rhs);
91       }
92       case HloOpcode::kMultiply: {
93         return lhs * rhs;
94       }
95       case HloOpcode::kAdd: {
96         return lhs + rhs;
97       }
98       default: {
99         // Default to Add
100         LOG(FATAL);
101       }
102     }
103   }
104 };
105 
106 using ::testing::HasSubstr;
107 
XLA_TEST_F(BroadcastSimpleTest,ScalarNoOpBroadcast)108 XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
109   XlaBuilder b(TestName());
110   Broadcast(ConstantR0<float>(&b, 1.5), {});
111   ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
112 }
113 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x3)114 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
115   XlaBuilder b(TestName());
116   Broadcast(ConstantR0<float>(&b, 2.25), {2, 3});
117   Array2D<float> expected(2, 3, 2.25);
118   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
119 }
120 
XLA_TEST_F(BroadcastSimpleTest,ScalarParamTo2D_2x3)121 XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
122   XlaBuilder b(TestName());
123   XlaOp src;
124   std::unique_ptr<GlobalData> param_data =
125       CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
126                                /*builder=*/&b, /*data_handle=*/&src);
127 
128   Broadcast(src, {2, 3});
129   Array2D<float> expected(2, 3, 2.25);
130   ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
131                              ErrorSpec(0.0001));
132 }
133 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x0)134 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
135   XlaBuilder b(TestName());
136   Broadcast(ConstantR0<float>(&b, 2.25), {2, 0});
137   Array2D<float> expected(2, 0);
138   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
139 }
140 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_0x2)141 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
142   XlaBuilder b(TestName());
143   Broadcast(ConstantR0<float>(&b, 2.25), {0, 2});
144   Array2D<float> expected(0, 2);
145   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
146 }
147 
148 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
149   XlaBuilder b(TestName());
150   Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {2});
151 
152   Array2D<float> expected(2, 3);
153   expected(0, 0) = 1;
154   expected(0, 1) = 2;
155   expected(0, 2) = 3;
156   expected(1, 0) = 1;
157   expected(1, 1) = 2;
158   expected(1, 2) = 3;
159   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
160 }
161 
162 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
163   XlaBuilder b(TestName());
164   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {1});
165 
166   Array2D<float> expected(2, 2);
167   expected(0, 0) = 1;
168   expected(0, 1) = 2;
169   expected(1, 0) = 1;
170   expected(1, 1) = 2;
171 
172   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
173 }
174 
175 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
176   XlaBuilder b(TestName());
177   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {0});
178 
179   Array2D<float> expected(2, 2);
180   expected(0, 0) = 1;
181   expected(0, 1) = 1;
182   expected(1, 0) = 2;
183   expected(1, 1) = 2;
184 
185   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
186 }
187 
188 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
189   XlaBuilder b(TestName());
190   BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
191                  {0, 1});
192 
193   Array3D<float> expected(2, 2, 2);
194   expected(0, 0, 0) = 1.0;
195   expected(1, 0, 0) = 2.0;
196   expected(0, 0, 1) = 1.0;
197   expected(1, 0, 1) = 2.0;
198   expected(0, 1, 0) = 5.0;
199   expected(1, 1, 0) = 6.0;
200   expected(1, 1, 1) = 6.0;
201   expected(0, 1, 1) = 5.0;
202 
203   ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
204 }
205 
206 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
207   XlaBuilder b(TestName());
208   BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
209                  {0, 2});
210 
211   Array3D<float> expected(2, 2, 2);
212   expected(0, 0, 0) = 1.0;
213   expected(1, 0, 0) = 2.0;
214   expected(0, 0, 1) = 5.0;
215   expected(1, 0, 1) = 6.0;
216   expected(0, 1, 0) = 1.0;
217   expected(1, 1, 0) = 2.0;
218   expected(1, 1, 1) = 6.0;
219   expected(0, 1, 1) = 5.0;
220 
221   ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
222 }
223 
224 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
225   XlaBuilder b(TestName());
226   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {3, 2}, {1});
227 
228   Array2D<float> expected(3, 2);
229   expected(0, 0) = 1;
230   expected(0, 1) = 2;
231   expected(1, 0) = 1;
232   expected(1, 1) = 2;
233   expected(2, 0) = 1;
234   expected(2, 1) = 2;
235 
236   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
237 }
238 
239 // Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest,BooleanAnd2DTo3D_Pred)240 XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
241   XlaBuilder b(TestName());
242 
243   Array2D<bool> x_vals(2, 1);
244   x_vals(0, 0) = true;
245   x_vals(1, 0) = false;
246   Array3D<bool> y_vals(2, 2, 1);
247   y_vals(0, 0, 0) = false;
248   y_vals(0, 1, 0) = false;
249   y_vals(1, 0, 0) = true;
250   y_vals(1, 1, 0) = true;
251 
252   XlaOp x, y;
253   auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
254   auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
255   And(x, y, /*broadcast_dimensions=*/{1, 2});
256 
257   Array3D<bool> expected(2, 2, 1);
258   expected(0, 0, 0) = false;
259   expected(0, 1, 0) = false;
260   expected(1, 0, 0) = true;
261   expected(1, 1, 0) = false;
262 
263   ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
264 }
265 
XLA_TEST_F(BroadcastSimpleTest,ZeroElement_1DTo2D)266 XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
267   XlaBuilder b(TestName());
268   Broadcast(ConstantR1<float>(&b, {}), {2});
269 
270   Array2D<float> expected(2, 0);
271   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
272 }
273 
274 XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
275   XlaBuilder b(TestName());
276   Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {0});
277 
278   Array2D<float> expected(0, 3);
279   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
280 }
281 
XLA_TEST_F(BroadcastSimpleTest,InDimensionAndDegenerateBroadcasting)282 XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
283   // Verify that binary op and degenerate dimension broadcast work together in
284   // the same operation.
285   //
286   // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
287   // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
288   // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
289   // dimensions.
290   XlaBuilder b(TestName());
291 
292   Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
293       ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
294                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
295       /*broadcast_dimensions=*/{1, 2});
296 
297   auto expected =
298       LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
299                                     {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
300 
301   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
302 }
303 
304 struct R3ImplicitBroadcastSpec {
305   std::array<int64, 3> output_bounds;
306   std::array<int64, 3> minor2major_layout;
307   std::array<int64, 3> input_bounds;
308   HloOpcode op;
309 } kR3ImplicitBroadcastTestCases[] = {
310     {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
311     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
312     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
313     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
314     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
315     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
316     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
317     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
318     {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
319     {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
320 };
321 
322 class BroadcastR3ImplicitTest
323     : public BroadcastSimpleTest,
324       public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
325 
XLA_TEST_P(BroadcastR3ImplicitTest,Doit)326 XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
327   const R3ImplicitBroadcastSpec& spec = GetParam();
328   XlaBuilder builder(TestName());
329 
330   Shape r3_shape, r3_implicit_shape;
331   Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
332                           spec.output_bounds[2]);
333   Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
334                                    spec.input_bounds[2]);
335 
336   std::unique_ptr<GlobalData> r3_global_data =
337       MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
338                  &r3_array, 1.0, 2.5, 56789);
339   std::unique_ptr<GlobalData> r3_implicit_global_data =
340       MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
341                  &r3_implicit_array, 1.0, 0.2, 56789);
342 
343   auto r3_implicit_parameter =
344       Parameter(&builder, 0, r3_implicit_shape, "input");
345   auto r3_parameter = Parameter(&builder, 1, r3_shape, "input");
346   BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
347 
348   Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
349                                 spec.output_bounds[2]);
350   auto Each = ([&](absl::Span<const int64> indices, float* value) {
351     float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
352                                           indices[1] % spec.input_bounds[1],
353                                           indices[2] % spec.input_bounds[2]);
354     float r3 = r3_array(indices[0], indices[1], indices[2]);
355     *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
356   });
357 
358   int n1 = expected_array.n1();
359   int n2 = expected_array.n2();
360   int n3 = expected_array.n3();
361   for (int64 i = 0; i < n1; i++) {
362     for (int64 j = 0; j < n2; j++) {
363       for (int64 k = 0; k < n3; k++) {
364         Each({i, j, k}, &expected_array(i, j, k));
365       }
366     }
367   }
368   auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
369   ComputeAndCompareLiteral(
370       &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
371       ErrorSpec(1e-7, 1e-7));
372 }
373 
374 INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
375                         BroadcastR3ImplicitTest,
376                         ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
377 
378 // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1_2)379 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
380   XlaBuilder b(TestName());
381   XlaOp r1h;
382   XlaOp r3h;
383 
384   Array3D<float> r1d = {{{1}}, {{2}}};
385   Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
386   auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
387   auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
388 
389   Add(r3h, r1h);
390 
391   auto expected =
392       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
393 
394   ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
395                            ErrorSpec(0.0001));
396 }
397 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1)398 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
399   XlaBuilder b(TestName());
400   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
401   auto r3 = ConstantLiteral(
402       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
403   Add(r3, r1);
404 
405   auto expected =
406       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
407 
408   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
409 }
410 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_2)411 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
412   XlaBuilder b(TestName());
413   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
414   auto r3 = ConstantLiteral(
415       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
416   Add(r3, r1);
417 
418   auto expected =
419       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
420 
421   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
422 }
423 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0)424 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
425   XlaBuilder b(TestName());
426   auto r1 =
427       ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
428   auto r3 = ConstantLiteral(
429       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
430   Add(r3, r1);
431 
432   auto expected =
433       LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
434 
435   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
436 }
437 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1)438 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
439   XlaBuilder b(TestName());
440   auto r1 =
441       ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
442   auto r3 = ConstantLiteral(
443       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
444   Add(r3, r1);
445 
446   auto expected =
447       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
448 
449   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
450 }
451 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_2)452 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
453   XlaBuilder b(TestName());
454   auto r1 = ConstantLiteral(
455       &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
456   auto r3 = ConstantLiteral(
457       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
458   Add(r3, r1);
459 
460   auto expected =
461       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
462 
463   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
464 }
465 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1_2)466 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
467   XlaBuilder b(TestName());
468   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
469   auto r3 = ConstantLiteral(
470       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
471   Add(r3, r1);
472 
473   auto expected =
474       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
475 
476   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
477 }
478 
479 struct R2ImplicitBroadcastSpec {
480   std::array<int64, 2> output_bounds;
481   std::array<int64, 2> minor2major_layout;
482   std::array<int64, 2> input_bounds1;
483   std::array<int64, 2> input_bounds2;
484   HloOpcode op1;
485   HloOpcode op2;
486 } kR2ImplicitBroadcastTestCases[] = {
487     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
488     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
489     {{{2, 3}},
490      {{1, 0}},
491      {{2, 1}},
492      {{1, 1}},
493      HloOpcode::kAdd,
494      HloOpcode::kMinimum},
495     {{{2, 3}},
496      {{1, 0}},
497      {{1, 3}},
498      {{1, 1}},
499      HloOpcode::kAdd,
500      HloOpcode::kMinimum},
501     {{{2, 3}},
502      {{1, 0}},
503      {{1, 1}},
504      {{1, 1}},
505      HloOpcode::kAdd,
506      HloOpcode::kMinimum},
507     {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
508     {{{150, 150}},
509      {{1, 0}},
510      {{150, 1}},
511      {{150, 1}},
512      HloOpcode::kAdd,
513      HloOpcode::kAdd},
514     {{{150, 150}},
515      {{1, 0}},
516      {{150, 1}},
517      {{1, 150}},
518      HloOpcode::kAdd,
519      HloOpcode::kAdd},
520     {{{150, 150}},
521      {{1, 0}},
522      {{150, 1}},
523      {{1, 1}},
524      HloOpcode::kAdd,
525      HloOpcode::kAdd},
526     {{{50, 150}},
527      {{1, 0}},
528      {{50, 1}},
529      {{50, 1}},
530      HloOpcode::kAdd,
531      HloOpcode::kAdd},
532     {{{50, 150}},
533      {{1, 0}},
534      {{50, 1}},
535      {{1, 150}},
536      HloOpcode::kAdd,
537      HloOpcode::kAdd},
538     {{{50, 150}},
539      {{1, 0}},
540      {{50, 1}},
541      {{1, 1}},
542      HloOpcode::kAdd,
543      HloOpcode::kAdd},
544     {{{150, 50}},
545      {{1, 0}},
546      {{150, 1}},
547      {{150, 1}},
548      HloOpcode::kAdd,
549      HloOpcode::kAdd},
550     {{{150, 50}},
551      {{1, 0}},
552      {{150, 1}},
553      {{1, 50}},
554      HloOpcode::kAdd,
555      HloOpcode::kAdd},
556     {{{150, 50}},
557      {{1, 0}},
558      {{150, 1}},
559      {{1, 1}},
560      HloOpcode::kAdd,
561      HloOpcode::kAdd}};
562 
563 class BroadcastR2ImplicitTest
564     : public BroadcastSimpleTest,
565       public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
566 
567 // Test r2 op1 r2_implicit_1 op2 r2_implicit_2
568 // where R2 is a rank-2 operand, and r2_implicit_2 are two
569 // rank-2 operands with degenerate dimensions:
XLA_TEST_P(BroadcastR2ImplicitTest,Doit)570 XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
571   const R2ImplicitBroadcastSpec& spec = GetParam();
572 
573   XlaBuilder builder(TestName());
574 
575   // Operands with degenerate dimensions require implicit broadcasting:
576   Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
577   Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
578   Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
579                                     spec.input_bounds1[1]);
580   Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
581                                     spec.input_bounds2[1]);
582 
583   std::unique_ptr<GlobalData> r2_global_data =
584       MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
585                  &r2_array, 1.0, 2.5, 56789);
586   std::unique_ptr<GlobalData> r2_implicit_global_data1 =
587       MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
588                  &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
589   std::unique_ptr<GlobalData> r2_implicit_global_data2 =
590       MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
591                  &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
592 
593   auto r2_implicit_parameter1 =
594       Parameter(&builder, 0, r2_implicit_shape1, "input0");
595   auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1");
596   auto r2_implicit_parameter2 =
597       Parameter(&builder, 2, r2_implicit_shape2, "input2");
598 
599   XlaOp op1 =
600       BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
601   BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
602 
603   Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
604 
605   expected_array.Each([&](int64 i, int64 j, float* v) {
606     float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
607                                   j % spec.input_bounds1[1]);
608     float v2 = r2_array(i, j);
609     float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
610                                   j % spec.input_bounds2[1]);
611     float tmp = ApplyOpToFloats(spec.op1, v1, v2);
612     *v = ApplyOpToFloats(spec.op2, tmp, v3);
613   });
614 
615   auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
616   ComputeAndCompareLiteral(
617       &builder, expected,
618       {r2_implicit_global_data1.get(), r2_global_data.get(),
619        r2_implicit_global_data2.get()},
620       ErrorSpec(1e-6, 1e-6));
621 }
622 
623 INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
624                         BroadcastR2ImplicitTest,
625                         ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
626 
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_0)627 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
628   XlaBuilder b(TestName());
629   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
630   auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
631   Add(r2, r1);
632 
633   auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
634 
635   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
636 }
637 
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_1)638 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
639   XlaBuilder b(TestName());
640   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
641   auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
642   Add(r2, r1);
643 
644   auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
645 
646   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
647 }
648 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim0)649 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
650   XlaBuilder b(TestName());
651   auto r1 = ConstantR1<float>(&b, {10, 20});
652   auto r3 = ConstantLiteral(
653       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
654   Add(r3, r1, {0});
655 
656   auto expected = LiteralUtil::CreateR3<float>(
657       {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
658 
659   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
660 }
661 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim1)662 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
663   XlaBuilder b(TestName());
664   auto r1 = ConstantR1<float>(&b, {10, 20});
665   auto r3 = ConstantLiteral(
666       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
667   Add(r1, r3, {1});
668 
669   auto expected = LiteralUtil::CreateR3<float>(
670       {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
671 
672   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
673 }
674 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim2)675 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
676   XlaBuilder b(TestName());
677   auto r1 = ConstantR1<float>(&b, {10, 20});
678   auto r3 = ConstantLiteral(
679       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
680   Add(r1, r3, {2});
681 
682   auto expected = LiteralUtil::CreateR3<float>(
683       {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
684 
685   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
686 }
687 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAll)688 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
689   XlaBuilder b(TestName());
690   auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
691   auto r1_1 = ConstantR1<float>(&b, {100, 200});
692   auto r1_2 = ConstantR1<float>(&b, {10, 20});
693   auto r3 = ConstantLiteral(
694       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
695   for (int i = 0; i < 3; ++i) {
696     r3 = Add(r1_0, r3, {0});
697     r3 = Add(r3, r1_1, {1});
698     r3 = Add(r1_2, r3, {2});
699   }
700   r3 = Mul(r3, ConstantR0<float>(&b, -2));
701 
702   auto expected = LiteralUtil::CreateR3<float>(
703       {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
704        {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
705 
706   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
707 }
708 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAllWithScalarBroadcast)709 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
710   XlaBuilder b(TestName());
711   auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
712   auto r1_1 = ConstantR1<float>(&b, {100, 200});
713   auto r1_2 = ConstantR1<float>(&b, {10, 20});
714   auto r0 = ConstantR0<float>(&b, 3);
715   auto r3 = Broadcast(r0, {2, 2, 2});
716   for (int i = 0; i < 3; ++i) {
717     r3 = Add(r1_0, r3, {0});
718     r3 = Add(r3, r1_1, {1});
719     r3 = Add(r1_2, r3, {2});
720   }
721   r3 = Mul(r3, ConstantR0<float>(&b, -1));
722 
723   auto expected = LiteralUtil::CreateR3<float>(
724       {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
725        {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
726 
727   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
728 }
729 
XLA_TEST_F(BroadcastSimpleTest,InvalidBinaryAndDegenerateBroadcasting)730 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
731   // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
732   // results in a shape incompatible with the lhs [2, 3, 1].
733   XlaBuilder b(TestName());
734 
735   Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
736       ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
737                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
738       /*broadcast_dimensions=*/{1, 2});
739 
740   auto result_status = Execute(&b, {});
741   EXPECT_FALSE(result_status.ok());
742   EXPECT_THAT(result_status.status().error_message(),
743               HasSubstr("dimension 0 mismatch"));
744 }
745 
XLA_TEST_F(BroadcastSimpleTest,InvalidInDimensionBroadcasting)746 XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
747   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
748   XlaBuilder b(TestName());
749 
750   Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
751       ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
752 
753   auto result_status = Execute(&b, {});
754   EXPECT_FALSE(result_status.ok());
755   EXPECT_THAT(result_status.status().error_message(),
756               HasSubstr("op add with incompatible shapes"));
757 }
758 
XLA_TEST_F(BroadcastSimpleTest,InvalidDegenerateBroadcasting)759 XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
760   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
761   XlaBuilder b(TestName());
762 
763   Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
764       ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
765 
766   auto result_status = Execute(&b, {});
767   EXPECT_FALSE(result_status.ok());
768   EXPECT_THAT(result_status.status().error_message(),
769               HasSubstr("op add with incompatible shapes"));
770 }
771 
772 }  // namespace
773 }  // namespace xla
774