1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that multi-dimensional arrays can be reduced among various
17 // user-provided dimensions.
18 //
19 // Note that comments for these tests are white-box in that they talk about the
20 // default data layout.
21 //
22 // The test space for reductions is the cartesian product of:
23 //
24 // <possible ranks> x
25 // <possible layouts for chosen rank> x
26 // <possible subsets of dimensions in chosen rank>
27
28 #include <stdlib.h>
29 #include <algorithm>
30 #include <cmath>
31 #include <memory>
32 #include <string>
33 #include <utility>
34 #include <vector>
35
36 #include "absl/algorithm/container.h"
37 #include "absl/strings/str_format.h"
38 #include "absl/strings/str_join.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/array2d.h"
41 #include "tensorflow/compiler/xla/array4d.h"
42 #include "tensorflow/compiler/xla/client/global_data.h"
43 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
44 #include "tensorflow/compiler/xla/client/local_client.h"
45 #include "tensorflow/compiler/xla/client/xla_builder.h"
46 #include "tensorflow/compiler/xla/client/xla_computation.h"
47 #include "tensorflow/compiler/xla/layout_util.h"
48 #include "tensorflow/compiler/xla/literal_util.h"
49 #include "tensorflow/compiler/xla/reference_util.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
54 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
55 #include "tensorflow/compiler/xla/tests/test_macros.h"
56 #include "tensorflow/compiler/xla/util.h"
57 #include "tensorflow/compiler/xla/xla_data.pb.h"
58 #include "tensorflow/core/lib/core/status_test_util.h"
59 #include "tensorflow/core/platform/test.h"
60 #include "tensorflow/core/platform/types.h"
61
62 namespace xla {
63 namespace {
64
65 using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*);
66
67 using FuncGenerator = XlaComputation (*)(XlaBuilder*);
68
69 class ReduceTest : public ClientLibraryTestBase {
70 protected:
ReduceTest()71 ReduceTest() {
72 // Implementation note: laid out z >> y >> x by default.
73 // clang-format off
74 literal_2d_ = LiteralUtil::CreateR2<float>({
75 // x0 x1 x2
76 { 1.f, 2.f, 3.f}, // y0
77 { 4.f, 5.f, 6.f}, // y1
78 });
79 literal_3d_ = LiteralUtil::CreateR3Projected<float>({
80 // x0 x1 x2
81 { 1.f, 2.f, 3.f}, // y0
82 { 4.f, 5.f, 6.f}, // y1
83 }, 4);
84 // clang-format on
85 CHECK(ShapeUtil::Equal(
86 literal_3d_.shape(),
87 ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
88 << literal_3d_.shape().ShortDebugString();
89 }
90
91 // Runs an R1 => R0 reduction test with the given number of elements.
RunR1ToR0Test(int64 element_count)92 void RunR1ToR0Test(int64 element_count) {
93 XlaBuilder builder(TestName());
94 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
95 const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
96 auto input = Parameter(&builder, 0, input_shape, "input");
97 auto zero = ConstantR0<float>(&builder, 0.0);
98 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
99
100 std::vector<float> input_data(element_count);
101 for (int64 i = 0; i < element_count; ++i) {
102 input_data[i] = rand_r(&seed_) % 3;
103 if (rand_r(&seed_) % 2 == 0) {
104 input_data[i] *= -1;
105 }
106 }
107 Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
108 std::unique_ptr<GlobalData> input_global_data =
109 client_->TransferToServer(input_literal).ConsumeValueOrDie();
110
111 float expected = 0.0;
112 for (float item : input_data) {
113 expected += item;
114 }
115 ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
116 ErrorSpec(0.001));
117 }
118
RunR1ToR0PredTest(bool and_reduce,absl::Span<const int> input_data)119 void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
120 const int element_count = input_data.size();
121 XlaBuilder builder(TestName());
122 const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
123 auto input_par = Parameter(&builder, 0, input_shape, "input");
124 auto pred_values =
125 Eq(input_par, ConstantR1<int>(&builder, element_count, 1));
126 XlaOp init_value;
127 XlaComputation reduce;
128 if (and_reduce) {
129 init_value = ConstantR0<bool>(&builder, true);
130 reduce = CreateScalarAndComputation(PRED, &builder);
131 } else {
132 init_value = ConstantR0<bool>(&builder, false);
133 reduce = CreateScalarOrComputation(PRED, &builder);
134 }
135 Reduce(pred_values, init_value, reduce,
136 /*dimensions_to_reduce=*/{0});
137
138 Literal input_literal = LiteralUtil::CreateR1(input_data);
139 std::unique_ptr<GlobalData> input_global_data =
140 client_->TransferToServer(input_literal).ConsumeValueOrDie();
141
142 bool expected = and_reduce;
143 for (bool item : input_data) {
144 if (and_reduce) {
145 expected = expected && item;
146 } else {
147 expected = expected || item;
148 }
149 }
150 ComputeAndCompareR0<bool>(&builder, expected, {input_global_data.get()});
151 }
152
153 // Reduce predicate tensor with dimension rows * cols to dimension cols, to
154 // test the implementation of atomic operations on misaligned small data
155 // types.
156 template <int64 cols>
RunR2ToR1PredTest(bool and_reduce,int64 rows,int64 minor=1,int64 major=0)157 void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1,
158 int64 major = 0) {
159 XlaBuilder builder(TestName());
160 const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
161 auto input = Parameter(&builder, 0, input_shape, "input");
162 auto input_pred = Eq(input, ConstantR0<uint8>(&builder, 1));
163
164 XlaOp init_value;
165 XlaComputation reduce_op;
166 if (and_reduce) {
167 init_value = ConstantR0<bool>(&builder, true);
168 reduce_op = CreateScalarAndComputation(PRED, &builder);
169 } else {
170 init_value = ConstantR0<bool>(&builder, false);
171 reduce_op = CreateScalarOrComputation(PRED, &builder);
172 }
173
174 Reduce(input_pred, init_value, reduce_op,
175 /*dimensions_to_reduce=*/{0});
176
177 Array2D<uint8> input_data(rows, cols);
178 input_data.FillRandom(0, 1);
179 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
180 input_literal =
181 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
182 std::unique_ptr<GlobalData> input_global_data =
183 client_->TransferToServer(input_literal).ConsumeValueOrDie();
184
185 std::array<bool, cols> expected;
186 for (int64 colno = 0; colno < cols; ++colno) {
187 bool column_sum = and_reduce ? true : false;
188 for (int64 rowno = 0; rowno < rows; ++rowno) {
189 if (and_reduce) {
190 column_sum = column_sum && input_data(rowno, colno);
191 } else {
192 column_sum = column_sum || input_data(rowno, colno);
193 }
194 }
195 expected[colno] = column_sum;
196 }
197
198 ComputeAndCompareR1<bool>(&builder, expected, {input_global_data.get()});
199 }
200
201 // Runs an R2 => R0 reduction test with the given number of (rows, cols).
RunR2ToR0Test(int64 rows,int64 cols,int64 minor=1,int64 major=0)202 void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
203 XlaBuilder builder(TestName());
204 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
205 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
206 auto input = Parameter(&builder, 0, input_shape, "input");
207 auto zero = ConstantR0<float>(&builder, 0.0);
208 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
209
210 Array2D<float> input_data(rows, cols);
211 input_data.FillRandom(3.14f, 0.04);
212 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
213 input_literal =
214 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
215 std::unique_ptr<GlobalData> input_global_data =
216 client_->TransferToServer(input_literal).ConsumeValueOrDie();
217
218 float expected = 0.0;
219 for (int64 rowno = 0; rowno < rows; ++rowno) {
220 for (int64 colno = 0; colno < cols; ++colno) {
221 expected += input_data(rowno, colno);
222 }
223 }
224 ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
225 ErrorSpec(0.01, 1e-4));
226 }
227
228 // Runs an R2 => R1 reduction test with the given number of (rows, cols).
RunR2ToR1Test(int64 rows,int64 cols,int64 minor=1,int64 major=0)229 void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
230 XlaBuilder builder(TestName());
231 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
232 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
233 auto input = Parameter(&builder, 0, input_shape, "input");
234 auto zero = ConstantR0<float>(&builder, 0.0);
235 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
236
237 Array2D<float> input_data(rows, cols);
238 input_data.FillRandom(3.14f, 0.04);
239 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
240 input_literal =
241 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
242 std::unique_ptr<GlobalData> input_global_data =
243 client_->TransferToServer(input_literal).ConsumeValueOrDie();
244
245 std::vector<float> expected;
246 for (int64 colno = 0; colno < cols; ++colno) {
247 float column_sum = 0;
248 for (int64 rowno = 0; rowno < rows; ++rowno) {
249 column_sum += input_data(rowno, colno);
250 }
251 expected.push_back(column_sum);
252 }
253 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
254 ErrorSpec(0.01, 1e-4));
255 }
256
257 template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_floating_point<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)258 void ComputeAndCompareGeneric(
259 typename std::enable_if<std::is_floating_point<NativeT>::value,
260 XlaBuilder>::type* builder,
261 absl::Span<const NativeT> expected,
262 absl::Span<GlobalData* const> arguments) {
263 ComputeAndCompareR1<NativeT>(builder, expected, arguments,
264 ErrorSpec(0.01, 1e-4));
265 }
266
267 template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_integral<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)268 void ComputeAndCompareGeneric(
269 typename std::enable_if<std::is_integral<NativeT>::value,
270 XlaBuilder>::type* builder,
271 absl::Span<const NativeT> expected,
272 absl::Span<GlobalData* const> arguments) {
273 ComputeAndCompareR1<NativeT>(builder, expected, arguments);
274 }
275
276 template <typename NativeT>
RunVectorizedReduceTestForType(const std::function<XlaComputation (XlaBuilder *)> & reduction_function_generator,const std::function<NativeT (NativeT,NativeT)> & reference_reduction_function,const NativeT & initial_value)277 void RunVectorizedReduceTestForType(
278 const std::function<XlaComputation(XlaBuilder*)>&
279 reduction_function_generator,
280 const std::function<NativeT(NativeT, NativeT)>&
281 reference_reduction_function,
282 const NativeT& initial_value) {
283 const int rows = 64, cols = 128;
284 const int minor = 1, major = 0;
285 XlaBuilder builder(TestName());
286 XlaComputation reduction_function = reduction_function_generator(&builder);
287 const Shape input_shape = ShapeUtil::MakeShape(
288 xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
289 auto input = Parameter(&builder, 0, input_shape, "input");
290 auto zero = ConstantR0<NativeT>(&builder, initial_value);
291 Reduce(input, zero, reduction_function,
292 /*dimensions_to_reduce=*/{0});
293
294 Array2D<NativeT> input_data(rows, cols);
295 input_data.FillUnique(initial_value);
296 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
297 input_literal =
298 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
299 std::unique_ptr<GlobalData> input_global_data =
300 client_->TransferToServer(input_literal).ConsumeValueOrDie();
301
302 // NativeT can be bool, and std::vector<bool> does not convert to
303 // Span.
304 std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
305 for (int64 colno = 0; colno < cols; ++colno) {
306 NativeT column_result = initial_value;
307 for (int64 rowno = 0; rowno < rows; ++rowno) {
308 column_result = reference_reduction_function(column_result,
309 input_data(rowno, colno));
310 }
311 expected[colno] = column_result;
312 }
313
314 ComputeAndCompareGeneric<NativeT>(
315 &builder, absl::Span<const NativeT>(expected.get(), cols),
316 {input_global_data.get()});
317 }
318
RunVectorizedReduceTest(const std::function<XlaComputation (PrimitiveType,XlaBuilder *)> & reduction_function_generator_for_type,const std::function<float (float,float)> & reference_reduction_function_for_floats,const std::function<int32 (int32,int32)> & reference_reduction_function_for_ints,const std::function<uint32 (uint32,uint32)> & reference_reduction_function_for_uints,float floating_point_identity,int32 signed_int_identity,uint32 unsigned_int_identity)319 void RunVectorizedReduceTest(
320 const std::function<XlaComputation(PrimitiveType, XlaBuilder*)>&
321 reduction_function_generator_for_type,
322 const std::function<float(float, float)>&
323 reference_reduction_function_for_floats,
324 const std::function<int32(int32, int32)>&
325 reference_reduction_function_for_ints,
326 const std::function<uint32(uint32, uint32)>&
327 reference_reduction_function_for_uints,
328 float floating_point_identity, int32 signed_int_identity,
329 uint32 unsigned_int_identity) {
330 // Float version
331 RunVectorizedReduceTestForType<float>(
332 [&](XlaBuilder* builder) {
333 return reduction_function_generator_for_type(F32, builder);
334 },
335 reference_reduction_function_for_floats, floating_point_identity);
336
337 // Signed int version
338 RunVectorizedReduceTestForType<int32>(
339 [&](XlaBuilder* builder) {
340 return reduction_function_generator_for_type(S32, builder);
341 },
342 reference_reduction_function_for_ints, signed_int_identity);
343
344 // Unsigned int version
345 RunVectorizedReduceTestForType<uint32>(
346 [&](XlaBuilder* builder) {
347 return reduction_function_generator_for_type(U32, builder);
348 },
349 reference_reduction_function_for_uints, unsigned_int_identity);
350 }
351
352 Literal literal_2d_;
353 Literal literal_3d_;
354 uint32 seed_ = 0xdeadbeef;
355 };
356
XLA_TEST_F(ReduceTest,ReduceR1_0_F32_To_R0)357 XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest,ReduceR1_1_F32_To_R0)358 XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest,ReduceR1_2_F32_To_R0)359 XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest,ReduceR1_16_F32_To_R0)360 XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
XLA_TEST_F(ReduceTest,ReduceR1_128_F32_To_R0)361 XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest,ReduceR1_129_F32_To_R0)362 XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
XLA_TEST_F(ReduceTest,ReduceR1_240_F32_To_R0)363 XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest,ReduceR1_256_F32_To_R0)364 XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest,ReduceR1_1024_F32_To_R0)365 XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest,ReduceR1_2048_F32_To_R0)366 XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
XLA_TEST_F(ReduceTest,ReduceR1_16K_F32_To_R0)367 XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16KP1_F32_To_R0)368 XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
369 RunR1ToR0Test(16 * 1024 + 1);
370 }
XLA_TEST_F(ReduceTest,ReduceR1_64K_F32_To_R0)371 XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_1M_F32_To_R0)372 XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16M_F32_To_R0)373 XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
374
XLA_TEST_F(ReduceTest,ReduceR2_0x0_To_R0)375 XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R0)376 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R0)377 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
XLA_TEST_F(ReduceTest,ReduceR2_2x0_To_R0)378 XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R0)379 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R0)380 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R0)381 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R0)382 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R0)383 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R0)384 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
385 RunR2ToR0Test(111, 50, 0, 1);
386 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R0)387 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R0)388 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
389
390 // Disabled due to b/33245142. Failed on 2016-11-30.
391 // XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R1)392 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R1)393 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
394 // Disabled due to b/33245142. Failed on 2016-11-30.
395 // XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R1)396 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R1)397 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R1)398 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R1)399 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R1)400 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R1)401 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
402 RunR2ToR1Test(111, 50, 0, 1);
403 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R1)404 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R1)405 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
406
XLA_TEST_F(ReduceTest,AndReduceAllOnesR1_10_Pred)407 XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) {
408 constexpr int element_count = 10;
409 std::vector<int> input(element_count, 1);
410 RunR1ToR0PredTest(/*and_reduce=*/true, input);
411 }
412
XLA_TEST_F(ReduceTest,AndReduceOnesAndZerosR1_10_Pred)413 XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) {
414 constexpr int element_count = 10;
415 std::vector<int> input(element_count);
416 for (int i = 0; i < element_count; ++i) {
417 input[i] = i % 2;
418 }
419 RunR1ToR0PredTest(/*and_reduce=*/true, input);
420 }
421
XLA_TEST_F(ReduceTest,OrReduceAllOnesR1_10_Pred)422 XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) {
423 constexpr int element_count = 10;
424 std::vector<int> input(element_count, 1);
425 RunR1ToR0PredTest(/*and_reduce=*/false, input);
426 }
427
XLA_TEST_F(ReduceTest,OrReduceOnesAndZerosR1_10_Pred)428 XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) {
429 constexpr int element_count = 10;
430 std::vector<int> input(element_count);
431 for (int i = 0; i < element_count; ++i) {
432 input[i] = i % 2;
433 }
434 RunR1ToR0PredTest(/*and_reduce=*/false, input);
435 }
436
XLA_TEST_F(ReduceTest,ReduceElementwiseR2_111x50_To_R1)437 XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
438 const int64 rows = 111, cols = 50;
439
440 XlaBuilder builder(TestName());
441 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
442 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
443 auto input = Parameter(&builder, 0, input_shape, "input");
444 auto zero = ConstantR0<float>(&builder, 0.0);
445 auto log_ = Log(input);
446 Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
447
448 Array2D<float> input_data(rows, cols);
449 input_data.FillRandom(3.14f, 0.04);
450 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
451 input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
452 std::unique_ptr<GlobalData> input_global_data =
453 client_->TransferToServer(input_literal).ConsumeValueOrDie();
454
455 std::vector<float> expected;
456 for (int64 colno = 0; colno < cols; ++colno) {
457 float column_sum = 0;
458 for (int64 rowno = 0; rowno < rows; ++rowno) {
459 column_sum += std::log(input_data(rowno, colno));
460 }
461 expected.push_back(column_sum);
462 }
463 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
464 ErrorSpec(0.01, 1e-4));
465 }
466
XLA_TEST_F(ReduceTest,TransposeAndReduceElementwiseR2_111x50_To_R1)467 XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
468 const int64 rows = 111, cols = 50;
469
470 XlaBuilder builder(TestName());
471 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
472 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
473 auto input = Parameter(&builder, 0, input_shape, "input");
474 auto zero = ConstantR0<float>(&builder, 0.0);
475 auto log_ = Log(input);
476 auto transpose = Transpose(log_, {1, 0});
477 Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
478
479 Array2D<float> input_data(rows, cols);
480 input_data.FillRandom(3.14f, 0.04);
481 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
482 input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
483 std::unique_ptr<GlobalData> input_global_data =
484 client_->TransferToServer(input_literal).ConsumeValueOrDie();
485
486 std::vector<float> expected;
487 for (int64 colno = 0; colno < cols; ++colno) {
488 float column_sum = 0;
489 for (int64 rowno = 0; rowno < rows; ++rowno) {
490 column_sum += std::log(input_data(rowno, colno));
491 }
492 expected.push_back(column_sum);
493 }
494 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
495 ErrorSpec(0.01, 1e-4));
496 }
497
498 // Test that algebraic simplifier does not incorrectly fold a transpose into a
499 // reduction operation.
XLA_TEST_F(ReduceTest,TransposeAndReduceR3_12x111x50_To_R2)500 XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
501 XlaBuilder builder(TestName());
502 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
503 const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
504 XlaOp input = Parameter(&builder, 0, input_shape, "input");
505 XlaOp zero = ConstantR0<float>(&builder, 0.0);
506 XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
507 Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
508
509 TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
510
511 ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
512 }
513
XLA_TEST_F(ReduceTest,Reshape_111x2x25Reduce_111x50_To_R1)514 XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
515 const int64 rows = 111, cols = 50;
516
517 XlaBuilder builder(TestName());
518 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
519 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
520 auto input = Parameter(&builder, 0, input_shape, "input");
521 auto zero = ConstantR0<float>(&builder, 0.0);
522 auto log_ = Tanh(input);
523 auto reshape = Reshape(log_, {rows, cols});
524 Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
525
526 Array3D<float> input_data(rows, 2, cols / 2);
527 input_data.FillRandom(3.14f, 0.04);
528 Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
529 std::unique_ptr<GlobalData> input_global_data =
530 client_->TransferToServer(input_literal).ConsumeValueOrDie();
531
532 std::vector<float> expected;
533 for (int64 major = 0; major < 2; ++major) {
534 for (int64 colno = 0; colno < cols / 2; ++colno) {
535 float column_sum = 0;
536 for (int64 rowno = 0; rowno < rows; ++rowno) {
537 column_sum += std::tanh(input_data(rowno, major, colno));
538 }
539 expected.push_back(column_sum);
540 }
541 }
542 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
543 ErrorSpec(0.01, 1e-4));
544 }
545
546 struct BoundsLayout {
547 std::vector<int64> bounds;
548 std::vector<int64> layout;
549 std::vector<int64> reduce_dims;
550 };
551
PrintTo(const BoundsLayout & spec,std::ostream * os)552 void PrintTo(const BoundsLayout& spec, std::ostream* os) {
553 *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(),
554 spec.bounds.size() - spec.reduce_dims.size(),
555 absl::StrJoin(spec.bounds, "x"),
556 absl::StrJoin(spec.layout, ""),
557 absl::StrJoin(spec.reduce_dims, ""));
558 }
559
560 // Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,AddReduce2DScalarToR0)561 XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
562 XlaBuilder builder(TestName());
563 auto add = CreateScalarAddComputation(F32, &builder);
564 auto scalar = ConstantR0<float>(&builder, 42.0);
565 auto broadcasted = Broadcast(scalar, {500, 500});
566 Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
567
568 float expected = 42.0f * static_cast<float>(500 * 500);
569 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
570 }
571
572 // Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DScalarToR0)573 XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
574 XlaBuilder builder(TestName());
575 auto max = CreateScalarMaxComputation(F32, &builder);
576 auto scalar = ConstantR0<float>(&builder, 42.0);
577 auto broadcasted = Broadcast(scalar, {500, 500});
578 Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), max, {0, 1});
579
580 float expected = 42.0f;
581 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
582 }
583
584 // Max-reduces a matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DToR0)585 XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
586 XlaBuilder builder(TestName());
587 auto max = CreateScalarMaxComputation(F32, &builder);
588 Array2D<float> input(300, 250);
589 input.FillRandom(214.0f);
590 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
591 Reduce(ConstantLiteral(&builder, input_literal),
592 ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
593 auto input_max = FLT_MIN;
594 input.Each(
595 [&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
596 ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
597 }
598
599 // Min-reduces matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MinReduce2DToR0)600 XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
601 XlaBuilder builder(TestName());
602 auto min = CreateScalarMinComputation(F32, &builder);
603 Array2D<float> input(150, 130);
604 input.FillRandom(214.0f);
605 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
606 Reduce(ConstantLiteral(&builder, input_literal),
607 ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
608
609 auto input_min = FLT_MAX;
610 input.Each(
611 [&](int64, int64, float* v) { input_min = std::min(input_min, *v); });
612 ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
613 }
614
XLA_TEST_F(ReduceTest,UnsignedInt_MinReduce)615 XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
616 XlaBuilder builder(TestName());
617 Array2D<uint32> input({{1}, {2}});
618 auto min = CreateScalarMinComputation(U32, &builder);
619 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
620 auto initial_value =
621 ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
622
623 Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
624 ComputeAndCompareR0<uint32>(&builder, 1, {});
625 }
626
XLA_TEST_F(ReduceTest,UnsignedInt_MaxReduce)627 XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
628 XlaBuilder builder(TestName());
629 Array2D<uint32> input({{1}, {2}});
630 auto max = CreateScalarMaxComputation(U32, &builder);
631 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
632 auto initial_value =
633 ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
634
635 Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
636 ComputeAndCompareR0<uint32>(&builder, 2, {});
637 }
638
639 // Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest,Reduce2DAmong1)640 XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
641 XlaBuilder builder(TestName());
642 auto m = ConstantLiteral(&builder, literal_2d_);
643 auto add = CreateScalarAddComputation(F32, &builder);
644 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
645
646 std::vector<float> expected = {6.f, 15.f};
647 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
648 }
649
XLA_TEST_F(ReduceTest,Reduce2DAmong0and1)650 XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
651 // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
652 XlaBuilder builder(TestName());
653 auto m = ConstantLiteral(&builder, literal_2d_);
654 auto add = CreateScalarAddComputation(F32, &builder);
655 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
656
657 ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
658 }
659
660 // Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest,Reduce2DAmongY)661 XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
662 XlaBuilder builder("reduce_among_y");
663 auto m = ConstantLiteral(&builder, literal_2d_);
664 auto add = CreateScalarAddComputation(F32, &builder);
665 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
666
667 std::vector<float> expected = {5.f, 7.f, 9.f};
668 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
669 }
670
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_1_2)671 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
672 XlaBuilder builder(TestName());
673 auto m = ConstantLiteral(&builder, literal_3d_);
674 auto add = CreateScalarAddComputation(F32, &builder);
675 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
676
677 std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
678 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
679 }
680
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_0_1)681 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
682 XlaBuilder builder(TestName());
683 auto m = ConstantLiteral(&builder, literal_3d_);
684 auto add = CreateScalarAddComputation(F32, &builder);
685 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
686
687 std::vector<float> expected = {20.f, 28.f, 36.f};
688 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
689 }
690
XLA_TEST_F(ReduceTest,ReduceR3ToR0)691 XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
692 XlaBuilder builder(TestName());
693 auto m = ConstantLiteral(&builder, literal_3d_);
694 auto add = CreateScalarAddComputation(F32, &builder);
695 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
696
697 float expected = 21.0f * 4.0;
698 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
699 }
700
XLA_TEST_F(ReduceTest,ReduceR3AmongDim0)701 XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
702 XlaBuilder builder(TestName());
703 auto m = ConstantLiteral(&builder, literal_3d_);
704 auto add = CreateScalarAddComputation(F32, &builder);
705 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
706
707 // clang-format off
708 Array2D<float> expected({
709 {4.f, 8.f, 12.f},
710 {16.f, 20.f, 24.f},
711 });
712 // clang-format on
713 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
714 }
715
XLA_TEST_F(ReduceTest,ReduceR3AmongDim1)716 XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
717 XlaBuilder builder(TestName());
718 auto m = ConstantLiteral(&builder, literal_3d_);
719 auto add = CreateScalarAddComputation(F32, &builder);
720 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
721
722 // clang-format off
723 Array2D<float> expected({
724 {5.f, 7.f, 9.f},
725 {5.f, 7.f, 9.f},
726 {5.f, 7.f, 9.f},
727 {5.f, 7.f, 9.f},
728 });
729 // clang-format on
730 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
731 }
732
XLA_TEST_F(ReduceTest,ReduceR3AmongDim2)733 XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
734 XlaBuilder builder(TestName());
735 auto m = ConstantLiteral(&builder, literal_3d_);
736 auto add = CreateScalarAddComputation(F32, &builder);
737 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
738
739 // clang-format off
740 Array2D<float> expected({
741 {6.f, 15.f},
742 {6.f, 15.f},
743 {6.f, 15.f},
744 {6.f, 15.f},
745 });
746 // clang-format on
747 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
748 }
749
XLA_TEST_F(ReduceTest,VectorizedReduce_Add)750 XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
751 RunVectorizedReduceTest(
752 static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
753 [](float a, float b) { return a + b; },
754 [](int32 a, int32 b) {
755 return static_cast<int32>(static_cast<uint32>(a) +
756 static_cast<uint32>(b));
757 },
758 [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
759 }
760
XLA_TEST_F(ReduceTest,VectorizedReduce_Multiply)761 XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
762 RunVectorizedReduceTest(
763 static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
764 [](float a, float b) { return a * b; },
765 [](int32 a, int32 b) {
766 return static_cast<int32>(static_cast<uint32>(a) *
767 static_cast<uint32>(b));
768 },
769 [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
770 }
771
XLA_TEST_F(ReduceTest,VectorizedReduce_Max)772 XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
773 RunVectorizedReduceTest(
774 static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
775 [](float a, float b) { return std::max(a, b); },
776 [](int32 a, int32 b) { return std::max(a, b); },
777 [](uint32 a, uint32 b) { return std::max(a, b); },
778 std::numeric_limits<float>::min(), std::numeric_limits<int32>::min(),
779 std::numeric_limits<uint32>::min());
780 }
781
XLA_TEST_F(ReduceTest,VectorizedReduce_Min)782 XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
783 RunVectorizedReduceTest(
784 static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
785 [](float a, float b) { return std::min(a, b); },
786 [](int32 a, int32 b) { return std::min(a, b); },
787 [](uint32 a, uint32 b) { return std::min(a, b); },
788 std::numeric_limits<float>::max(), std::numeric_limits<int32>::max(),
789 std::numeric_limits<uint32>::max());
790 }
791
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanAnd)792 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
793 RunVectorizedReduceTestForType<bool>(
794 static_cast<FuncGenerator>([](XlaBuilder* builder) {
795 return CreateScalarAndComputation(PRED, builder);
796 }),
797 [](bool a, bool b) { return a && b; }, true);
798 }
799
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanOr)800 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
801 RunVectorizedReduceTestForType<bool>(
802 static_cast<FuncGenerator>([](XlaBuilder* builder) {
803 return CreateScalarOrComputation(PRED, builder);
804 }),
805 [](bool a, bool b) { return a || b; }, false);
806 }
807
808 class ReduceR3ToR2Test : public ReduceTest,
809 public ::testing::WithParamInterface<BoundsLayout> {};
810
XLA_TEST_P(ReduceR3ToR2Test,ReduceR3ToR2)811 XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
812 XlaBuilder builder(TestName());
813 const auto& bounds = GetParam().bounds;
814 Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
815 // input_array.FillRandom(3.14f, 0.05);
816 input_array.Fill(1.0f);
817
818 auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
819 input_literal =
820 input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
821 std::unique_ptr<GlobalData> input_data =
822 client_->TransferToServer(input_literal).ConsumeValueOrDie();
823
824 auto input_activations =
825 Parameter(&builder, 0, input_literal.shape(), "input");
826 XlaComputation add = CreateScalarAddComputation(F32, &builder);
827 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
828 GetParam().reduce_dims);
829
830 auto expected =
831 ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
832 [](float a, float b) { return a + b; });
833
834 ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
835 ErrorSpec(1e-3, 1e-3));
836 }
837
838 INSTANTIATE_TEST_CASE_P(
839 ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
840 // Specifies (shape, layout, reduction dimensions).
841 ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
842 BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
843 BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
844 // These should be simplified into a reshape.
845 BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
846 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
847 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
848 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
849 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
850 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
851 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
852 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
853 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
854 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
855 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
856 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
857 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
858 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
859 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
860 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
861
XLA_TEST_F(ReduceTest,OperationOnConstantAsInitValue)862 XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
863 XlaBuilder builder(TestName());
864 XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
865
866 auto a = ConstantR0<float>(&builder, 2.0f);
867 auto a2 = Abs(a);
868
869 Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
870 std::unique_ptr<GlobalData> b_data =
871 client_->TransferToServer(b_literal).ConsumeValueOrDie();
872 auto b = Parameter(&builder, 0, b_literal.shape(), "b");
873 Reduce(b, a2, max_f32, {0});
874
875 ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
876 }
877
XLA_TEST_F(ReduceTest,ReduceAndPredR2_128x64_To_R1)878 XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) {
879 RunR2ToR1PredTest</*cols=64*/ 64>(/*and_reduce=true*/ true, /*rows=128*/ 128);
880 }
XLA_TEST_F(ReduceTest,ReduceOrPredR2_64x32_To_R1)881 XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) {
882 RunR2ToR1PredTest</*cols=32*/ 32>(/*and_reduce=false*/ false, /*rows=64*/ 64);
883 }
884
885 // Tests reductions with different initial values. There's no test macro that
886 // combines TYPED_TEST and TYPED_P, so we have to do it manually.
887 class ReduceInitializerTest : public ReduceTest {
888 protected:
889 template <typename T>
DoTest(T initializer,int num_elems)890 void DoTest(T initializer, int num_elems) {
891 XlaBuilder builder(TestName());
892 XlaComputation max_fn = CreateScalarMaxComputation(
893 primitive_util::NativeToPrimitiveType<T>(), &builder);
894
895 auto init = ConstantR0<T>(&builder, initializer);
896 std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
897 auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
898 auto input_data =
899 client_->TransferToServer(input_literal).ConsumeValueOrDie();
900 Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
901 {0});
902
903 ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
904 }
905 };
906
XLA_TEST_F(ReduceInitializerTest,U8Small)907 XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest<uint8>(42, 2); }
908
XLA_TEST_F(ReduceInitializerTest,U8BigPowerOf2)909 XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest<uint8>(42, 4096); }
910
XLA_TEST_F(ReduceInitializerTest,U8InitializerBigNonPowerOf2)911 XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) {
912 DoTest<uint8>(42, 4095);
913 }
914
XLA_TEST_F(ReduceInitializerTest,U64InitializerZero)915 XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) {
916 DoTest<uint64>(0, 1024);
917 }
918
XLA_TEST_F(ReduceInitializerTest,U64InitializerOne)919 XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) {
920 DoTest<uint64>(1, 1024);
921 }
922
XLA_TEST_F(ReduceInitializerTest,U64InitializerBigValue)923 XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
924 DoTest<uint64>(1234556789123, 1024);
925 }
926
927 // Test the operational semantic that the init value is passed on the lhs for
928 // reduces. Can be tested by performing an "identity" reduce (that simply
929 // returns one of the parameters). In this case, we return the rhs, which for
930 // a 1D array with one element, should not be the init value.
XLA_TEST_F(ReduceTest,ReduceIdentity)931 XLA_TEST_F(ReduceTest, ReduceIdentity) {
932 XlaBuilder builder(TestName());
933 Shape single_float = ShapeUtil::MakeShape(F32, {});
934 Parameter(&builder, 0, single_float, "lhs-unused");
935 Parameter(&builder, 1, single_float, "rhs-used");
936 auto computation_status = builder.Build();
937 TF_ASSERT_OK(computation_status.status());
938
939 Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
940 Reduce(Parameter(&builder, 0, operand_shape, "operand"),
941 Parameter(&builder, 1, single_float, "init"),
942 computation_status.ValueOrDie(), {0});
943
944 float operand[] = {42.0f};
945 float init = 58.5f;
946 float expected = 42.0f;
947 Literal input_literal = LiteralUtil::CreateR1<float>(operand);
948 std::unique_ptr<GlobalData> input_global_data =
949 client_->TransferToServer(input_literal).ConsumeValueOrDie();
950 Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
951 std::unique_ptr<GlobalData> input_global_data2 =
952 client_->TransferToServer(input_literal2).ConsumeValueOrDie();
953 ComputeAndCompareR0<float>(
954 &builder, expected, {input_global_data.get(), input_global_data2.get()},
955 ErrorSpec(0.0001));
956 }
957
XLA_TEST_F(ReduceTest,AndReduceU64)958 XLA_TEST_F(ReduceTest, AndReduceU64) {
959 XlaBuilder builder(TestName());
960 Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
961 {0XFFFFFFFFFFFFFFD6LL, 101},
962 {1, 0XFFFFFFFFFFFFFFFFLL}};
963 auto reducer = CreateScalarAndComputation(U64, &builder);
964 auto m = ConstantR2FromArray2D(&builder, initializer);
965 Reduce(m, ConstantR0<uint64>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1});
966
967 std::vector<uint64> expected = {0x1204461080145890LL, 68, 1};
968 ComputeAndCompareR1<uint64>(&builder, expected, {});
969 }
970
XLA_TEST_F(ReduceTest,OrReduceU64)971 XLA_TEST_F(ReduceTest, OrReduceU64) {
972 XlaBuilder builder(TestName());
973 Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
974 {0xFFFFFFFFFFFFFFD6LL, 101},
975 {1, 0xCAFEBEEFABABABABLL}};
976 auto reducer = CreateScalarOrComputation(U64, &builder);
977 auto m = ConstantR2FromArray2D(&builder, initializer);
978 Reduce(m, ConstantR0<uint64>(&builder, 0), reducer, {1});
979
980 std::vector<uint64> expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL,
981 0xCAFEBEEFABABABABLL};
982 ComputeAndCompareR1<uint64>(&builder, expected, {});
983 }
984
XLA_TEST_F(ReduceTest,R0ReduceInDisguise)985 XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
986 XlaBuilder builder(TestName());
987 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
988 constexpr int element_count = 127;
989 const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1});
990 auto input = Parameter(&builder, 0, input_shape, "input");
991 auto zero = ConstantR0<float>(&builder, 0.0);
992 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
993
994 Array2D<float> input_data(element_count, 1);
995 input_data.FillRandom(3.0f);
996 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
997 std::unique_ptr<GlobalData> input_global_data =
998 client_->TransferToServer(input_literal).ConsumeValueOrDie();
999
1000 float expected = absl::c_accumulate(input_data, 0.0f);
1001 ComputeAndCompareR1<float>(&builder, {expected}, {input_global_data.get()},
1002 ErrorSpec(0.001));
1003 }
1004
1005 } // namespace
1006 } // namespace xla
1007