1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace xla {
31 namespace {
32
33 class BroadcastTest : public HloTestBase {};
34
XLA_TEST_F(BroadcastTest,BroadcastScalarToScalar)35 XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
36 // Test degenerate case of broadcasting a scalar into a scalar.
37 auto builder = HloComputation::Builder(TestName());
38 auto input = builder.AddInstruction(
39 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
40 builder.AddInstruction(HloInstruction::CreateBroadcast(
41 ShapeUtil::MakeShape(F32, {}), input, {}));
42
43 // Create HLO module, compile, and execute.
44 auto hlo_module = CreateNewVerifiedModule();
45 hlo_module->AddEntryComputation(builder.Build());
46 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
47
48 EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
49 error_spec_));
50 }
51
XLA_TEST_F(BroadcastTest,BroadcastScalarTo2D)52 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
53 auto builder = HloComputation::Builder(TestName());
54 auto input = builder.AddInstruction(
55 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
56 builder.AddInstruction(HloInstruction::CreateBroadcast(
57 ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
58
59 // Create HLO module, compile, and execute.
60 auto hlo_module = CreateNewVerifiedModule();
61 hlo_module->AddEntryComputation(builder.Build());
62 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
63
64 EXPECT_TRUE(LiteralTestUtil::Near(
65 LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
66 error_spec_));
67 }
68
XLA_TEST_F(BroadcastTest,BroadcastVectorTo2D)69 XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
70 auto builder = HloComputation::Builder(TestName());
71 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
72 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
73
74 // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
75 // to enable testing of the results.
76 auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
77 ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
78 auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
79 ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
80 builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
81
82 // Create HLO module, compile, and execute.
83 auto hlo_module = CreateNewVerifiedModule();
84 hlo_module->AddEntryComputation(builder.Build());
85 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
86
87 EXPECT_TRUE(LiteralTestUtil::Near(
88 LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
89 LiteralSlice(result, {0}), error_spec_));
90
91 EXPECT_TRUE(LiteralTestUtil::Near(
92 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
93 LiteralSlice(result, {1}), error_spec_));
94 }
95
XLA_TEST_F(BroadcastTest,Broadcast2DTo2D)96 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
97 auto builder = HloComputation::Builder(TestName());
98 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
99 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
100 builder.AddInstruction(HloInstruction::CreateBroadcast(
101 ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
102
103 // Create HLO module, compile, and execute.
104 auto hlo_module = CreateNewVerifiedModule();
105 hlo_module->AddEntryComputation(builder.Build());
106 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
107
108 EXPECT_TRUE(LiteralTestUtil::Near(
109 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
110 error_spec_));
111 }
112
XLA_TEST_F(BroadcastTest,Broadcast2DTo2DTranspose)113 XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
114 // Degenerately broadcasting a shape into a shape of the same rank reorders
115 // the dimensions, ie transpose.
116 auto builder = HloComputation::Builder(TestName());
117 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
118 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
119 builder.AddInstruction(HloInstruction::CreateBroadcast(
120 ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
121
122 // Create HLO module, compile, and execute.
123 auto hlo_module = CreateNewVerifiedModule();
124 hlo_module->AddEntryComputation(builder.Build());
125 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
126
127 EXPECT_TRUE(LiteralTestUtil::Near(
128 LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
129 error_spec_));
130 }
131
XLA_TEST_F(BroadcastTest,Broadcast2DTo3D)132 XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
133 auto builder = HloComputation::Builder(TestName());
134 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
135 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
136 builder.AddInstruction(HloInstruction::CreateBroadcast(
137 ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
138
139 // Create HLO module, compile, and execute.
140 auto hlo_module = CreateNewVerifiedModule();
141 hlo_module->AddEntryComputation(builder.Build());
142 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
143
144 EXPECT_TRUE(LiteralTestUtil::Near(
145 LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
146 {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
147 result, error_spec_));
148 }
149
TEST_F(BroadcastTest,Broadcast_R1_2_To_R4_2x2x3x3)150 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
151 auto builder = HloComputation::Builder(TestName());
152 auto input = builder.AddInstruction(
153 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
154
155 // Broadcast vector in dimension 1.
156 builder.AddInstruction(HloInstruction::CreateBroadcast(
157 ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
158
159 // Create HLO module, compile, and execute.
160 auto hlo_module = CreateNewVerifiedModule();
161 hlo_module->AddEntryComputation(builder.Build());
162 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
163
164 Array4D<float> expected(2, 2, 3, 3);
165 Array2D<float> pz({{1, 2}, {1, 2}});
166 expected.FillWithPZ(pz);
167
168 EXPECT_TRUE(LiteralTestUtil::Near(
169 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
170 }
171
TEST_F(BroadcastTest,Broadcast_R1_1025_To_R4_3x3x3x1025)172 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
173 auto builder = HloComputation::Builder(TestName());
174 std::vector<float> input_data(1025);
175 int64_t r1_size = input_data.size();
176 std::iota(input_data.begin(), input_data.end(), 0.0f);
177 auto input = builder.AddInstruction(
178 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
179
180 // Broadcast vector in dimension 3.
181 builder.AddInstruction(HloInstruction::CreateBroadcast(
182 ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
183
184 // Create HLO module, compile, and execute.
185 auto hlo_module = CreateNewVerifiedModule();
186 hlo_module->AddEntryComputation(builder.Build());
187 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
188
189 Array4D<float> expected(3, 3, 3, 1025);
190 Array2D<float> yx(3, r1_size);
191 for (int64_t y = 0; y < 3; ++y) {
192 for (int64_t x = 0; x < r1_size; ++x) {
193 yx(y, x) = input_data[x];
194 }
195 }
196 expected.FillWithYX(yx);
197
198 EXPECT_TRUE(LiteralTestUtil::Near(
199 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
200 }
201
XLA_TEST_F(BroadcastTest,Broadcast_R1_64_To_R4_32x64x7x7)202 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
203 auto builder = HloComputation::Builder(TestName());
204 Array4D<float> r4_array(32, 64, 7, 7);
205 r4_array.Fill(42.0);
206 std::vector<float> r1_array(64, 42.0);
207
208 auto input = builder.AddInstruction(
209 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
210
211 // Broadcast vector in dimension 1.
212 builder.AddInstruction(HloInstruction::CreateBroadcast(
213 ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
214
215 // Create HLO module, compile, and execute.
216 auto hlo_module = CreateNewVerifiedModule();
217 hlo_module->AddEntryComputation(builder.Build());
218 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
219
220 EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
221 result, error_spec_));
222 }
223
TEST_F(BroadcastTest,Broadcast_R0_to_R4_64x64x3x3)224 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
225 auto builder = HloComputation::Builder(TestName());
226 auto input = builder.AddInstruction(
227 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
228 builder.AddInstruction(HloInstruction::CreateBroadcast(
229 ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
230
231 // Create HLO module, compile, and execute.
232 auto hlo_module = CreateNewVerifiedModule();
233 hlo_module->AddEntryComputation(builder.Build());
234 LOG(INFO) << hlo_module->ToString();
235 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
236
237 Array4D<float> expected(64, 64, 3, 3);
238 expected.Fill(1.0f);
239
240 EXPECT_TRUE(LiteralTestUtil::Near(
241 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
242 }
243
TEST_F(BroadcastTest,Broadcast_R2_2x2_To_R4_3x3x2x2)244 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
245 auto builder = HloComputation::Builder(TestName());
246 Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
247 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
248 LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
249
250 // Broadcast vector in dimensions 2 and 3.
251 builder.AddInstruction(HloInstruction::CreateBroadcast(
252 ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
253
254 // Create HLO module, compile, and execute.
255 auto hlo_module = CreateNewVerifiedModule();
256 hlo_module->AddEntryComputation(builder.Build());
257 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
258
259 Array4D<float> expected(3, 3, 2, 2);
260 expected.FillWithYX(to_broadcast);
261
262 EXPECT_TRUE(LiteralTestUtil::Near(
263 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
264 }
265
TEST_F(BroadcastTest,Broadcast_R3_2x3x4_to_R4_2x3x4x5)266 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
267 auto builder = HloComputation::Builder(TestName());
268 Array3D<float> input_vals(2, 3, 4);
269 input_vals.FillRandom(1.0);
270
271 Array4D<float> expected(2, 3, 4, 5);
272 for (int i = 0; i < 2; ++i) {
273 for (int j = 0; j < 3; ++j) {
274 for (int k = 0; k < 4; ++k) {
275 for (int m = 0; m < 5; ++m) {
276 expected(i, j, k, m) = input_vals(i, j, k);
277 }
278 }
279 }
280 }
281 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
282 LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
283
284 // Broadcast vector in dimensions 2 and 3.
285 builder.AddInstruction(HloInstruction::CreateBroadcast(
286 ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
287
288 // Create HLO module, compile, and execute.
289 auto hlo_module = CreateNewVerifiedModule();
290 hlo_module->AddEntryComputation(builder.Build());
291 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
292
293 EXPECT_TRUE(LiteralTestUtil::Near(
294 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
295 }
296
297 } // namespace
298 } // namespace xla
299