• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace xla {
31 namespace {
32 
33 class BroadcastTest : public HloTestBase {};
34 
XLA_TEST_F(BroadcastTest,BroadcastScalarToScalar)35 XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
36   // Test degenerate case of broadcasting a scalar into a scalar.
37   auto builder = HloComputation::Builder(TestName());
38   auto input = builder.AddInstruction(
39       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
40   builder.AddInstruction(HloInstruction::CreateBroadcast(
41       ShapeUtil::MakeShape(F32, {}), input, {}));
42 
43   // Create HLO module, compile, and execute.
44   auto hlo_module = CreateNewVerifiedModule();
45   hlo_module->AddEntryComputation(builder.Build());
46   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
47 
48   EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
49                                     error_spec_));
50 }
51 
XLA_TEST_F(BroadcastTest,BroadcastScalarTo2D)52 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
53   auto builder = HloComputation::Builder(TestName());
54   auto input = builder.AddInstruction(
55       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
56   builder.AddInstruction(HloInstruction::CreateBroadcast(
57       ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
58 
59   // Create HLO module, compile, and execute.
60   auto hlo_module = CreateNewVerifiedModule();
61   hlo_module->AddEntryComputation(builder.Build());
62   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
63 
64   EXPECT_TRUE(LiteralTestUtil::Near(
65       LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
66       error_spec_));
67 }
68 
XLA_TEST_F(BroadcastTest,BroadcastVectorTo2D)69 XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
70   auto builder = HloComputation::Builder(TestName());
71   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
72       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
73 
74   // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
75   // to enable testing of the results.
76   auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
77       ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
78   auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
79       ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
80   builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
81 
82   // Create HLO module, compile, and execute.
83   auto hlo_module = CreateNewVerifiedModule();
84   hlo_module->AddEntryComputation(builder.Build());
85   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
86 
87   EXPECT_TRUE(LiteralTestUtil::Near(
88       LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
89       LiteralSlice(result, {0}), error_spec_));
90 
91   EXPECT_TRUE(LiteralTestUtil::Near(
92       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
93       LiteralSlice(result, {1}), error_spec_));
94 }
95 
XLA_TEST_F(BroadcastTest,Broadcast2DTo2D)96 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
97   auto builder = HloComputation::Builder(TestName());
98   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
99       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
100   builder.AddInstruction(HloInstruction::CreateBroadcast(
101       ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
102 
103   // Create HLO module, compile, and execute.
104   auto hlo_module = CreateNewVerifiedModule();
105   hlo_module->AddEntryComputation(builder.Build());
106   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
107 
108   EXPECT_TRUE(LiteralTestUtil::Near(
109       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
110       error_spec_));
111 }
112 
XLA_TEST_F(BroadcastTest,Broadcast2DTo2DTranspose)113 XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
114   // Degenerately broadcasting a shape into a shape of the same rank reorders
115   // the dimensions, ie transpose.
116   auto builder = HloComputation::Builder(TestName());
117   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
118       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
119   builder.AddInstruction(HloInstruction::CreateBroadcast(
120       ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
121 
122   // Create HLO module, compile, and execute.
123   auto hlo_module = CreateNewVerifiedModule();
124   hlo_module->AddEntryComputation(builder.Build());
125   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
126 
127   EXPECT_TRUE(LiteralTestUtil::Near(
128       LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
129       error_spec_));
130 }
131 
XLA_TEST_F(BroadcastTest,Broadcast2DTo3D)132 XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
133   auto builder = HloComputation::Builder(TestName());
134   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
135       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
136   builder.AddInstruction(HloInstruction::CreateBroadcast(
137       ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
138 
139   // Create HLO module, compile, and execute.
140   auto hlo_module = CreateNewVerifiedModule();
141   hlo_module->AddEntryComputation(builder.Build());
142   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
143 
144   EXPECT_TRUE(LiteralTestUtil::Near(
145       LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
146                                     {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
147       result, error_spec_));
148 }
149 
TEST_F(BroadcastTest,Broadcast_R1_2_To_R4_2x2x3x3)150 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
151   auto builder = HloComputation::Builder(TestName());
152   auto input = builder.AddInstruction(
153       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
154 
155   // Broadcast vector in dimension 1.
156   builder.AddInstruction(HloInstruction::CreateBroadcast(
157       ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
158 
159   // Create HLO module, compile, and execute.
160   auto hlo_module = CreateNewVerifiedModule();
161   hlo_module->AddEntryComputation(builder.Build());
162   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
163 
164   Array4D<float> expected(2, 2, 3, 3);
165   Array2D<float> pz({{1, 2}, {1, 2}});
166   expected.FillWithPZ(pz);
167 
168   EXPECT_TRUE(LiteralTestUtil::Near(
169       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
170 }
171 
TEST_F(BroadcastTest,Broadcast_R1_1025_To_R4_3x3x3x1025)172 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
173   auto builder = HloComputation::Builder(TestName());
174   std::vector<float> input_data(1025);
175   int64_t r1_size = input_data.size();
176   std::iota(input_data.begin(), input_data.end(), 0.0f);
177   auto input = builder.AddInstruction(
178       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
179 
180   // Broadcast vector in dimension 3.
181   builder.AddInstruction(HloInstruction::CreateBroadcast(
182       ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
183 
184   // Create HLO module, compile, and execute.
185   auto hlo_module = CreateNewVerifiedModule();
186   hlo_module->AddEntryComputation(builder.Build());
187   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
188 
189   Array4D<float> expected(3, 3, 3, 1025);
190   Array2D<float> yx(3, r1_size);
191   for (int64_t y = 0; y < 3; ++y) {
192     for (int64_t x = 0; x < r1_size; ++x) {
193       yx(y, x) = input_data[x];
194     }
195   }
196   expected.FillWithYX(yx);
197 
198   EXPECT_TRUE(LiteralTestUtil::Near(
199       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
200 }
201 
XLA_TEST_F(BroadcastTest,Broadcast_R1_64_To_R4_32x64x7x7)202 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
203   auto builder = HloComputation::Builder(TestName());
204   Array4D<float> r4_array(32, 64, 7, 7);
205   r4_array.Fill(42.0);
206   std::vector<float> r1_array(64, 42.0);
207 
208   auto input = builder.AddInstruction(
209       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
210 
211   // Broadcast vector in dimension 1.
212   builder.AddInstruction(HloInstruction::CreateBroadcast(
213       ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
214 
215   // Create HLO module, compile, and execute.
216   auto hlo_module = CreateNewVerifiedModule();
217   hlo_module->AddEntryComputation(builder.Build());
218   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
219 
220   EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
221                                     result, error_spec_));
222 }
223 
TEST_F(BroadcastTest,Broadcast_R0_to_R4_64x64x3x3)224 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
225   auto builder = HloComputation::Builder(TestName());
226   auto input = builder.AddInstruction(
227       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
228   builder.AddInstruction(HloInstruction::CreateBroadcast(
229       ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
230 
231   // Create HLO module, compile, and execute.
232   auto hlo_module = CreateNewVerifiedModule();
233   hlo_module->AddEntryComputation(builder.Build());
234   LOG(INFO) << hlo_module->ToString();
235   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
236 
237   Array4D<float> expected(64, 64, 3, 3);
238   expected.Fill(1.0f);
239 
240   EXPECT_TRUE(LiteralTestUtil::Near(
241       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
242 }
243 
TEST_F(BroadcastTest,Broadcast_R2_2x2_To_R4_3x3x2x2)244 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
245   auto builder = HloComputation::Builder(TestName());
246   Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
247   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
248       LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
249 
250   // Broadcast vector in dimensions 2 and 3.
251   builder.AddInstruction(HloInstruction::CreateBroadcast(
252       ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
253 
254   // Create HLO module, compile, and execute.
255   auto hlo_module = CreateNewVerifiedModule();
256   hlo_module->AddEntryComputation(builder.Build());
257   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
258 
259   Array4D<float> expected(3, 3, 2, 2);
260   expected.FillWithYX(to_broadcast);
261 
262   EXPECT_TRUE(LiteralTestUtil::Near(
263       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
264 }
265 
TEST_F(BroadcastTest,Broadcast_R3_2x3x4_to_R4_2x3x4x5)266 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
267   auto builder = HloComputation::Builder(TestName());
268   Array3D<float> input_vals(2, 3, 4);
269   input_vals.FillRandom(1.0);
270 
271   Array4D<float> expected(2, 3, 4, 5);
272   for (int i = 0; i < 2; ++i) {
273     for (int j = 0; j < 3; ++j) {
274       for (int k = 0; k < 4; ++k) {
275         for (int m = 0; m < 5; ++m) {
276           expected(i, j, k, m) = input_vals(i, j, k);
277         }
278       }
279     }
280   }
281   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
282       LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
283 
284   // Broadcast vector in dimensions 2 and 3.
285   builder.AddInstruction(HloInstruction::CreateBroadcast(
286       ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
287 
288   // Create HLO module, compile, and execute.
289   auto hlo_module = CreateNewVerifiedModule();
290   hlo_module->AddEntryComputation(builder.Build());
291   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
292 
293   EXPECT_TRUE(LiteralTestUtil::Near(
294       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
295 }
296 
297 }  // namespace
298 }  // namespace xla
299