• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests the reduce-window XLA operation.
17 
18 #include <limits>
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/array3d.h"
27 #include "tensorflow/compiler/xla/array4d.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/padding.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/reference_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/test_macros.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 namespace {
47 
48 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
49 // Tests both F32 and BF16.
50 static std::array<bool, 2> use_bfloat16_params{false, true};
51 #else
52 // Only tests F32.
53 static std::array<bool, 1> use_bfloat16_params{false};
54 #endif
55 
56 class ReduceWindowTestBase : public ClientLibraryTestBase {
57  public:
DefaultErrorSpec() const58   ErrorSpec DefaultErrorSpec() const {
59     if (use_bfloat16()) {
60       return ErrorSpec(2e-1, 6e-2);
61     } else {
62       return ErrorSpec(1e-3, 1e-3);
63     }
64   }
65 };
66 
67 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
68                          public ReduceWindowTestBase {
69  public:
ReduceWindowTest()70   ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
71 
ReduceWindowAdd(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)72   void ReduceWindowAdd(const XlaOp& input,
73                        absl::Span<const int64> window_dimensions,
74                        absl::Span<const int64> window_strides,
75                        Padding padding) {
76     auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
77                                           &builder_);
78     ReduceWindow(input, init,
79                  CreateScalarAddComputation(FloatType(), &builder_),
80                  window_dimensions, window_strides, padding);
81   }
82 
ReduceWindowMax(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)83   void ReduceWindowMax(const XlaOp& input,
84                        absl::Span<const int64> window_dimensions,
85                        absl::Span<const int64> window_strides,
86                        Padding padding) {
87     auto init =
88         CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
89     ReduceWindow(input, init,
90                  CreateScalarMaxComputation(FloatType(), &builder_),
91                  window_dimensions, window_strides, padding);
92   }
93 
ReduceWindowMin(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)94   void ReduceWindowMin(const XlaOp& input,
95                        absl::Span<const int64> window_dimensions,
96                        absl::Span<const int64> window_strides,
97                        Padding padding) {
98     auto init =
99         CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
100     ReduceWindow(input, init,
101                  CreateScalarMinComputation(FloatType(), &builder_),
102                  window_dimensions, window_strides, padding);
103   }
104 
105   XlaBuilder builder_;
106 };
107 
TEST_P(ReduceWindowTest,MismatchedRanksGivesErrorStatus)108 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
109   const auto input = CreateConstantFromLiteral(
110       LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
111   const auto init_value =
112       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
113   TF_ASSERT_OK(builder_.first_error());
114   ReduceWindow(input, init_value,
115                CreateScalarAddComputation(FloatType(), &builder_),
116                /*window_dimensions=*/{1, 2},
117                /*window_strides=*/{1}, Padding::kValid);
118   ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
119       << builder_.first_error();
120   ASSERT_THAT(builder_.first_error().error_message(),
121               ::testing::HasSubstr("Want input dimensions size"));
122 }
123 
124 // Regression test for b/68964348.
TEST_P(ReduceWindowTest,R0ReduceWindow)125 TEST_P(ReduceWindowTest, R0ReduceWindow) {
126   const auto input =
127       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
128   const auto init =
129       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
130   ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
131                /*window_dimensions=*/{},
132                /*window_strides=*/{}, Padding::kSame);
133   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
134                            ErrorSpec(0.00001));
135 }
136 
TEST_P(ReduceWindowTest,Min3In5Stride2)137 TEST_P(ReduceWindowTest, Min3In5Stride2) {
138   const auto input = CreateConstantFromLiteral(
139       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
140   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
141   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
142                            {}, ErrorSpec(0.00001));
143 }
144 
TEST_P(ReduceWindowTest,Min3In5Stride1WithSamePadding)145 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
146   const auto input = CreateConstantFromLiteral(
147       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
148   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
149                   Padding::kSame);
150   ComputeAndCompareLiteral(&builder_,
151                            LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
152                            {}, ErrorSpec(0.00001));
153 }
154 
XLA_TEST_P(ReduceWindowTest,ZeroElementSmall)155 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
156   Array4D<float> input_array(1, 0, 2, 1);
157   const auto input = CreateConstantFromArray(input_array, &builder_);
158   Padding padding = Padding::kSame;
159   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
160 
161   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
162                                               {1, 1, 1, 1}, padding);
163 
164   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
165                            DefaultErrorSpec());
166 }
167 
TEST_P(ReduceWindowTest,NonSquareSmall)168 TEST_P(ReduceWindowTest, NonSquareSmall) {
169   Array4D<float> input_array(1, 2, 2, 1);
170   input_array.FillRandom(2.f, 2.f);
171   const auto input = CreateConstantFromArray(input_array, &builder_);
172 
173   Padding padding = Padding::kSame;
174   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
175 
176   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
177                                               {1, 1, 1, 1}, padding);
178 
179   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
180                            DefaultErrorSpec());
181 }
182 
TEST_P(ReduceWindowTest,MiddleDimsSmall)183 TEST_P(ReduceWindowTest, MiddleDimsSmall) {
184   Array4D<float> input_array(1, 3, 3, 1);
185   input_array.FillRandom(2.f, 2.f);
186   const auto input = CreateConstantFromArray(input_array, &builder_);
187   Padding padding = Padding::kSame;
188   ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
189 
190   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
191                                               {1, 2, 2, 1}, padding);
192 
193   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
194                            DefaultErrorSpec());
195 }
196 
TEST_P(ReduceWindowTest,Along2ndMinorDim)197 TEST_P(ReduceWindowTest, Along2ndMinorDim) {
198   Array4D<float> input_array(3, 6, 7, 32);
199   input_array.FillRandom(2.f, 2.f);
200   const auto input = CreateConstantFromArray(input_array, &builder_);
201 
202   // The parameters of this reduction mimic feature norm (e.g. LRN).
203   int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
204   Padding padding = Padding::kSame;
205   ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
206 
207   auto res = ReferenceUtil::ReduceWindow4DAdd(
208       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
209 
210   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
211                            DefaultErrorSpec());
212 }
213 
TEST_P(ReduceWindowTest,AmongMajor2Dims)214 TEST_P(ReduceWindowTest, AmongMajor2Dims) {
215   Array4D<float> input_array(4, 4, 6, 8);
216   input_array.FillWithMinorDimNum();
217   const auto input_data_handle =
218       CreateConstantFromArray(input_array, &builder_);
219 
220   int win_len = 3;
221   int win_stride = 1;
222 
223   Padding padding = Padding::kSame;
224   // Reduce only along the x and y dimensions, according to the win_len.
225   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
226                   {win_stride, win_stride, 1, 1}, padding);
227 
228   auto result = ReferenceUtil::ReduceWindow4DAdd(
229       input_array, 0.0f, {win_len, win_len, 1, 1},
230       {win_stride, win_stride, 1, 1}, padding);
231 
232   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
233                            DefaultErrorSpec());
234 }
235 
TEST_P(ReduceWindowTest,AmongMajor2DimsMediumSize)236 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
237   Array4D<float> input_array(9, 12, 4, 89);
238   input_array.FillRandom(2.f, 2.f);
239 
240   int win_len = 3;
241   int win_stride = 2;
242 
243   const auto input_data_handle =
244       CreateConstantFromArray(input_array, &builder_);
245 
246   Padding padding = Padding::kSame;
247   // Reduce only along the x and y dimensions, according to the win_len.
248   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
249                   {win_stride, win_stride, 1, 1}, padding);
250 
251   auto result = ReferenceUtil::ReduceWindow4DAdd(
252       input_array, 0.0f, {win_len, win_len, 1, 1},
253       {win_stride, win_stride, 1, 1}, padding);
254 
255   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
256                            DefaultErrorSpec());
257 }
258 
259 // Tests the super windowing logic w.r.t handling prime number of windows in a
260 // major dimension with reduction.
TEST_P(ReduceWindowTest,PrimeWindowsInReductionDimension)261 TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
262   Array4D<float> input_array(15, 15, 4, 128);
263   input_array.FillRandom(2.f, 4.f);
264 
265   int win_len = 3;
266   int win_stride = 2;
267 
268   const auto input_data_handle =
269       CreateConstantFromArray(input_array, &builder_);
270 
271   Padding padding = Padding::kSame;
272   // Reduce only along the x and y dimensions, according to the win_len.
273   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
274                   {win_stride, win_stride, 1, 1}, padding);
275 
276   auto result = ReferenceUtil::ReduceWindow4DAdd(
277       input_array, 0.0f, {win_len, win_len, 1, 1},
278       {win_stride, win_stride, 1, 1}, padding);
279 
280   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
281                            DefaultErrorSpec());
282 }
283 
TEST_P(ReduceWindowTest,ReduceAlongLaneDimension)284 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
285   Array4D<float> input_array(19, 17, 8, 256);
286   input_array.FillWithMinorDimNum();
287 
288   const auto input_data_handle =
289       CreateConstantFromArray(input_array, &builder_);
290 
291   Padding padding = Padding::kSame;
292   ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
293 
294   auto result = ReferenceUtil::ReduceWindow4DAdd(
295       input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
296 
297   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
298                            DefaultErrorSpec());
299 }
300 
301 // Tests a reduction function that is not a simple add/min/max/etc.
XLA_TEST_P(ReduceWindowTest,NonstandardReduceFunction)302 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
303   Array4D<float> input_array(1, 2, 2, 1);
304   input_array(0, 0, 0, 0) = 1;
305   input_array(0, 0, 1, 0) = 2;
306   input_array(0, 1, 0, 0) = 3;
307   input_array(0, 1, 1, 0) = 4;
308   const auto input = CreateConstantFromArray(input_array, &builder_);
309 
310   Padding padding = Padding::kValid;
311   const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
312   auto b = builder_.CreateSubBuilder("unusual");
313   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
314   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
315   Min(Add(lhs, rhs),
316       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
317   XlaComputation reduce_fn = b->BuildAndNoteError();
318 
319   ReduceWindow(
320       input,
321       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
322       reduce_fn,
323       /*window_dimensions=*/{1, 1, 2, 1},
324       /*window_strides=*/{1, 1, 1, 1}, padding);
325 
326   const auto reduce_func = [](float arg1, float arg2) {
327     return std::min<float>(arg1 + arg2, 8.0f);
328   };
329 
330   auto expected =
331       ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
332                                            /*window=*/{1, 1, 2, 1},
333                                            /*stride=*/{1, 1, 1, 1}, padding);
334 
335   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
336                            {}, DefaultErrorSpec());
337 }
338 
TEST_P(ReduceWindowTest,R4UnitWindow)339 TEST_P(ReduceWindowTest, R4UnitWindow) {
340   Array4D<float> input_array(13, 12, 8, 15);
341   input_array.FillRandom(2.f, 2.f);
342   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
343       input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
344   XlaOp input;
345   auto input_data = CreateParameterAndTransferLiteral(
346       0, input_literal, "parameter", &builder_, &input);
347 
348   Padding padding = Padding::kSame;
349   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
350 
351   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
352                                               {1, 4, 1, 1}, padding);
353 
354   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
355                            {input_data.get()}, DefaultErrorSpec());
356 }
357 
XLA_TEST_P(ReduceWindowTest,R6AddMultipleStrides)358 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
359   std::vector<int64> input_dims(6, 8);
360   auto shape = ShapeUtil::MakeShape(F32, input_dims);
361 
362   Literal arg_literal(shape);
363   arg_literal.PopulateWithValue(1.0f);
364   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
365 
366   Padding padding = Padding::kValid;
367   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
368 
369   std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
370   std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
371   Shape result_shape =
372       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
373   Literal expected(result_shape);
374   expected.PopulateWithValue(27.0f);
375   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
376 }
377 
XLA_TEST_P(ReduceWindowTest,R6Add)378 XLA_TEST_P(ReduceWindowTest, R6Add) {
379   std::vector<int64> input_dims(6, 8);
380   auto shape = ShapeUtil::MakeShape(F32, input_dims);
381 
382   Literal arg_literal =
383       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
384 
385   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
386 
387   Padding padding = Padding::kValid;
388   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
389 
390   std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
391   Literal expected =
392       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
393 
394   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
395 }
396 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorStride)397 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
398   Array4D<float> input_array(2, 1, 27, 119);
399   input_array.FillRandom(2.0f);
400   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
401       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
402   XlaOp input;
403   auto input_data = CreateParameterAndTransferLiteral(
404       0, input_literal, "parameter", &builder_, &input);
405 
406   int win_len = 1;
407   int stride = 8;
408   Padding padding = Padding::kSame;
409   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
410 
411   auto res = ReferenceUtil::ReduceWindow4DAdd(
412       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
413 
414   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
415                            {input_data.get()}, DefaultErrorSpec());
416 }
417 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorUnitStride)418 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
419   Array4D<float> input_array(3, 2, 4, 64);
420   input_array.FillRandom(2.0f);
421   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
422       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
423   XlaOp input;
424   auto input_data = CreateParameterAndTransferLiteral(
425       0, input_literal, "parameter", &builder_, &input);
426 
427   int win_len = 3;
428   int stride = 1;
429   Padding padding = Padding::kSame;
430   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
431 
432   auto res = ReferenceUtil::ReduceWindow4DAdd(
433       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
434 
435   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
436                            {input_data.get()}, DefaultErrorSpec());
437 }
438 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorWin)439 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
440   Array4D<float> input_array(1, 3, 12, 200);
441   input_array.FillRandom(2.0f);
442   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
443       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
444   XlaOp input;
445   auto input_data = CreateParameterAndTransferLiteral(
446       0, input_literal, "parameter", &builder_, &input);
447 
448   int win_len = 8;
449   int stride = 5;
450   Padding padding = Padding::kSame;
451   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
452 
453   auto res = ReferenceUtil::ReduceWindow4DAdd(
454       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
455 
456   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
457                            {input_data.get()}, DefaultErrorSpec());
458 }
459 
TEST_P(ReduceWindowTest,AmongMajor2DimsMultipleMinor)460 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
461   Array4D<float> input_array(6, 4, 10, 130);
462   input_array.FillRandom(2.0f);
463 
464   int win_len = 3;
465   int win_stride = 2;
466 
467   Padding padding = Padding::kSame;
468   const auto input_data_handle =
469       CreateConstantFromArray(input_array, &builder_);
470   // Reduce only along the x and y dimensions, according to the win_len.
471   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
472                   {win_stride, win_stride, 1, 1}, padding);
473 
474   auto result = ReferenceUtil::ReduceWindow4DAdd(
475       input_array, 0.0f, {win_len, win_len, 1, 1},
476       {win_stride, win_stride, 1, 1}, padding);
477   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
478                            DefaultErrorSpec());
479 }
480 
XLA_TEST_P(ReduceWindowTest,Add24In1152_NoOverlap)481 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
482   std::vector<float> input_vector(128 * 9, 1);
483   const auto input = CreateConstantFromLiteral(
484       LiteralUtil::CreateR1<float>(input_vector), &builder_);
485   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
486   ComputeAndCompareLiteral(
487       &builder_,
488       LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
489       DefaultErrorSpec());
490 }
491 
XLA_TEST_P(ReduceWindowTest,Add128In128Stride128)492 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
493   std::vector<float> input_vector{
494       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
495       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
496       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
497       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
498       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
499       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
500       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
501       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
502   const auto input = CreateConstantFromLiteral(
503       LiteralUtil::CreateR1<float>(input_vector), &builder_);
504   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
505   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
506                            DefaultErrorSpec());
507 }
508 
XLA_TEST_P(ReduceWindowTest,Add128In128)509 XLA_TEST_P(ReduceWindowTest, Add128In128) {
510   std::vector<float> input_vector{
511       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
512       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
513       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
514       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
515       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
516       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
517       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
518       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
519   const auto input = CreateConstantFromLiteral(
520       LiteralUtil::CreateR1<float>(input_vector), &builder_);
521   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
522   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
523                            DefaultErrorSpec());
524 }
525 
526 // Regression test for a bug that appeared in Inception (b/34784899).
TEST_P(ReduceWindowTest,R2ReduceWindowInceptionFromBroadcast)527 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
528   Array2D<float> input_array(14, 14, 1.0f);
529   const auto input = CreateConstantFromArray(input_array, &builder_);
530 
531   int win_len = 3;
532   int stride = 1;
533   Padding padding = Padding::kSame;
534   ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
535 
536   auto res = ReferenceUtil::ReduceWindow2DAdd(
537       input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
538 
539   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
540                            {}, DefaultErrorSpec());
541 }
542 
TEST_P(ReduceWindowTest,R2ReduceWindowNonOverlappingFromBroadcast)543 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
544   Array2D<float> input_array(6, 4, 1.0f);
545   XlaOp input = Broadcast(
546       CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
547 
548   Padding padding = Padding::kSame;
549   ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
550 
551   auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
552                                               padding);
553 
554   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
555                            {}, DefaultErrorSpec());
556 }
557 
558 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
559                         ::testing::ValuesIn(use_bfloat16_params));
560 
561 enum Reducer { kAdd, kMax };
562 
563 struct R4ReduceWindowTestData {
564   int64 base_bounds[4];
565   int64 window_bounds[4];
566   int64 strides[4];
567   int64 pad_low[4];
568   int64 pad_high[4];
569   int64 layout[4];
570 
571   Reducer reducer;
572 };
573 
R4ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R4ReduceWindowTestData,bool>> & data)574 string R4ReduceWindowTestDataToString(
575     const ::testing::TestParamInfo<
576         ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
577   const auto& param = ::testing::get<0>(data.param);
578   string str = absl::StrCat(
579       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
580       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
581       "__strides_", absl::StrJoin(param.strides, "x"),              //
582       "__pad_low_", absl::StrJoin(param.pad_low, "x"),              //
583       "__pad_high_", absl::StrJoin(param.pad_high, "x"),            //
584       "__layout_", absl::StrJoin(param.layout, "_"),                //
585       (param.reducer == kAdd) ? "_add" : "_max");
586   CHECK(param.reducer == kAdd || param.reducer == kMax);
587 
588   // Test names are not allowed to contain the '-' character.
589   std::replace(str.begin(), str.end(), '-', 'n');
590   if (::testing::get<1>(data.param)) {
591     absl::StrAppend(&str, "_bfloat16");
592   }
593   return str;
594 }
595 
596 class R4ReduceWindowTest : public ReduceWindowTestBase,
597                            public ::testing::WithParamInterface<
598                                ::testing::tuple<R4ReduceWindowTestData, bool>> {
599  protected:
R4ReduceWindowTest()600   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
601 
DoIt()602   void DoIt() {
603     XlaBuilder b(TestName());
604     const auto& param = ::testing::get<0>(GetParam());
605 
606     const float kInitValue = 0.0f;
607 
608     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
609                          param.base_bounds[2], param.base_bounds[3]);
610     // Choose a prime iota length so that each window sees a unique set of
611     // values. (Technically, the requirement is that the iota length is
612     // relatively prime to all of the dimensions involved in the reduce-window.)
613     input.FillRepeatedIota(0, 137);
614     // Floating point sum reduction requires higher localized precision. We need
615     // the following normalization in order to enable testing of kAdd on large
616     // windows.
617     input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
618       *value = *value / 10000000000.f;
619     });
620     Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
621         input, LayoutUtil::MakeLayout(param.layout));
622     XlaOp parameter;
623     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
624                                                        &b, &parameter);
625 
626     std::vector<std::pair<int64, int64>> padding(4);
627     for (int i = 0; i < 4; ++i) {
628       padding[i] = {param.pad_low[i], param.pad_high[i]};
629     }
630 
631     auto init_value =
632         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
633     CHECK(param.reducer == kAdd || param.reducer == kMax);
634     auto reducer = param.reducer;
635     auto computation = reducer == kAdd
636                            ? CreateScalarAddComputation(FloatType(), &b)
637                            : CreateScalarMaxComputation(FloatType(), &b);
638     ReduceWindowWithGeneralPadding(
639         /*operand=*/parameter,
640         /*init_value=*/init_value,
641         /*computation=*/computation,
642         /*window_dimensions=*/param.window_bounds,
643         /*window_strides=*/param.strides,
644         /*base_dilations=*/{},
645         /*window_dilations=*/{},
646         /*padding=*/padding);
647 
648     CHECK(reducer == kAdd || reducer == kMax);
649     auto reduce_func = reducer == kAdd
650                            ? +[](float a, float b) { return a + b; }
651                            : +[](float a, float b) { return std::max(a, b); };
652     std::unique_ptr<Array4D<float>> expected =
653         ReferenceUtil::ReduceWindow4DGeneric(
654             /*operand=*/input,
655             /*init=*/kInitValue,
656             /*reduce_func=*/reduce_func,
657             /*window=*/param.window_bounds,
658             /*stride=*/param.strides,
659             /*padding=*/padding);
660     Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
661     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
662         input_literal.shape().element_type(),
663         AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
664     ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
665                              DefaultErrorSpec(), &expected_shape_with_layout);
666   }
667 };
668 
TEST_P(R4ReduceWindowTest,DoIt)669 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
670 
671 // base_bounds, window_bounds, strides, pad_low, pad_high
672 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
673     // Minimal edge case.
674     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
675                            /*window_bounds=*/{1, 1, 1, 1},
676                            /*strides=*/{1, 1, 1, 1},
677                            /*pad_low=*/{0, 0, 0, 0},
678                            /*pad_high=*/{0, 0, 0, 0},
679                            /*layout=*/{3, 2, 1, 0},
680                            /*reducer=*/kAdd},
681 
682     // Arbitrary padding (not kSame or kValid).
683     R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
684                            /*window_bounds=*/{3, 3, 1, 1},
685                            /*strides=*/{2, 2, 1, 1},
686                            /*pad_low=*/{4, 4, 0, 0},
687                            /*pad_high=*/{4, 4, 0, 0},
688                            /*layout=*/{3, 2, 1, 0},
689                            /*reducer=*/kAdd},
690 
691     // Zero base bound edge case.
692     R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
693                            /*window_bounds=*/{1, 1, 1, 1},
694                            /*strides=*/{1, 1, 1, 1},
695                            /*pad_low=*/{0, 0, 0, 0},
696                            /*pad_high=*/{0, 0, 0, 0},
697                            /*layout=*/{3, 2, 1, 0},
698                            /*reducer=*/kAdd},
699 
700     // With max instead of add.
701     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
702                            /*window_bounds=*/{2, 3, 1, 1},
703                            /*strides=*/{1, 1, 1, 1},
704                            /*pad_low=*/{0, 0, 0, 0},
705                            /*pad_high=*/{0, 0, 0, 0},
706                            /*layout=*/{3, 2, 1, 0},
707                            /*reducer=*/kMax},
708 
709     // With stride.
710     R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
711                            /*window_bounds=*/{3, 2, 1, 1},
712                            /*strides=*/{2, 4, 1, 1},
713                            /*pad_low=*/{0, 0, 0, 0},
714                            /*pad_high=*/{0, 0, 0, 0},
715                            /*layout=*/{3, 2, 1, 0},
716                            /*reducer=*/kAdd},
717 
718     // With low padding.
719     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
720                            /*window_bounds=*/{3, 2, 1, 1},
721                            /*strides=*/{2, 2, 1, 1},
722                            /*pad_low=*/{3, 2, 0, 0},
723                            /*pad_high=*/{0, 0, 0, 0},
724                            /*layout=*/{3, 2, 1, 0},
725                            /*reducer=*/kAdd},
726 
727     // With high padding.
728     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
729                            /*window_bounds=*/{3, 2, 1, 1},
730                            /*strides=*/{2, 2, 1, 1},
731                            /*pad_low=*/{0, 0, 0, 0},
732                            /*pad_high=*/{2, 3, 0, 0},
733                            /*layout=*/{3, 2, 1, 0},
734                            /*reducer=*/kAdd},
735 
736     // Window touches both sides of the padding simultaneously.
737     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
738                            /*window_bounds=*/{3, 3, 1, 1},
739                            /*strides=*/{1, 1, 1, 1},
740                            /*pad_low=*/{1, 1, 0, 0},
741                            /*pad_high=*/{1, 1, 0, 0},
742                            /*layout=*/{3, 2, 1, 0},
743                            /*reducer=*/kAdd},
744 
745     // Window is entirely in the padding for some positions.
746     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
747                            /*window_bounds=*/{3, 3, 1, 1},
748                            /*strides=*/{1, 1, 1, 1},
749                            /*pad_low=*/{4, 4, 0, 0},
750                            /*pad_high=*/{4, 4, 0, 0},
751                            /*layout=*/{3, 2, 1, 0},
752                            /*reducer=*/kAdd},
753 
754     // Zero base bound with padding edge case.
755     R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
756                            /*window_bounds=*/{1, 1, 1, 1},
757                            /*strides=*/{1, 1, 1, 1},
758                            /*pad_low=*/{0, 1, 0, 0},
759                            /*pad_high=*/{0, 0, 0, 0},
760                            /*layout=*/{3, 2, 1, 0},
761                            /*reducer=*/kAdd},
762 
763     // With stride, low padding and high padding.
764     R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
765                            /*window_bounds=*/{3, 4, 1, 1},
766                            /*strides=*/{3, 1, 1, 1},
767                            /*pad_low=*/{10, 1, 0, 0},
768                            /*pad_high=*/{2, 3, 0, 0},
769                            /*layout=*/{3, 2, 1, 0},
770                            /*reducer=*/kAdd},
771 
772     // With minor dimension == 129.
773     R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
774                            /*window_bounds=*/{1, 1, 1, 1},
775                            /*strides=*/{1, 1, 1, 1},
776                            /*pad_low=*/{0, 0, 0, 0},
777                            /*pad_high=*/{0, 0, 0, 0},
778                            /*layout=*/{3, 2, 1, 0},
779                            /*reducer=*/kAdd},
780 
781     // With minor dims reduction and non-overlapped stride.
782     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
783                            /*window_bounds=*/{1, 1, 2, 2},
784                            /*strides=*/{1, 1, 2, 2},
785                            /*pad_low=*/{0, 0, 0, 0},
786                            /*pad_high=*/{0, 0, 0, 0},
787                            /*layout=*/{3, 2, 1, 0},
788                            /*reducer=*/kAdd},
789 
790     // With minor dims reduction and overlapped stride.
791     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
792                            /*window_bounds=*/{1, 1, 4, 4},
793                            /*strides=*/{1, 1, 2, 2},
794                            /*pad_low=*/{0, 0, 0, 0},
795                            /*pad_high=*/{1, 0, 0, 0},
796                            /*layout=*/{3, 2, 1, 0},
797                            /*reducer=*/kAdd},
798 
799     R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
800                            /*window_bounds=*/{1, 64, 64, 1},
801                            /*strides=*/{1, 64, 64, 1},
802                            /*pad_low=*/{0, 0, 0, 0},
803                            /*pad_high=*/{0, 0, 0, 0},
804                            /*layout=*/{3, 0, 2, 1},
805                            /*reducer=*/kAdd},
806 
807     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
808                            /*window_bounds=*/{112, 112, 1, 8},
809                            /*strides=*/{112, 112, 1, 8},
810                            /*pad_low=*/{0, 0, 0, 0},
811                            /*pad_high=*/{0, 0, 0, 0},
812                            /*layout=*/{3, 2, 1, 0},
813                            /*reducer=*/kMax},
814 
815     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
816                            /*window_bounds=*/{2, 3, 4, 5},
817                            /*strides=*/{1, 1, 1, 1},
818                            /*pad_low=*/{0, 0, 0, 0},
819                            /*pad_high=*/{0, 0, 0, 0},
820                            /*layout=*/{3, 2, 1, 0},
821                            /*reducer=*/kAdd},
822 
823     // With 0321 layout.
824     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
825                            /*window_bounds=*/{2, 3, 4, 5},
826                            /*strides=*/{1, 2, 3, 4},
827                            /*pad_low=*/{0, 0, 0, 0},
828                            /*pad_high=*/{0, 0, 0, 0},
829                            /*layout=*/{0, 3, 2, 1},
830                            /*reducer=*/kAdd},
831 
832     // With 0123 layout.
833     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
834                            /*window_bounds=*/{2, 3, 7, 9},
835                            /*strides=*/{1, 2, 5, 8},
836                            /*pad_low=*/{0, 0, 0, 0},
837                            /*pad_high=*/{0, 0, 0, 0},
838                            /*layout=*/{0, 1, 2, 3},
839                            /*reducer=*/kAdd},
840 };
841 
842 INSTANTIATE_TEST_CASE_P(
843     R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
844     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
845                        ::testing::ValuesIn(use_bfloat16_params)),
846     R4ReduceWindowTestDataToString);
847 
848 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
849 
XLA_TEST_P(R4ReduceWindowLargeTest,DISABLED_ON_INTERPRETER (DoIt))850 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
851 
852 // Test cases that are large/slow/failed.
853 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
854     R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
855                            /*window_bounds=*/{3, 3, 1, 5},
856                            /*strides=*/{1, 1, 1, 5},
857                            /*pad_low=*/{1, 1, 0, 0},
858                            /*pad_high=*/{1, 1, 0, 0},
859                            /*layout=*/{3, 2, 1, 0},
860                            /*reducer=*/kMax},
861 
862     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
863                            /*window_bounds=*/{3, 3, 1, 1},
864                            /*strides=*/{2, 2, 1, 1},
865                            /*pad_low=*/{0, 0, 0, 0},
866                            /*pad_high=*/{1, 1, 0, 0},
867                            /*layout=*/{3, 2, 1, 0},
868                            /*reducer=*/kAdd},
869 
870     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
871                            /*window_bounds=*/{1, 1, 4, 1},
872                            /*strides=*/{1, 1, 4, 1},
873                            /*pad_low=*/{0, 0, 1, 0},
874                            /*pad_high=*/{0, 0, 2, 0},
875                            /*layout=*/{3, 2, 1, 0},
876                            /*reducer=*/kMax},
877 
878     // Patterns generated by cumsum/cumprod.
879     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
880                            /*window_bounds=*/{1021, 1, 1, 1},
881                            /*strides=*/{1, 1, 1, 1},
882                            /*pad_low=*/{1020, 0, 0, 0},
883                            /*pad_high=*/{0, 0, 0, 0},
884                            /*layout=*/{3, 2, 1, 0},
885                            /*reducer=*/kAdd},
886 
887     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
888                            /*window_bounds=*/{1, 1, 1021, 1},
889                            /*strides=*/{1, 1, 1, 1},
890                            /*pad_low=*/{0, 0, 1020, 0},
891                            /*pad_high=*/{0, 0, 0, 0},
892                            /*layout=*/{3, 2, 1, 0},
893                            /*reducer=*/kAdd},
894 
895     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
896                            /*window_bounds=*/{1, 1, 1, 1021},
897                            /*strides=*/{1, 1, 1, 1},
898                            /*pad_low=*/{0, 0, 0, 1020},
899                            /*pad_high=*/{0, 0, 0, 0},
900                            /*layout=*/{3, 2, 1, 0},
901                            /*reducer=*/kAdd},
902 
903     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
904                            /*window_bounds=*/{1021, 1, 1, 1},
905                            /*strides=*/{1, 1, 1, 1},
906                            /*pad_low=*/{1021, 0, 0, 0},
907                            /*pad_high=*/{0, 0, 0, 0},
908                            /*layout=*/{3, 2, 1, 0},
909                            /*reducer=*/kAdd},
910 
911     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
912                            /*window_bounds=*/{1, 1, 1021, 1},
913                            /*strides=*/{1, 1, 1, 1},
914                            /*pad_low=*/{0, 0, 1021, 0},
915                            /*pad_high=*/{0, 0, 0, 0},
916                            /*layout=*/{3, 2, 1, 0},
917                            /*reducer=*/kAdd},
918 
919     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
920                            /*window_bounds=*/{1, 1, 1, 1021},
921                            /*strides=*/{1, 1, 1, 1},
922                            /*pad_low=*/{0, 0, 0, 1021},
923                            /*pad_high=*/{0, 0, 0, 0},
924                            /*layout=*/{3, 2, 1, 0},
925                            /*reducer=*/kAdd},
926 };
927 
928 INSTANTIATE_TEST_CASE_P(
929     R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
930     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
931                        ::testing::ValuesIn(use_bfloat16_params)),
932     R4ReduceWindowTestDataToString);
933 
934 struct R3ReduceWindowTestData {
935   int64 base_bounds[3];
936   int64 window_bounds[3];
937   int64 strides[3];
938   int64 layout[3];
939   Padding padding;
940   Reducer reducer;
941 } kR3TestCases[] = {
942     {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
943      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
944      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
945     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
946      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
947      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
948     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
949      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
950      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
951     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
952      /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
953      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
954     {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
955      /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
956      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
957     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
958      /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
959      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
960     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
961      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
962      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
963     {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
964      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
965      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
966     {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
967      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
968      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
969     {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
970      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
971      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
972     {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
973      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
974      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
975     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
976      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
977      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
978     {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
979      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
980      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
981     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
982      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
983      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
984 };
985 
R3ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R3ReduceWindowTestData,bool>> & data)986 string R3ReduceWindowTestDataToString(
987     const ::testing::TestParamInfo<
988         ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
989   const auto& param = ::testing::get<0>(data.param);
990   string str = absl::StrCat(
991       "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
992       absl::StrJoin(param.window_bounds, "x"), "__strides_",
993       absl::StrJoin(param.strides, "x"), "__padding_",
994       param.padding == Padding::kSame ? "same" : "valid", "__layout_",
995       param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
996       param.reducer == kAdd ? "add" : "max");
997   if (::testing::get<1>(data.param)) {
998     absl::StrAppend(&str, "_bfloat16");
999   }
1000   return str;
1001 }
1002 
1003 class R3ReduceWindowTest : public ReduceWindowTestBase,
1004                            public ::testing::WithParamInterface<
1005                                ::testing::tuple<R3ReduceWindowTestData, bool>> {
1006  protected:
R3ReduceWindowTest()1007   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1008 };
1009 
TEST_P(R3ReduceWindowTest,DoIt)1010 TEST_P(R3ReduceWindowTest, DoIt) {
1011   XlaBuilder b(TestName());
1012   const auto& param = ::testing::get<0>(GetParam());
1013 
1014   const float kInitValue = 0.0f;
1015   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
1016                        param.base_bounds[2]);
1017   // Choose a prime iota length so that each window sees a unique set of values.
1018   // (Technically, the requirement is that the iota length is relatively prime
1019   // to all of the dimensions involved in the reduce-window.)
1020   input.FillRepeatedIota(0, 137);
1021   Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
1022       input, LayoutUtil::MakeLayout(param.layout));
1023   auto reducer = param.reducer;
1024   if (use_bfloat16()) {
1025     input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
1026 
1027     // To avoid numerical issues, force the reducer to be kMax for bf16
1028     // inputs.
1029     reducer = kMax;
1030   }
1031 
1032   XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
1033   auto init_value =
1034       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1035 
1036   auto computation = reducer == kAdd
1037                          ? CreateScalarAddComputation(FloatType(), &b)
1038                          : CreateScalarMaxComputation(FloatType(), &b);
1039 
1040   ReduceWindow(/*operand=*/parameter,
1041                /*init_value=*/init_value,
1042                /*computation=*/computation,
1043                /*window_dimensions=*/param.window_bounds,
1044                /*window_strides=*/param.strides, /*padding=*/param.padding);
1045 
1046   ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
1047 }
1048 
1049 INSTANTIATE_TEST_CASE_P(
1050     R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
1051     ::testing::Combine(::testing::ValuesIn(kR3TestCases),
1052                        ::testing::ValuesIn(use_bfloat16_params)),
1053     R3ReduceWindowTestDataToString);
1054 
1055 struct R2ReduceWindowTestData {
1056   int64 base_bounds[2];
1057   int64 window_bounds[2];
1058   int64 strides[2];
1059   int64 pad_low[2];
1060   int64 pad_high[2];
1061   int64 layout[2];
1062   Reducer reducer;
1063 } kR2TestCases[] = {
1064     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1065      /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1066      /*layout=*/{0, 1},
1067      /*reducer=*/Reducer::kAdd},
1068     {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
1069      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
1070      /*layout=*/{0, 1},
1071      /*reducer=*/Reducer::kAdd},
1072     {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
1073      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1074      /*layout=*/{0, 1},
1075      /*reducer=*/Reducer::kAdd},
1076     {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
1077      /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
1078      /*layout=*/{0, 1},
1079      /*reducer=*/Reducer::kAdd},
1080 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
1081 // ptxas bug.
1082 #ifndef XLA_TEST_BACKEND_GPU
1083     {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
1084      /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
1085      /*layout=*/{0, 1},
1086      /*reducer=*/Reducer::kAdd},
1087 #endif
1088     {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
1089      /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
1090      /*layout=*/{0, 1},
1091      /*reducer=*/Reducer::kAdd},
1092     {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
1093      /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
1094      /*layout=*/{1, 0},
1095      /*reducer=*/Reducer::kAdd},
1096     {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
1097      /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
1098      /*layout=*/{1, 0},
1099      /*reducer=*/Reducer::kAdd},
1100     // Regression test for a bug that appeared in Inception (b/34784899).
1101     {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
1102      /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
1103      /*layout=*/{1, 0},
1104      /*reducer=*/Reducer::kAdd},
1105     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1106      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1107      /*layout=*/{1, 0},
1108      /*reducer=*/Reducer::kAdd},
1109     // Regression test for a bug that appeared in Inception (b/34784899).
1110     {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
1111      /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1112      /*layout=*/{1, 0},
1113      /*reducer=*/Reducer::kAdd},
1114     // Regression test for b/73903312: bf16 lacks precision to store result of
1115     // very large windows. Testing with a reasonable window larger than 128.
1116     {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
1117      /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
1118      /*layout=*/{1, 0},
1119      /*reducer=*/Reducer::kAdd},
1120     {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
1121      /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1122      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1123     {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
1124      /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1125      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1126     // Regression test for b/72234705: bf16 lacks precision to store incremental
1127     // results on very large windows. Using smaller window with minor dim 128.
1128     {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
1129      /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1130      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1131 };
1132 
R2ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R2ReduceWindowTestData,bool>> & data)1133 string R2ReduceWindowTestDataToString(
1134     const ::testing::TestParamInfo<
1135         ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
1136   const auto& param = ::testing::get<0>(data.param);
1137   string str = absl::StrCat(
1138       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
1139       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
1140       "__strides_", absl::StrJoin(param.strides, "x"),              //
1141       "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
1142       absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
1143       param.layout[1],  //
1144       "__reducer_", param.reducer == kAdd ? "add" : "max");
1145   if (::testing::get<1>(data.param)) {
1146     absl::StrAppend(&str, "_bfloat16");
1147   }
1148   return str;
1149 }
1150 
1151 class R2ReduceWindowTest : public ReduceWindowTestBase,
1152                            public ::testing::WithParamInterface<
1153                                ::testing::tuple<R2ReduceWindowTestData, bool>> {
1154  protected:
R2ReduceWindowTest()1155   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1156 
DoIt()1157   void DoIt() {
1158     XlaBuilder b(TestName());
1159     const auto& param = ::testing::get<0>(GetParam());
1160 
1161     const float kInitValue = 0.0f;
1162     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
1163     Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
1164         input, LayoutUtil::MakeLayout(param.layout));
1165 
1166     XlaOp parameter;
1167     auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
1168                                                        &b, &parameter);
1169     std::vector<std::pair<int64, int64>> padding(2);
1170     for (int i = 0; i < 2; ++i) {
1171       padding[i] = {param.pad_low[i], param.pad_high[i]};
1172     }
1173     auto computation = param.reducer == kAdd
1174                            ? CreateScalarAddComputation(FloatType(), &b)
1175                            : CreateScalarMaxComputation(FloatType(), &b);
1176     auto init_value =
1177         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1178     ReduceWindowWithGeneralPadding(
1179         /*operand=*/parameter,
1180         /*init_value=*/init_value,
1181         /*computation=*/computation,
1182         /*window_dimensions=*/param.window_bounds,
1183         /*window_strides=*/param.strides,
1184         /*base_dilations=*/{},
1185         /*window_dilations=*/{},
1186         /*padding=*/padding);
1187 
1188     auto reduce_func = param.reducer == kAdd
1189                            ? +[](float a, float b) { return a + b; }
1190                            : +[](float a, float b) { return std::max(a, b); };
1191     auto expected = ReferenceUtil::ReduceWindow2DGeneric(
1192         /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func,
1193         /*window=*/param.window_bounds,
1194         /*stride=*/param.strides, /*padding=*/padding);
1195 
1196     ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
1197                              {input_arg.get()}, DefaultErrorSpec());
1198   }
1199 };
1200 
TEST_P(R2ReduceWindowTest,DoIt)1201 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
1202 
1203 INSTANTIATE_TEST_CASE_P(
1204     R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
1205     ::testing::Combine(::testing::ValuesIn(kR2TestCases),
1206                        ::testing::ValuesIn(use_bfloat16_params)),
1207     R2ReduceWindowTestDataToString);
1208 
1209 struct R1ReduceWindowTestData {
1210   int64 base_bounds[1];
1211   int64 window_bounds[1];
1212   int64 strides[1];
1213   int64 pad_low[1];
1214   int64 pad_high[1];
1215   Reducer reducer;
1216 } kR1TestCases[] = {
1217     {/*base_bounds=*/{1}, /*window_bounds=*/{1},
1218      /*strides=*/{1},
1219      /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
1220      /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
1221      /*reducer=*/Reducer::kAdd},
1222 
1223     {/*base_bounds=*/{3}, /*window_bounds=*/{3},
1224      /*strides=*/{1},
1225      /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
1226      /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
1227      /*reducer=*/Reducer::kAdd},
1228 
1229     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1230      /*strides=*/{1},
1231      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
1232      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
1233      /*reducer=*/Reducer::kAdd},
1234 
1235     {/*base_bounds=*/{5}, /*window_bounds=*/{1},
1236      /*strides=*/{1},
1237      /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
1238      /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
1239      /*reducer=*/Reducer::kMax},
1240 
1241     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1242      /*strides=*/{4},
1243      /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
1244      /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
1245      /*reducer=*/Reducer::kMax},
1246 
1247     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1248      /*strides=*/{3},
1249      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
1250      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
1251      /*reducer=*/Reducer::kAdd},
1252 
1253     {/*base_bounds=*/{128 * 2},
1254      /*window_bounds=*/{30},
1255      /*strides=*/{27},
1256      /*pad_low=*/
1257      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
1258      /*pad_high=*/
1259      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
1260      /*reducer=*/Reducer::kAdd},
1261 
1262     {/*base_bounds=*/{128 * 17},
1263      /*window_bounds=*/{7},
1264      /*strides=*/{64},
1265      /*pad_low=*/
1266      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
1267      /*pad_high=*/
1268      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
1269      /*reducer=*/Reducer::kAdd},
1270 
1271     {/*base_bounds=*/{128 * 2},
1272      /*window_bounds=*/{32},
1273      /*strides=*/{56},
1274      /*pad_low=*/
1275      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
1276      /*pad_high=*/
1277      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
1278      /*reducer=*/Reducer::kAdd},
1279 
1280     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1281      /*strides=*/{1},
1282      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
1283      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
1284      /*reducer=*/Reducer::kAdd},
1285 
1286     {/*base_bounds=*/{5}, /*window_bounds=*/{3},
1287      /*strides=*/{2},
1288      /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
1289      /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
1290      /*reducer=*/Reducer::kAdd},
1291 
1292     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1293      /*strides=*/{3},
1294      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
1295      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
1296      /*reducer=*/Reducer::kAdd},
1297 
1298     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1299      /*strides=*/{1},
1300      /*pad_low=*/{0},
1301      /*pad_high=*/{5},
1302      /*reducer=*/Reducer::kAdd},
1303 
1304     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1305      /*strides=*/{1},
1306      /*pad_low=*/{5},
1307      /*pad_high=*/{0},
1308      /*reducer=*/Reducer::kAdd},
1309 
1310     // The pattern generated by inclusive scan (cumsum/cumprod).
1311     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1312      /*strides=*/{1},
1313      /*pad_low=*/{4095},
1314      /*pad_high=*/{0},
1315      /*reducer=*/Reducer::kMax},
1316 
1317     // The pattern generated by exclusive scan (cumsum/cumprod).
1318     {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1319      /*strides=*/{1},
1320      /*pad_low=*/{4095},
1321      /*pad_high=*/{0},
1322      /*reducer=*/Reducer::kMax},
1323 };
1324 
R1ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R1ReduceWindowTestData,bool>> & data)1325 string R1ReduceWindowTestDataToString(
1326     const ::testing::TestParamInfo<
1327         ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
1328   const auto& param = ::testing::get<0>(data.param);
1329   string str =
1330       absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
1331                    "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
1332                    "__strides_", absl::StrJoin(param.strides, "x"),
1333                    "__pad_low_", absl::StrJoin(param.pad_low, "x"),
1334                    "__pad_high_", absl::StrJoin(param.pad_high, "x"),
1335                    "__reducer_", param.reducer == kAdd ? "add" : "max");
1336   if (::testing::get<1>(data.param)) {
1337     absl::StrAppend(&str, "_bfloat16");
1338   }
1339   return str;
1340 }
1341 
1342 class R1ReduceWindowTest : public ReduceWindowTestBase,
1343                            public ::testing::WithParamInterface<
1344                                ::testing::tuple<R1ReduceWindowTestData, bool>> {
1345  protected:
R1ReduceWindowTest()1346   R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1347 };
1348 
TEST_P(R1ReduceWindowTest,DoIt)1349 TEST_P(R1ReduceWindowTest, DoIt) {
1350   XlaBuilder b(TestName());
1351   const auto& param = ::testing::get<0>(GetParam());
1352   CHECK(param.reducer == kAdd || param.reducer == kMax);
1353 
1354   const float kInitValue = 0.0f;
1355   std::vector<float> input_vector(param.base_bounds[0]);
1356   std::iota(std::begin(input_vector), std::end(input_vector), 0);
1357   Literal input_literal =
1358       LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
1359   XlaOp parameter;
1360   auto input_arg =
1361       CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
1362 
1363   std::vector<std::pair<int64, int64>> padding(1);
1364   padding[0] = {param.pad_low[0], param.pad_high[0]};
1365 
1366   auto computation = param.reducer == kAdd
1367                          ? CreateScalarAddComputation(FloatType(), &b)
1368                          : CreateScalarMaxComputation(FloatType(), &b);
1369   auto init_value =
1370       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1371   ReduceWindowWithGeneralPadding(
1372       /*operand=*/parameter,
1373       /*init_value=*/init_value,
1374       /*computation=*/computation,
1375       /*window_dimensions=*/param.window_bounds,
1376       /*window_strides=*/param.strides,
1377       /*base_dilations=*/{},
1378       /*window_dilations=*/{},
1379       /*padding=*/padding);
1380 
1381   auto reduce_func = param.reducer == kAdd
1382                          ? +[](float a, float b) { return a + b; }
1383                          : +[](float a, float b) { return std::max(a, b); };
1384   auto expected = ReferenceUtil::ReduceWindow1DGeneric(
1385       /*operand=*/absl::Span<const float>(input_vector),
1386       /*init=*/kInitValue,
1387       /*reduce_func=*/reduce_func,
1388       /*window=*/param.window_bounds,
1389       /*stride=*/param.strides,
1390       /*padding=*/padding);
1391 
1392   ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
1393                            {input_arg.get()}, DefaultErrorSpec());
1394 }
1395 
1396 INSTANTIATE_TEST_CASE_P(
1397     R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
1398     ::testing::Combine(::testing::ValuesIn(kR1TestCases),
1399                        ::testing::ValuesIn(use_bfloat16_params)),
1400     R1ReduceWindowTestDataToString);
1401 
1402 // Test class for text-based test cases. Note that this compares with the
1403 // results on the interpreter backend.
1404 class ReduceWindowTextTest : public HloTestBase {};
1405 
XLA_TEST_F(ReduceWindowTextTest,R2General256x384)1406 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
1407   const string hlo_string = R"(
1408 HloModule R2Window
1409 mul {
1410   lhs = f32[] parameter(0)
1411   rhs = f32[] parameter(1)
1412   ROOT mul = f32[] multiply(lhs, rhs)
1413 }
1414 ENTRY R2Window {
1415   operand = f32[256,384]{1,0} parameter(0)
1416   constant = f32[] constant(1)
1417   ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1418 }
1419 )";
1420   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1421 }
1422 
XLA_TEST_F(ReduceWindowTextTest,R2General256x384Layout01)1423 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
1424   const string hlo_string = R"(
1425 HloModule R2Window
1426 mul {
1427 lhs = f32[] parameter(0)
1428 rhs = f32[] parameter(1)
1429 ROOT mul = f32[] multiply(lhs, rhs)
1430 }
1431 ENTRY R2Window {
1432 operand = f32[256,384]{0,1} parameter(0)
1433 constant = f32[] constant(1)
1434 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1435 }
1436 )";
1437   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1438 }
1439 
XLA_TEST_F(ReduceWindowTextTest,R2General2x5)1440 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
1441   const string hlo_string = R"(
1442 HloModule R2Window
1443 mul {
1444   lhs = f32[] parameter(0)
1445   rhs = f32[] parameter(1)
1446   ROOT mul = f32[] multiply(lhs, rhs)
1447 }
1448 ENTRY R2Window {
1449   operand = f32[2,5]{1,0} parameter(0)
1450   constant = f32[] constant(1)
1451   ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
1452 }
1453 )";
1454   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1455 }
1456 
XLA_TEST_F(ReduceWindowTextTest,R2EffectiveScalar)1457 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
1458   const string hlo_string = R"(
1459 HloModule R2Window
1460 mul {
1461   lhs = f32[] parameter(0)
1462   rhs = f32[] parameter(1)
1463   ROOT mul = f32[] multiply(lhs, rhs)
1464 }
1465 ENTRY R2Window {
1466   operand = f32[1,1]{1,0} parameter(0)
1467   negate = f32[1,1]{1,0} negate(operand)
1468   constant = f32[] constant(1)
1469   ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
1470 }
1471 )";
1472   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1473 }
1474 
XLA_TEST_F(ReduceWindowTextTest,R3EffectiveScalar)1475 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
1476   const string hlo_string = R"(
1477 HloModule R3Window
1478 mul {
1479   lhs = f32[] parameter(0)
1480   rhs = f32[] parameter(1)
1481   ROOT mul = f32[] multiply(lhs, rhs)
1482 }
1483 ENTRY R3Window {
1484   operand = f32[1,1,1]{2,1,0} parameter(0)
1485   negate = f32[1,1,1]{2,1,0} negate(operand)
1486   constant = f32[] constant(1)
1487   ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul
1488 }
1489 )";
1490   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1491 }
1492 
XLA_TEST_F(HloTestBase,ReduceWindowIdentity)1493 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
1494   const string hlo_string = R"(
1495 HloModule ReduceWindowIdentity
1496 identity.pad_to_reduce_window {
1497   param0 = f32[] parameter(0)
1498   ROOT param1 = f32[] parameter(1)
1499 }
1500 ENTRY reduce-window-identity {
1501   operand = f32[1,32,64]{2,1,0} parameter(0)
1502   constant.4466 = f32[] constant(0)
1503   ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window
1504 }
1505 
1506 )";
1507   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1508 }
1509 
XLA_TEST_F(HloTestBase,ReduceWindowS32)1510 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
1511   const string hlo_string = R"(
1512 HloModule reduce-window
1513 
1514 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
1515   %param0 = s32[] parameter(0)
1516   ROOT %param1 = s32[] parameter(1)
1517 }
1518 
1519 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
1520   %parameter.0 = s32[81,8]{1,0} parameter(0)
1521   %parameter.1 = s32[] parameter(1)
1522   ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1523 }
1524 
1525 )";
1526   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1527 }
1528 
XLA_TEST_F(HloTestBase,ReduceWindowS64)1529 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
1530   const string hlo_string = R"(
1531 HloModule reduce-window
1532 
1533 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
1534   %param0 = s64[] parameter(0)
1535   ROOT %param1 = s64[] parameter(1)
1536 }
1537 
1538 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
1539   %parameter.0 = s64[81,8]{1,0} parameter(0)
1540   %parameter.1 = s64[] parameter(1)
1541   ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1542 }
1543 
1544 )";
1545   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1546 }
1547 
XLA_TEST_F(HloTestBase,ReduceWindowF16)1548 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
1549   const string hlo_string = R"(
1550 HloModule reduce-window
1551 
1552 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
1553   %param0 = f16[] parameter(0)
1554   ROOT %param1 = f16[] parameter(1)
1555 }
1556 
1557 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
1558   %parameter.0 = f16[81,8]{1,0} parameter(0)
1559   %parameter.1 = f16[] parameter(1)
1560   ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1561 }
1562 
1563 )";
1564   EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1565 }
1566 
1567 }  // namespace
1568 }  // namespace xla
1569