• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that multi-dimensional arrays can be reduced among various
17 // user-provided dimensions.
18 //
19 // Note that comments for these tests are white-box in that they talk about the
20 // default data layout.
21 //
22 // The test space for reductions is the cartesian product of:
23 //
24 //    <possible ranks> x
25 //    <possible layouts for chosen rank> x
26 //    <possible subsets of dimensions in chosen rank>
27 
28 #include <stdlib.h>
29 #include <algorithm>
30 #include <cmath>
31 #include <memory>
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 #include "absl/algorithm/container.h"
37 #include "absl/strings/str_format.h"
38 #include "absl/strings/str_join.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/array2d.h"
41 #include "tensorflow/compiler/xla/array4d.h"
42 #include "tensorflow/compiler/xla/client/global_data.h"
43 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
44 #include "tensorflow/compiler/xla/client/local_client.h"
45 #include "tensorflow/compiler/xla/client/xla_builder.h"
46 #include "tensorflow/compiler/xla/client/xla_computation.h"
47 #include "tensorflow/compiler/xla/layout_util.h"
48 #include "tensorflow/compiler/xla/literal_util.h"
49 #include "tensorflow/compiler/xla/reference_util.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
54 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
55 #include "tensorflow/compiler/xla/tests/test_macros.h"
56 #include "tensorflow/compiler/xla/util.h"
57 #include "tensorflow/compiler/xla/xla_data.pb.h"
58 #include "tensorflow/core/lib/core/status_test_util.h"
59 #include "tensorflow/core/platform/test.h"
60 #include "tensorflow/core/platform/types.h"
61 
62 namespace xla {
63 namespace {
64 
65 using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*);
66 
67 using FuncGenerator = XlaComputation (*)(XlaBuilder*);
68 
69 class ReduceTest : public ClientLibraryTestBase {
70  protected:
ReduceTest()71   ReduceTest() {
72     // Implementation note: laid out z >> y >> x by default.
73     // clang-format off
74     literal_2d_ = LiteralUtil::CreateR2<float>({
75       // x0   x1   x2
76       { 1.f, 2.f, 3.f},  // y0
77       { 4.f, 5.f, 6.f},  // y1
78     });
79     literal_3d_ = LiteralUtil::CreateR3Projected<float>({
80       // x0   x1   x2
81       { 1.f, 2.f, 3.f},  // y0
82       { 4.f, 5.f, 6.f},  // y1
83     }, 4);
84     // clang-format on
85     CHECK(ShapeUtil::Equal(
86         literal_3d_.shape(),
87         ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
88         << literal_3d_.shape().ShortDebugString();
89   }
90 
91   // Runs an R1 => R0 reduction test with the given number of elements.
RunR1ToR0Test(int64 element_count)92   void RunR1ToR0Test(int64 element_count) {
93     XlaBuilder builder(TestName());
94     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
95     const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
96     auto input = Parameter(&builder, 0, input_shape, "input");
97     auto zero = ConstantR0<float>(&builder, 0.0);
98     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
99 
100     std::vector<float> input_data(element_count);
101     for (int64 i = 0; i < element_count; ++i) {
102       input_data[i] = rand_r(&seed_) % 3;
103       if (rand_r(&seed_) % 2 == 0) {
104         input_data[i] *= -1;
105       }
106     }
107     Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
108     std::unique_ptr<GlobalData> input_global_data =
109         client_->TransferToServer(input_literal).ConsumeValueOrDie();
110 
111     float expected = 0.0;
112     for (float item : input_data) {
113       expected += item;
114     }
115     ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
116                                ErrorSpec(0.001));
117   }
118 
RunR1ToR0PredTest(bool and_reduce,absl::Span<const int> input_data)119   void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
120     const int element_count = input_data.size();
121     XlaBuilder builder(TestName());
122     const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
123     auto input_par = Parameter(&builder, 0, input_shape, "input");
124     auto pred_values =
125         Eq(input_par, ConstantR1<int>(&builder, element_count, 1));
126     XlaOp init_value;
127     XlaComputation reduce;
128     if (and_reduce) {
129       init_value = ConstantR0<bool>(&builder, true);
130       reduce = CreateScalarAndComputation(PRED, &builder);
131     } else {
132       init_value = ConstantR0<bool>(&builder, false);
133       reduce = CreateScalarOrComputation(PRED, &builder);
134     }
135     Reduce(pred_values, init_value, reduce,
136            /*dimensions_to_reduce=*/{0});
137 
138     Literal input_literal = LiteralUtil::CreateR1(input_data);
139     std::unique_ptr<GlobalData> input_global_data =
140         client_->TransferToServer(input_literal).ConsumeValueOrDie();
141 
142     bool expected = and_reduce;
143     for (bool item : input_data) {
144       if (and_reduce) {
145         expected = expected && item;
146       } else {
147         expected = expected || item;
148       }
149     }
150     ComputeAndCompareR0<bool>(&builder, expected, {input_global_data.get()});
151   }
152 
153   // Reduce predicate tensor with dimension rows * cols to dimension cols, to
154   // test the implementation of atomic operations on misaligned small data
155   // types.
156   template <int64 cols>
RunR2ToR1PredTest(bool and_reduce,int64 rows,int64 minor=1,int64 major=0)157   void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1,
158                          int64 major = 0) {
159     XlaBuilder builder(TestName());
160     const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
161     auto input = Parameter(&builder, 0, input_shape, "input");
162     auto input_pred = Eq(input, ConstantR0<uint8>(&builder, 1));
163 
164     XlaOp init_value;
165     XlaComputation reduce_op;
166     if (and_reduce) {
167       init_value = ConstantR0<bool>(&builder, true);
168       reduce_op = CreateScalarAndComputation(PRED, &builder);
169     } else {
170       init_value = ConstantR0<bool>(&builder, false);
171       reduce_op = CreateScalarOrComputation(PRED, &builder);
172     }
173 
174     Reduce(input_pred, init_value, reduce_op,
175            /*dimensions_to_reduce=*/{0});
176 
177     Array2D<uint8> input_data(rows, cols);
178     input_data.FillRandom(0, 1);
179     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
180     input_literal =
181         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
182     std::unique_ptr<GlobalData> input_global_data =
183         client_->TransferToServer(input_literal).ConsumeValueOrDie();
184 
185     std::array<bool, cols> expected;
186     for (int64 colno = 0; colno < cols; ++colno) {
187       bool column_sum = and_reduce ? true : false;
188       for (int64 rowno = 0; rowno < rows; ++rowno) {
189         if (and_reduce) {
190           column_sum = column_sum && input_data(rowno, colno);
191         } else {
192           column_sum = column_sum || input_data(rowno, colno);
193         }
194       }
195       expected[colno] = column_sum;
196     }
197 
198     ComputeAndCompareR1<bool>(&builder, expected, {input_global_data.get()});
199   }
200 
201   // Runs an R2 => R0 reduction test with the given number of (rows, cols).
RunR2ToR0Test(int64 rows,int64 cols,int64 minor=1,int64 major=0)202   void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
203     XlaBuilder builder(TestName());
204     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
205     const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
206     auto input = Parameter(&builder, 0, input_shape, "input");
207     auto zero = ConstantR0<float>(&builder, 0.0);
208     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
209 
210     Array2D<float> input_data(rows, cols);
211     input_data.FillRandom(3.14f, 0.04);
212     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
213     input_literal =
214         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
215     std::unique_ptr<GlobalData> input_global_data =
216         client_->TransferToServer(input_literal).ConsumeValueOrDie();
217 
218     float expected = 0.0;
219     for (int64 rowno = 0; rowno < rows; ++rowno) {
220       for (int64 colno = 0; colno < cols; ++colno) {
221         expected += input_data(rowno, colno);
222       }
223     }
224     ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
225                                ErrorSpec(0.01, 1e-4));
226   }
227 
228   // Runs an R2 => R1 reduction test with the given number of (rows, cols).
RunR2ToR1Test(int64 rows,int64 cols,int64 minor=1,int64 major=0)229   void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
230     XlaBuilder builder(TestName());
231     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
232     const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
233     auto input = Parameter(&builder, 0, input_shape, "input");
234     auto zero = ConstantR0<float>(&builder, 0.0);
235     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
236 
237     Array2D<float> input_data(rows, cols);
238     input_data.FillRandom(3.14f, 0.04);
239     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
240     input_literal =
241         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
242     std::unique_ptr<GlobalData> input_global_data =
243         client_->TransferToServer(input_literal).ConsumeValueOrDie();
244 
245     std::vector<float> expected;
246     for (int64 colno = 0; colno < cols; ++colno) {
247       float column_sum = 0;
248       for (int64 rowno = 0; rowno < rows; ++rowno) {
249         column_sum += input_data(rowno, colno);
250       }
251       expected.push_back(column_sum);
252     }
253     ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
254                                ErrorSpec(0.01, 1e-4));
255   }
256 
257   template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_floating_point<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)258   void ComputeAndCompareGeneric(
259       typename std::enable_if<std::is_floating_point<NativeT>::value,
260                               XlaBuilder>::type* builder,
261       absl::Span<const NativeT> expected,
262       absl::Span<GlobalData* const> arguments) {
263     ComputeAndCompareR1<NativeT>(builder, expected, arguments,
264                                  ErrorSpec(0.01, 1e-4));
265   }
266 
267   template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_integral<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)268   void ComputeAndCompareGeneric(
269       typename std::enable_if<std::is_integral<NativeT>::value,
270                               XlaBuilder>::type* builder,
271       absl::Span<const NativeT> expected,
272       absl::Span<GlobalData* const> arguments) {
273     ComputeAndCompareR1<NativeT>(builder, expected, arguments);
274   }
275 
276   template <typename NativeT>
RunVectorizedReduceTestForType(const std::function<XlaComputation (XlaBuilder *)> & reduction_function_generator,const std::function<NativeT (NativeT,NativeT)> & reference_reduction_function,const NativeT & initial_value)277   void RunVectorizedReduceTestForType(
278       const std::function<XlaComputation(XlaBuilder*)>&
279           reduction_function_generator,
280       const std::function<NativeT(NativeT, NativeT)>&
281           reference_reduction_function,
282       const NativeT& initial_value) {
283     const int rows = 64, cols = 128;
284     const int minor = 1, major = 0;
285     XlaBuilder builder(TestName());
286     XlaComputation reduction_function = reduction_function_generator(&builder);
287     const Shape input_shape = ShapeUtil::MakeShape(
288         xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
289     auto input = Parameter(&builder, 0, input_shape, "input");
290     auto zero = ConstantR0<NativeT>(&builder, initial_value);
291     Reduce(input, zero, reduction_function,
292            /*dimensions_to_reduce=*/{0});
293 
294     Array2D<NativeT> input_data(rows, cols);
295     input_data.FillUnique(initial_value);
296     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
297     input_literal =
298         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
299     std::unique_ptr<GlobalData> input_global_data =
300         client_->TransferToServer(input_literal).ConsumeValueOrDie();
301 
302     // NativeT can be bool, and std::vector<bool> does not convert to
303     // Span.
304     std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
305     for (int64 colno = 0; colno < cols; ++colno) {
306       NativeT column_result = initial_value;
307       for (int64 rowno = 0; rowno < rows; ++rowno) {
308         column_result = reference_reduction_function(column_result,
309                                                      input_data(rowno, colno));
310       }
311       expected[colno] = column_result;
312     }
313 
314     ComputeAndCompareGeneric<NativeT>(
315         &builder, absl::Span<const NativeT>(expected.get(), cols),
316         {input_global_data.get()});
317   }
318 
RunVectorizedReduceTest(const std::function<XlaComputation (PrimitiveType,XlaBuilder *)> & reduction_function_generator_for_type,const std::function<float (float,float)> & reference_reduction_function_for_floats,const std::function<int32 (int32,int32)> & reference_reduction_function_for_ints,const std::function<uint32 (uint32,uint32)> & reference_reduction_function_for_uints,float floating_point_identity,int32 signed_int_identity,uint32 unsigned_int_identity)319   void RunVectorizedReduceTest(
320       const std::function<XlaComputation(PrimitiveType, XlaBuilder*)>&
321           reduction_function_generator_for_type,
322       const std::function<float(float, float)>&
323           reference_reduction_function_for_floats,
324       const std::function<int32(int32, int32)>&
325           reference_reduction_function_for_ints,
326       const std::function<uint32(uint32, uint32)>&
327           reference_reduction_function_for_uints,
328       float floating_point_identity, int32 signed_int_identity,
329       uint32 unsigned_int_identity) {
330     // Float version
331     RunVectorizedReduceTestForType<float>(
332         [&](XlaBuilder* builder) {
333           return reduction_function_generator_for_type(F32, builder);
334         },
335         reference_reduction_function_for_floats, floating_point_identity);
336 
337     // Signed int version
338     RunVectorizedReduceTestForType<int32>(
339         [&](XlaBuilder* builder) {
340           return reduction_function_generator_for_type(S32, builder);
341         },
342         reference_reduction_function_for_ints, signed_int_identity);
343 
344     // Unsigned int version
345     RunVectorizedReduceTestForType<uint32>(
346         [&](XlaBuilder* builder) {
347           return reduction_function_generator_for_type(U32, builder);
348         },
349         reference_reduction_function_for_uints, unsigned_int_identity);
350   }
351 
352   Literal literal_2d_;
353   Literal literal_3d_;
354   uint32 seed_ = 0xdeadbeef;
355 };
356 
XLA_TEST_F(ReduceTest,ReduceR1_0_F32_To_R0)357 XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest,ReduceR1_1_F32_To_R0)358 XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest,ReduceR1_2_F32_To_R0)359 XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest,ReduceR1_16_F32_To_R0)360 XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
XLA_TEST_F(ReduceTest,ReduceR1_128_F32_To_R0)361 XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest,ReduceR1_129_F32_To_R0)362 XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
XLA_TEST_F(ReduceTest,ReduceR1_240_F32_To_R0)363 XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest,ReduceR1_256_F32_To_R0)364 XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest,ReduceR1_1024_F32_To_R0)365 XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest,ReduceR1_2048_F32_To_R0)366 XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
XLA_TEST_F(ReduceTest,ReduceR1_16K_F32_To_R0)367 XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16KP1_F32_To_R0)368 XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
369   RunR1ToR0Test(16 * 1024 + 1);
370 }
XLA_TEST_F(ReduceTest,ReduceR1_64K_F32_To_R0)371 XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_1M_F32_To_R0)372 XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16M_F32_To_R0)373 XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
374 
XLA_TEST_F(ReduceTest,ReduceR2_0x0_To_R0)375 XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R0)376 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R0)377 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
XLA_TEST_F(ReduceTest,ReduceR2_2x0_To_R0)378 XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R0)379 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R0)380 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R0)381 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R0)382 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R0)383 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R0)384 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
385   RunR2ToR0Test(111, 50, 0, 1);
386 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R0)387 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R0)388 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
389 
390 // Disabled due to b/33245142. Failed on 2016-11-30.
391 // XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R1)392 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R1)393 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
394 // Disabled due to b/33245142. Failed on 2016-11-30.
395 // XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R1)396 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R1)397 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R1)398 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R1)399 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R1)400 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R1)401 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
402   RunR2ToR1Test(111, 50, 0, 1);
403 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R1)404 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R1)405 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
406 
XLA_TEST_F(ReduceTest,AndReduceAllOnesR1_10_Pred)407 XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) {
408   constexpr int element_count = 10;
409   std::vector<int> input(element_count, 1);
410   RunR1ToR0PredTest(/*and_reduce=*/true, input);
411 }
412 
XLA_TEST_F(ReduceTest,AndReduceOnesAndZerosR1_10_Pred)413 XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) {
414   constexpr int element_count = 10;
415   std::vector<int> input(element_count);
416   for (int i = 0; i < element_count; ++i) {
417     input[i] = i % 2;
418   }
419   RunR1ToR0PredTest(/*and_reduce=*/true, input);
420 }
421 
XLA_TEST_F(ReduceTest,OrReduceAllOnesR1_10_Pred)422 XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) {
423   constexpr int element_count = 10;
424   std::vector<int> input(element_count, 1);
425   RunR1ToR0PredTest(/*and_reduce=*/false, input);
426 }
427 
XLA_TEST_F(ReduceTest,OrReduceOnesAndZerosR1_10_Pred)428 XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) {
429   constexpr int element_count = 10;
430   std::vector<int> input(element_count);
431   for (int i = 0; i < element_count; ++i) {
432     input[i] = i % 2;
433   }
434   RunR1ToR0PredTest(/*and_reduce=*/false, input);
435 }
436 
XLA_TEST_F(ReduceTest,ReduceElementwiseR2_111x50_To_R1)437 XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
438   const int64 rows = 111, cols = 50;
439 
440   XlaBuilder builder(TestName());
441   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
442   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
443   auto input = Parameter(&builder, 0, input_shape, "input");
444   auto zero = ConstantR0<float>(&builder, 0.0);
445   auto log_ = Log(input);
446   Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
447 
448   Array2D<float> input_data(rows, cols);
449   input_data.FillRandom(3.14f, 0.04);
450   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
451   input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
452   std::unique_ptr<GlobalData> input_global_data =
453       client_->TransferToServer(input_literal).ConsumeValueOrDie();
454 
455   std::vector<float> expected;
456   for (int64 colno = 0; colno < cols; ++colno) {
457     float column_sum = 0;
458     for (int64 rowno = 0; rowno < rows; ++rowno) {
459       column_sum += std::log(input_data(rowno, colno));
460     }
461     expected.push_back(column_sum);
462   }
463   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
464                              ErrorSpec(0.01, 1e-4));
465 }
466 
XLA_TEST_F(ReduceTest,TransposeAndReduceElementwiseR2_111x50_To_R1)467 XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
468   const int64 rows = 111, cols = 50;
469 
470   XlaBuilder builder(TestName());
471   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
472   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
473   auto input = Parameter(&builder, 0, input_shape, "input");
474   auto zero = ConstantR0<float>(&builder, 0.0);
475   auto log_ = Log(input);
476   auto transpose = Transpose(log_, {1, 0});
477   Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
478 
479   Array2D<float> input_data(rows, cols);
480   input_data.FillRandom(3.14f, 0.04);
481   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
482   input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
483   std::unique_ptr<GlobalData> input_global_data =
484       client_->TransferToServer(input_literal).ConsumeValueOrDie();
485 
486   std::vector<float> expected;
487   for (int64 colno = 0; colno < cols; ++colno) {
488     float column_sum = 0;
489     for (int64 rowno = 0; rowno < rows; ++rowno) {
490       column_sum += std::log(input_data(rowno, colno));
491     }
492     expected.push_back(column_sum);
493   }
494   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
495                              ErrorSpec(0.01, 1e-4));
496 }
497 
498 // Test that algebraic simplifier does not incorrectly fold a transpose into a
499 // reduction operation.
XLA_TEST_F(ReduceTest,TransposeAndReduceR3_12x111x50_To_R2)500 XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
501   XlaBuilder builder(TestName());
502   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
503   const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
504   XlaOp input = Parameter(&builder, 0, input_shape, "input");
505   XlaOp zero = ConstantR0<float>(&builder, 0.0);
506   XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
507   Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
508 
509   TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
510 
511   ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
512 }
513 
XLA_TEST_F(ReduceTest,Reshape_111x2x25Reduce_111x50_To_R1)514 XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
515   const int64 rows = 111, cols = 50;
516 
517   XlaBuilder builder(TestName());
518   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
519   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
520   auto input = Parameter(&builder, 0, input_shape, "input");
521   auto zero = ConstantR0<float>(&builder, 0.0);
522   auto log_ = Tanh(input);
523   auto reshape = Reshape(log_, {rows, cols});
524   Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
525 
526   Array3D<float> input_data(rows, 2, cols / 2);
527   input_data.FillRandom(3.14f, 0.04);
528   Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
529   std::unique_ptr<GlobalData> input_global_data =
530       client_->TransferToServer(input_literal).ConsumeValueOrDie();
531 
532   std::vector<float> expected;
533   for (int64 major = 0; major < 2; ++major) {
534     for (int64 colno = 0; colno < cols / 2; ++colno) {
535       float column_sum = 0;
536       for (int64 rowno = 0; rowno < rows; ++rowno) {
537         column_sum += std::tanh(input_data(rowno, major, colno));
538       }
539       expected.push_back(column_sum);
540     }
541   }
542   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
543                              ErrorSpec(0.01, 1e-4));
544 }
545 
546 struct BoundsLayout {
547   std::vector<int64> bounds;
548   std::vector<int64> layout;
549   std::vector<int64> reduce_dims;
550 };
551 
PrintTo(const BoundsLayout & spec,std::ostream * os)552 void PrintTo(const BoundsLayout& spec, std::ostream* os) {
553   *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(),
554                          spec.bounds.size() - spec.reduce_dims.size(),
555                          absl::StrJoin(spec.bounds, "x"),
556                          absl::StrJoin(spec.layout, ""),
557                          absl::StrJoin(spec.reduce_dims, ""));
558 }
559 
560 // Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,AddReduce2DScalarToR0)561 XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
562   XlaBuilder builder(TestName());
563   auto add = CreateScalarAddComputation(F32, &builder);
564   auto scalar = ConstantR0<float>(&builder, 42.0);
565   auto broadcasted = Broadcast(scalar, {500, 500});
566   Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
567 
568   float expected = 42.0f * static_cast<float>(500 * 500);
569   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
570 }
571 
572 // Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DScalarToR0)573 XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
574   XlaBuilder builder(TestName());
575   auto max = CreateScalarMaxComputation(F32, &builder);
576   auto scalar = ConstantR0<float>(&builder, 42.0);
577   auto broadcasted = Broadcast(scalar, {500, 500});
578   Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), max, {0, 1});
579 
580   float expected = 42.0f;
581   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
582 }
583 
584 // Max-reduces a matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DToR0)585 XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
586   XlaBuilder builder(TestName());
587   auto max = CreateScalarMaxComputation(F32, &builder);
588   Array2D<float> input(300, 250);
589   input.FillRandom(214.0f);
590   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
591   Reduce(ConstantLiteral(&builder, input_literal),
592          ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
593   auto input_max = FLT_MIN;
594   input.Each(
595       [&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
596   ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
597 }
598 
599 // Min-reduces matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MinReduce2DToR0)600 XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
601   XlaBuilder builder(TestName());
602   auto min = CreateScalarMinComputation(F32, &builder);
603   Array2D<float> input(150, 130);
604   input.FillRandom(214.0f);
605   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
606   Reduce(ConstantLiteral(&builder, input_literal),
607          ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
608 
609   auto input_min = FLT_MAX;
610   input.Each(
611       [&](int64, int64, float* v) { input_min = std::min(input_min, *v); });
612   ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
613 }
614 
XLA_TEST_F(ReduceTest,UnsignedInt_MinReduce)615 XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
616   XlaBuilder builder(TestName());
617   Array2D<uint32> input({{1}, {2}});
618   auto min = CreateScalarMinComputation(U32, &builder);
619   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
620   auto initial_value =
621       ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
622 
623   Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
624   ComputeAndCompareR0<uint32>(&builder, 1, {});
625 }
626 
XLA_TEST_F(ReduceTest,UnsignedInt_MaxReduce)627 XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
628   XlaBuilder builder(TestName());
629   Array2D<uint32> input({{1}, {2}});
630   auto max = CreateScalarMaxComputation(U32, &builder);
631   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
632   auto initial_value =
633       ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
634 
635   Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
636   ComputeAndCompareR0<uint32>(&builder, 2, {});
637 }
638 
639 // Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest,Reduce2DAmong1)640 XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
641   XlaBuilder builder(TestName());
642   auto m = ConstantLiteral(&builder, literal_2d_);
643   auto add = CreateScalarAddComputation(F32, &builder);
644   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
645 
646   std::vector<float> expected = {6.f, 15.f};
647   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
648 }
649 
XLA_TEST_F(ReduceTest,Reduce2DAmong0and1)650 XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
651   // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
652   XlaBuilder builder(TestName());
653   auto m = ConstantLiteral(&builder, literal_2d_);
654   auto add = CreateScalarAddComputation(F32, &builder);
655   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
656 
657   ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
658 }
659 
660 // Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest,Reduce2DAmongY)661 XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
662   XlaBuilder builder("reduce_among_y");
663   auto m = ConstantLiteral(&builder, literal_2d_);
664   auto add = CreateScalarAddComputation(F32, &builder);
665   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
666 
667   std::vector<float> expected = {5.f, 7.f, 9.f};
668   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
669 }
670 
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_1_2)671 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
672   XlaBuilder builder(TestName());
673   auto m = ConstantLiteral(&builder, literal_3d_);
674   auto add = CreateScalarAddComputation(F32, &builder);
675   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
676 
677   std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
678   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
679 }
680 
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_0_1)681 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
682   XlaBuilder builder(TestName());
683   auto m = ConstantLiteral(&builder, literal_3d_);
684   auto add = CreateScalarAddComputation(F32, &builder);
685   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
686 
687   std::vector<float> expected = {20.f, 28.f, 36.f};
688   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
689 }
690 
XLA_TEST_F(ReduceTest,ReduceR3ToR0)691 XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
692   XlaBuilder builder(TestName());
693   auto m = ConstantLiteral(&builder, literal_3d_);
694   auto add = CreateScalarAddComputation(F32, &builder);
695   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
696 
697   float expected = 21.0f * 4.0;
698   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
699 }
700 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim0)701 XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
702   XlaBuilder builder(TestName());
703   auto m = ConstantLiteral(&builder, literal_3d_);
704   auto add = CreateScalarAddComputation(F32, &builder);
705   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
706 
707   // clang-format off
708   Array2D<float> expected({
709       {4.f, 8.f, 12.f},
710       {16.f, 20.f, 24.f},
711   });
712   // clang-format on
713   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
714 }
715 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim1)716 XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
717   XlaBuilder builder(TestName());
718   auto m = ConstantLiteral(&builder, literal_3d_);
719   auto add = CreateScalarAddComputation(F32, &builder);
720   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
721 
722   // clang-format off
723   Array2D<float> expected({
724       {5.f, 7.f, 9.f},
725       {5.f, 7.f, 9.f},
726       {5.f, 7.f, 9.f},
727       {5.f, 7.f, 9.f},
728   });
729   // clang-format on
730   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
731 }
732 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim2)733 XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
734   XlaBuilder builder(TestName());
735   auto m = ConstantLiteral(&builder, literal_3d_);
736   auto add = CreateScalarAddComputation(F32, &builder);
737   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
738 
739   // clang-format off
740   Array2D<float> expected({
741       {6.f, 15.f},
742       {6.f, 15.f},
743       {6.f, 15.f},
744       {6.f, 15.f},
745   });
746   // clang-format on
747   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
748 }
749 
XLA_TEST_F(ReduceTest,VectorizedReduce_Add)750 XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
751   RunVectorizedReduceTest(
752       static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
753       [](float a, float b) { return a + b; },
754       [](int32 a, int32 b) {
755         return static_cast<int32>(static_cast<uint32>(a) +
756                                   static_cast<uint32>(b));
757       },
758       [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
759 }
760 
XLA_TEST_F(ReduceTest,VectorizedReduce_Multiply)761 XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
762   RunVectorizedReduceTest(
763       static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
764       [](float a, float b) { return a * b; },
765       [](int32 a, int32 b) {
766         return static_cast<int32>(static_cast<uint32>(a) *
767                                   static_cast<uint32>(b));
768       },
769       [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
770 }
771 
XLA_TEST_F(ReduceTest,VectorizedReduce_Max)772 XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
773   RunVectorizedReduceTest(
774       static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
775       [](float a, float b) { return std::max(a, b); },
776       [](int32 a, int32 b) { return std::max(a, b); },
777       [](uint32 a, uint32 b) { return std::max(a, b); },
778       std::numeric_limits<float>::min(), std::numeric_limits<int32>::min(),
779       std::numeric_limits<uint32>::min());
780 }
781 
XLA_TEST_F(ReduceTest,VectorizedReduce_Min)782 XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
783   RunVectorizedReduceTest(
784       static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
785       [](float a, float b) { return std::min(a, b); },
786       [](int32 a, int32 b) { return std::min(a, b); },
787       [](uint32 a, uint32 b) { return std::min(a, b); },
788       std::numeric_limits<float>::max(), std::numeric_limits<int32>::max(),
789       std::numeric_limits<uint32>::max());
790 }
791 
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanAnd)792 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
793   RunVectorizedReduceTestForType<bool>(
794       static_cast<FuncGenerator>([](XlaBuilder* builder) {
795         return CreateScalarAndComputation(PRED, builder);
796       }),
797       [](bool a, bool b) { return a && b; }, true);
798 }
799 
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanOr)800 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
801   RunVectorizedReduceTestForType<bool>(
802       static_cast<FuncGenerator>([](XlaBuilder* builder) {
803         return CreateScalarOrComputation(PRED, builder);
804       }),
805       [](bool a, bool b) { return a || b; }, false);
806 }
807 
808 class ReduceR3ToR2Test : public ReduceTest,
809                          public ::testing::WithParamInterface<BoundsLayout> {};
810 
XLA_TEST_P(ReduceR3ToR2Test,ReduceR3ToR2)811 XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
812   XlaBuilder builder(TestName());
813   const auto& bounds = GetParam().bounds;
814   Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
815   //  input_array.FillRandom(3.14f, 0.05);
816   input_array.Fill(1.0f);
817 
818   auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
819   input_literal =
820       input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
821   std::unique_ptr<GlobalData> input_data =
822       client_->TransferToServer(input_literal).ConsumeValueOrDie();
823 
824   auto input_activations =
825       Parameter(&builder, 0, input_literal.shape(), "input");
826   XlaComputation add = CreateScalarAddComputation(F32, &builder);
827   Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
828          GetParam().reduce_dims);
829 
830   auto expected =
831       ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
832                                   [](float a, float b) { return a + b; });
833 
834   ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
835                              ErrorSpec(1e-3, 1e-3));
836 }
837 
838 INSTANTIATE_TEST_CASE_P(
839     ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
840     // Specifies (shape, layout, reduction dimensions).
841     ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
842                       BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
843                       BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
844                       // These should be simplified into a reshape.
845                       BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
846                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
847                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
848                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
849                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
850                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
851                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
852                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
853                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
854                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
855                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
856                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
857                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
858                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
859                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
860                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
861 
XLA_TEST_F(ReduceTest,OperationOnConstantAsInitValue)862 XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
863   XlaBuilder builder(TestName());
864   XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
865 
866   auto a = ConstantR0<float>(&builder, 2.0f);
867   auto a2 = Abs(a);
868 
869   Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
870   std::unique_ptr<GlobalData> b_data =
871       client_->TransferToServer(b_literal).ConsumeValueOrDie();
872   auto b = Parameter(&builder, 0, b_literal.shape(), "b");
873   Reduce(b, a2, max_f32, {0});
874 
875   ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
876 }
877 
XLA_TEST_F(ReduceTest,ReduceAndPredR2_128x64_To_R1)878 XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) {
879   RunR2ToR1PredTest</*cols=64*/ 64>(/*and_reduce=true*/ true, /*rows=128*/ 128);
880 }
XLA_TEST_F(ReduceTest,ReduceOrPredR2_64x32_To_R1)881 XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) {
882   RunR2ToR1PredTest</*cols=32*/ 32>(/*and_reduce=false*/ false, /*rows=64*/ 64);
883 }
884 
885 // Tests reductions with different initial values.  There's no test macro that
886 // combines TYPED_TEST and TYPED_P, so we have to do it manually.
887 class ReduceInitializerTest : public ReduceTest {
888  protected:
889   template <typename T>
DoTest(T initializer,int num_elems)890   void DoTest(T initializer, int num_elems) {
891     XlaBuilder builder(TestName());
892     XlaComputation max_fn = CreateScalarMaxComputation(
893         primitive_util::NativeToPrimitiveType<T>(), &builder);
894 
895     auto init = ConstantR0<T>(&builder, initializer);
896     std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
897     auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
898     auto input_data =
899         client_->TransferToServer(input_literal).ConsumeValueOrDie();
900     Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
901            {0});
902 
903     ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
904   }
905 };
906 
XLA_TEST_F(ReduceInitializerTest,U8Small)907 XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest<uint8>(42, 2); }
908 
XLA_TEST_F(ReduceInitializerTest,U8BigPowerOf2)909 XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest<uint8>(42, 4096); }
910 
XLA_TEST_F(ReduceInitializerTest,U8InitializerBigNonPowerOf2)911 XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) {
912   DoTest<uint8>(42, 4095);
913 }
914 
XLA_TEST_F(ReduceInitializerTest,U64InitializerZero)915 XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) {
916   DoTest<uint64>(0, 1024);
917 }
918 
XLA_TEST_F(ReduceInitializerTest,U64InitializerOne)919 XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) {
920   DoTest<uint64>(1, 1024);
921 }
922 
XLA_TEST_F(ReduceInitializerTest,U64InitializerBigValue)923 XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
924   DoTest<uint64>(1234556789123, 1024);
925 }
926 
927 // Test the operational semantic that the init value is passed on the lhs for
928 // reduces. Can be tested by performing an "identity" reduce (that simply
929 // returns one of the parameters). In this case, we return the rhs, which for
930 // a 1D array with one element, should not be the init value.
XLA_TEST_F(ReduceTest,ReduceIdentity)931 XLA_TEST_F(ReduceTest, ReduceIdentity) {
932   XlaBuilder builder(TestName());
933   Shape single_float = ShapeUtil::MakeShape(F32, {});
934   Parameter(&builder, 0, single_float, "lhs-unused");
935   Parameter(&builder, 1, single_float, "rhs-used");
936   auto computation_status = builder.Build();
937   TF_ASSERT_OK(computation_status.status());
938 
939   Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
940   Reduce(Parameter(&builder, 0, operand_shape, "operand"),
941          Parameter(&builder, 1, single_float, "init"),
942          computation_status.ValueOrDie(), {0});
943 
944   float operand[] = {42.0f};
945   float init = 58.5f;
946   float expected = 42.0f;
947   Literal input_literal = LiteralUtil::CreateR1<float>(operand);
948   std::unique_ptr<GlobalData> input_global_data =
949       client_->TransferToServer(input_literal).ConsumeValueOrDie();
950   Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
951   std::unique_ptr<GlobalData> input_global_data2 =
952       client_->TransferToServer(input_literal2).ConsumeValueOrDie();
953   ComputeAndCompareR0<float>(
954       &builder, expected, {input_global_data.get(), input_global_data2.get()},
955       ErrorSpec(0.0001));
956 }
957 
XLA_TEST_F(ReduceTest,AndReduceU64)958 XLA_TEST_F(ReduceTest, AndReduceU64) {
959   XlaBuilder builder(TestName());
960   Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
961                                  {0XFFFFFFFFFFFFFFD6LL, 101},
962                                  {1, 0XFFFFFFFFFFFFFFFFLL}};
963   auto reducer = CreateScalarAndComputation(U64, &builder);
964   auto m = ConstantR2FromArray2D(&builder, initializer);
965   Reduce(m, ConstantR0<uint64>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1});
966 
967   std::vector<uint64> expected = {0x1204461080145890LL, 68, 1};
968   ComputeAndCompareR1<uint64>(&builder, expected, {});
969 }
970 
XLA_TEST_F(ReduceTest,OrReduceU64)971 XLA_TEST_F(ReduceTest, OrReduceU64) {
972   XlaBuilder builder(TestName());
973   Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL},
974                                  {0xFFFFFFFFFFFFFFD6LL, 101},
975                                  {1, 0xCAFEBEEFABABABABLL}};
976   auto reducer = CreateScalarOrComputation(U64, &builder);
977   auto m = ConstantR2FromArray2D(&builder, initializer);
978   Reduce(m, ConstantR0<uint64>(&builder, 0), reducer, {1});
979 
980   std::vector<uint64> expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL,
981                                   0xCAFEBEEFABABABABLL};
982   ComputeAndCompareR1<uint64>(&builder, expected, {});
983 }
984 
XLA_TEST_F(ReduceTest,R0ReduceInDisguise)985 XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
986   XlaBuilder builder(TestName());
987   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
988   constexpr int element_count = 127;
989   const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1});
990   auto input = Parameter(&builder, 0, input_shape, "input");
991   auto zero = ConstantR0<float>(&builder, 0.0);
992   Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
993 
994   Array2D<float> input_data(element_count, 1);
995   input_data.FillRandom(3.0f);
996   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
997   std::unique_ptr<GlobalData> input_global_data =
998       client_->TransferToServer(input_literal).ConsumeValueOrDie();
999 
1000   float expected = absl::c_accumulate(input_data, 0.0f);
1001   ComputeAndCompareR1<float>(&builder, {expected}, {input_global_data.get()},
1002                              ErrorSpec(0.001));
1003 }
1004 
1005 }  // namespace
1006 }  // namespace xla
1007