1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests the reduce-window XLA operation.
17
18 #include <limits>
19 #include <memory>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/array3d.h"
27 #include "tensorflow/compiler/xla/array4d.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/padding.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/reference_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/test_macros.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46 namespace {
47
48 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
49 // Tests both F32 and BF16.
50 static std::array<bool, 2> use_bfloat16_params{false, true};
51 #else
52 // Only tests F32.
53 static std::array<bool, 1> use_bfloat16_params{false};
54 #endif
55
56 class ReduceWindowTestBase : public ClientLibraryTestBase {
57 public:
DefaultErrorSpec() const58 ErrorSpec DefaultErrorSpec() const {
59 if (use_bfloat16()) {
60 return ErrorSpec(2e-1, 6e-2);
61 } else {
62 return ErrorSpec(1e-3, 1e-3);
63 }
64 }
65 };
66
67 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
68 public ReduceWindowTestBase {
69 public:
ReduceWindowTest()70 ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
71
ReduceWindowAdd(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)72 void ReduceWindowAdd(const XlaOp& input,
73 absl::Span<const int64> window_dimensions,
74 absl::Span<const int64> window_strides,
75 Padding padding) {
76 auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
77 &builder_);
78 ReduceWindow(input, init,
79 CreateScalarAddComputation(FloatType(), &builder_),
80 window_dimensions, window_strides, padding);
81 }
82
ReduceWindowMax(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)83 void ReduceWindowMax(const XlaOp& input,
84 absl::Span<const int64> window_dimensions,
85 absl::Span<const int64> window_strides,
86 Padding padding) {
87 auto init =
88 CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
89 ReduceWindow(input, init,
90 CreateScalarMaxComputation(FloatType(), &builder_),
91 window_dimensions, window_strides, padding);
92 }
93
ReduceWindowMin(const XlaOp & input,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)94 void ReduceWindowMin(const XlaOp& input,
95 absl::Span<const int64> window_dimensions,
96 absl::Span<const int64> window_strides,
97 Padding padding) {
98 auto init =
99 CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
100 ReduceWindow(input, init,
101 CreateScalarMinComputation(FloatType(), &builder_),
102 window_dimensions, window_strides, padding);
103 }
104
105 XlaBuilder builder_;
106 };
107
TEST_P(ReduceWindowTest,MismatchedRanksGivesErrorStatus)108 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
109 const auto input = CreateConstantFromLiteral(
110 LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
111 const auto init_value =
112 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
113 TF_ASSERT_OK(builder_.first_error());
114 ReduceWindow(input, init_value,
115 CreateScalarAddComputation(FloatType(), &builder_),
116 /*window_dimensions=*/{1, 2},
117 /*window_strides=*/{1}, Padding::kValid);
118 ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
119 << builder_.first_error();
120 ASSERT_THAT(builder_.first_error().error_message(),
121 ::testing::HasSubstr("Want input dimensions size"));
122 }
123
124 // Regression test for b/68964348.
TEST_P(ReduceWindowTest,R0ReduceWindow)125 TEST_P(ReduceWindowTest, R0ReduceWindow) {
126 const auto input =
127 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
128 const auto init =
129 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
130 ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
131 /*window_dimensions=*/{},
132 /*window_strides=*/{}, Padding::kSame);
133 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
134 ErrorSpec(0.00001));
135 }
136
TEST_P(ReduceWindowTest,Min3In5Stride2)137 TEST_P(ReduceWindowTest, Min3In5Stride2) {
138 const auto input = CreateConstantFromLiteral(
139 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
140 ReduceWindowMin(input, {3}, {2}, Padding::kValid);
141 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
142 {}, ErrorSpec(0.00001));
143 }
144
TEST_P(ReduceWindowTest,Min3In5Stride1WithSamePadding)145 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
146 const auto input = CreateConstantFromLiteral(
147 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
148 ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
149 Padding::kSame);
150 ComputeAndCompareLiteral(&builder_,
151 LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
152 {}, ErrorSpec(0.00001));
153 }
154
XLA_TEST_P(ReduceWindowTest,ZeroElementSmall)155 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
156 Array4D<float> input_array(1, 0, 2, 1);
157 const auto input = CreateConstantFromArray(input_array, &builder_);
158 Padding padding = Padding::kSame;
159 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
160
161 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
162 {1, 1, 1, 1}, padding);
163
164 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
165 DefaultErrorSpec());
166 }
167
TEST_P(ReduceWindowTest,NonSquareSmall)168 TEST_P(ReduceWindowTest, NonSquareSmall) {
169 Array4D<float> input_array(1, 2, 2, 1);
170 input_array.FillRandom(2.f, 2.f);
171 const auto input = CreateConstantFromArray(input_array, &builder_);
172
173 Padding padding = Padding::kSame;
174 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
175
176 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
177 {1, 1, 1, 1}, padding);
178
179 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
180 DefaultErrorSpec());
181 }
182
TEST_P(ReduceWindowTest,MiddleDimsSmall)183 TEST_P(ReduceWindowTest, MiddleDimsSmall) {
184 Array4D<float> input_array(1, 3, 3, 1);
185 input_array.FillRandom(2.f, 2.f);
186 const auto input = CreateConstantFromArray(input_array, &builder_);
187 Padding padding = Padding::kSame;
188 ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
189
190 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
191 {1, 2, 2, 1}, padding);
192
193 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
194 DefaultErrorSpec());
195 }
196
TEST_P(ReduceWindowTest,Along2ndMinorDim)197 TEST_P(ReduceWindowTest, Along2ndMinorDim) {
198 Array4D<float> input_array(3, 6, 7, 32);
199 input_array.FillRandom(2.f, 2.f);
200 const auto input = CreateConstantFromArray(input_array, &builder_);
201
202 // The parameters of this reduction mimic feature norm (e.g. LRN).
203 int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd
204 Padding padding = Padding::kSame;
205 ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
206
207 auto res = ReferenceUtil::ReduceWindow4DAdd(
208 input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
209
210 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
211 DefaultErrorSpec());
212 }
213
TEST_P(ReduceWindowTest,AmongMajor2Dims)214 TEST_P(ReduceWindowTest, AmongMajor2Dims) {
215 Array4D<float> input_array(4, 4, 6, 8);
216 input_array.FillWithMinorDimNum();
217 const auto input_data_handle =
218 CreateConstantFromArray(input_array, &builder_);
219
220 int win_len = 3;
221 int win_stride = 1;
222
223 Padding padding = Padding::kSame;
224 // Reduce only along the x and y dimensions, according to the win_len.
225 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
226 {win_stride, win_stride, 1, 1}, padding);
227
228 auto result = ReferenceUtil::ReduceWindow4DAdd(
229 input_array, 0.0f, {win_len, win_len, 1, 1},
230 {win_stride, win_stride, 1, 1}, padding);
231
232 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
233 DefaultErrorSpec());
234 }
235
TEST_P(ReduceWindowTest,AmongMajor2DimsMediumSize)236 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
237 Array4D<float> input_array(9, 12, 4, 89);
238 input_array.FillRandom(2.f, 2.f);
239
240 int win_len = 3;
241 int win_stride = 2;
242
243 const auto input_data_handle =
244 CreateConstantFromArray(input_array, &builder_);
245
246 Padding padding = Padding::kSame;
247 // Reduce only along the x and y dimensions, according to the win_len.
248 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
249 {win_stride, win_stride, 1, 1}, padding);
250
251 auto result = ReferenceUtil::ReduceWindow4DAdd(
252 input_array, 0.0f, {win_len, win_len, 1, 1},
253 {win_stride, win_stride, 1, 1}, padding);
254
255 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
256 DefaultErrorSpec());
257 }
258
259 // Tests the super windowing logic w.r.t handling prime number of windows in a
260 // major dimension with reduction.
TEST_P(ReduceWindowTest,PrimeWindowsInReductionDimension)261 TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
262 Array4D<float> input_array(15, 15, 4, 128);
263 input_array.FillRandom(2.f, 4.f);
264
265 int win_len = 3;
266 int win_stride = 2;
267
268 const auto input_data_handle =
269 CreateConstantFromArray(input_array, &builder_);
270
271 Padding padding = Padding::kSame;
272 // Reduce only along the x and y dimensions, according to the win_len.
273 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
274 {win_stride, win_stride, 1, 1}, padding);
275
276 auto result = ReferenceUtil::ReduceWindow4DAdd(
277 input_array, 0.0f, {win_len, win_len, 1, 1},
278 {win_stride, win_stride, 1, 1}, padding);
279
280 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
281 DefaultErrorSpec());
282 }
283
TEST_P(ReduceWindowTest,ReduceAlongLaneDimension)284 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
285 Array4D<float> input_array(19, 17, 8, 256);
286 input_array.FillWithMinorDimNum();
287
288 const auto input_data_handle =
289 CreateConstantFromArray(input_array, &builder_);
290
291 Padding padding = Padding::kSame;
292 ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
293
294 auto result = ReferenceUtil::ReduceWindow4DAdd(
295 input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
296
297 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
298 DefaultErrorSpec());
299 }
300
301 // Tests a reduction function that is not a simple add/min/max/etc.
XLA_TEST_P(ReduceWindowTest,NonstandardReduceFunction)302 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
303 Array4D<float> input_array(1, 2, 2, 1);
304 input_array(0, 0, 0, 0) = 1;
305 input_array(0, 0, 1, 0) = 2;
306 input_array(0, 1, 0, 0) = 3;
307 input_array(0, 1, 1, 0) = 4;
308 const auto input = CreateConstantFromArray(input_array, &builder_);
309
310 Padding padding = Padding::kValid;
311 const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
312 auto b = builder_.CreateSubBuilder("unusual");
313 auto lhs = Parameter(b.get(), 0, scalar, "lhs");
314 auto rhs = Parameter(b.get(), 1, scalar, "rhs");
315 Min(Add(lhs, rhs),
316 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
317 XlaComputation reduce_fn = b->BuildAndNoteError();
318
319 ReduceWindow(
320 input,
321 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
322 reduce_fn,
323 /*window_dimensions=*/{1, 1, 2, 1},
324 /*window_strides=*/{1, 1, 1, 1}, padding);
325
326 const auto reduce_func = [](float arg1, float arg2) {
327 return std::min<float>(arg1 + arg2, 8.0f);
328 };
329
330 auto expected =
331 ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
332 /*window=*/{1, 1, 2, 1},
333 /*stride=*/{1, 1, 1, 1}, padding);
334
335 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
336 {}, DefaultErrorSpec());
337 }
338
TEST_P(ReduceWindowTest,R4UnitWindow)339 TEST_P(ReduceWindowTest, R4UnitWindow) {
340 Array4D<float> input_array(13, 12, 8, 15);
341 input_array.FillRandom(2.f, 2.f);
342 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
343 input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
344 XlaOp input;
345 auto input_data = CreateParameterAndTransferLiteral(
346 0, input_literal, "parameter", &builder_, &input);
347
348 Padding padding = Padding::kSame;
349 ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
350
351 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
352 {1, 4, 1, 1}, padding);
353
354 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
355 {input_data.get()}, DefaultErrorSpec());
356 }
357
XLA_TEST_P(ReduceWindowTest,R6AddMultipleStrides)358 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
359 std::vector<int64> input_dims(6, 8);
360 auto shape = ShapeUtil::MakeShape(F32, input_dims);
361
362 Literal arg_literal(shape);
363 arg_literal.PopulateWithValue(1.0f);
364 const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
365
366 Padding padding = Padding::kValid;
367 ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
368
369 std::vector<int64> output_layout = {1, 5, 3, 2, 0, 4};
370 std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
371 Shape result_shape =
372 ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
373 Literal expected(result_shape);
374 expected.PopulateWithValue(27.0f);
375 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
376 }
377
XLA_TEST_P(ReduceWindowTest,R6Add)378 XLA_TEST_P(ReduceWindowTest, R6Add) {
379 std::vector<int64> input_dims(6, 8);
380 auto shape = ShapeUtil::MakeShape(F32, input_dims);
381
382 Literal arg_literal =
383 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
384
385 const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
386
387 Padding padding = Padding::kValid;
388 ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
389
390 std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
391 Literal expected =
392 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
393
394 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
395 }
396
XLA_TEST_P(ReduceWindowTest,R4SecondMinorStride)397 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
398 Array4D<float> input_array(2, 1, 27, 119);
399 input_array.FillRandom(2.0f);
400 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
401 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
402 XlaOp input;
403 auto input_data = CreateParameterAndTransferLiteral(
404 0, input_literal, "parameter", &builder_, &input);
405
406 int win_len = 1;
407 int stride = 8;
408 Padding padding = Padding::kSame;
409 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
410
411 auto res = ReferenceUtil::ReduceWindow4DAdd(
412 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
413
414 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
415 {input_data.get()}, DefaultErrorSpec());
416 }
417
XLA_TEST_P(ReduceWindowTest,R4SecondMinorUnitStride)418 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
419 Array4D<float> input_array(3, 2, 4, 64);
420 input_array.FillRandom(2.0f);
421 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
422 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
423 XlaOp input;
424 auto input_data = CreateParameterAndTransferLiteral(
425 0, input_literal, "parameter", &builder_, &input);
426
427 int win_len = 3;
428 int stride = 1;
429 Padding padding = Padding::kSame;
430 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
431
432 auto res = ReferenceUtil::ReduceWindow4DAdd(
433 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
434
435 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
436 {input_data.get()}, DefaultErrorSpec());
437 }
438
XLA_TEST_P(ReduceWindowTest,R4SecondMinorWin)439 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
440 Array4D<float> input_array(1, 3, 12, 200);
441 input_array.FillRandom(2.0f);
442 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
443 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
444 XlaOp input;
445 auto input_data = CreateParameterAndTransferLiteral(
446 0, input_literal, "parameter", &builder_, &input);
447
448 int win_len = 8;
449 int stride = 5;
450 Padding padding = Padding::kSame;
451 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
452
453 auto res = ReferenceUtil::ReduceWindow4DAdd(
454 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
455
456 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
457 {input_data.get()}, DefaultErrorSpec());
458 }
459
TEST_P(ReduceWindowTest,AmongMajor2DimsMultipleMinor)460 TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
461 Array4D<float> input_array(6, 4, 10, 130);
462 input_array.FillRandom(2.0f);
463
464 int win_len = 3;
465 int win_stride = 2;
466
467 Padding padding = Padding::kSame;
468 const auto input_data_handle =
469 CreateConstantFromArray(input_array, &builder_);
470 // Reduce only along the x and y dimensions, according to the win_len.
471 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
472 {win_stride, win_stride, 1, 1}, padding);
473
474 auto result = ReferenceUtil::ReduceWindow4DAdd(
475 input_array, 0.0f, {win_len, win_len, 1, 1},
476 {win_stride, win_stride, 1, 1}, padding);
477 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
478 DefaultErrorSpec());
479 }
480
XLA_TEST_P(ReduceWindowTest,Add24In1152_NoOverlap)481 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
482 std::vector<float> input_vector(128 * 9, 1);
483 const auto input = CreateConstantFromLiteral(
484 LiteralUtil::CreateR1<float>(input_vector), &builder_);
485 ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
486 ComputeAndCompareLiteral(
487 &builder_,
488 LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
489 DefaultErrorSpec());
490 }
491
XLA_TEST_P(ReduceWindowTest,Add128In128Stride128)492 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
493 std::vector<float> input_vector{
494 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
495 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
496 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
497 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
498 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
499 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
500 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
501 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
502 const auto input = CreateConstantFromLiteral(
503 LiteralUtil::CreateR1<float>(input_vector), &builder_);
504 ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
505 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
506 DefaultErrorSpec());
507 }
508
XLA_TEST_P(ReduceWindowTest,Add128In128)509 XLA_TEST_P(ReduceWindowTest, Add128In128) {
510 std::vector<float> input_vector{
511 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
512 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
513 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
514 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
515 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
516 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
517 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
518 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
519 const auto input = CreateConstantFromLiteral(
520 LiteralUtil::CreateR1<float>(input_vector), &builder_);
521 ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
522 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
523 DefaultErrorSpec());
524 }
525
526 // Regression test for a bug that appeared in Inception (b/34784899).
TEST_P(ReduceWindowTest,R2ReduceWindowInceptionFromBroadcast)527 TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
528 Array2D<float> input_array(14, 14, 1.0f);
529 const auto input = CreateConstantFromArray(input_array, &builder_);
530
531 int win_len = 3;
532 int stride = 1;
533 Padding padding = Padding::kSame;
534 ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
535
536 auto res = ReferenceUtil::ReduceWindow2DAdd(
537 input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
538
539 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
540 {}, DefaultErrorSpec());
541 }
542
TEST_P(ReduceWindowTest,R2ReduceWindowNonOverlappingFromBroadcast)543 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
544 Array2D<float> input_array(6, 4, 1.0f);
545 XlaOp input = Broadcast(
546 CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
547
548 Padding padding = Padding::kSame;
549 ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
550
551 auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
552 padding);
553
554 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
555 {}, DefaultErrorSpec());
556 }
557
558 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
559 ::testing::ValuesIn(use_bfloat16_params));
560
561 enum Reducer { kAdd, kMax };
562
563 struct R4ReduceWindowTestData {
564 int64 base_bounds[4];
565 int64 window_bounds[4];
566 int64 strides[4];
567 int64 pad_low[4];
568 int64 pad_high[4];
569 int64 layout[4];
570
571 Reducer reducer;
572 };
573
R4ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R4ReduceWindowTestData,bool>> & data)574 string R4ReduceWindowTestDataToString(
575 const ::testing::TestParamInfo<
576 ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
577 const auto& param = ::testing::get<0>(data.param);
578 string str = absl::StrCat(
579 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
580 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
581 "__strides_", absl::StrJoin(param.strides, "x"), //
582 "__pad_low_", absl::StrJoin(param.pad_low, "x"), //
583 "__pad_high_", absl::StrJoin(param.pad_high, "x"), //
584 "__layout_", absl::StrJoin(param.layout, "_"), //
585 (param.reducer == kAdd) ? "_add" : "_max");
586 CHECK(param.reducer == kAdd || param.reducer == kMax);
587
588 // Test names are not allowed to contain the '-' character.
589 std::replace(str.begin(), str.end(), '-', 'n');
590 if (::testing::get<1>(data.param)) {
591 absl::StrAppend(&str, "_bfloat16");
592 }
593 return str;
594 }
595
596 class R4ReduceWindowTest : public ReduceWindowTestBase,
597 public ::testing::WithParamInterface<
598 ::testing::tuple<R4ReduceWindowTestData, bool>> {
599 protected:
R4ReduceWindowTest()600 R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
601
DoIt()602 void DoIt() {
603 XlaBuilder b(TestName());
604 const auto& param = ::testing::get<0>(GetParam());
605
606 const float kInitValue = 0.0f;
607
608 Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
609 param.base_bounds[2], param.base_bounds[3]);
610 // Choose a prime iota length so that each window sees a unique set of
611 // values. (Technically, the requirement is that the iota length is
612 // relatively prime to all of the dimensions involved in the reduce-window.)
613 input.FillRepeatedIota(0, 137);
614 // Floating point sum reduction requires higher localized precision. We need
615 // the following normalization in order to enable testing of kAdd on large
616 // windows.
617 input.Each([&](absl::Span<const int64> /*indices*/, float* value) {
618 *value = *value / 10000000000.f;
619 });
620 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
621 input, LayoutUtil::MakeLayout(param.layout));
622 XlaOp parameter;
623 auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
624 &b, ¶meter);
625
626 std::vector<std::pair<int64, int64>> padding(4);
627 for (int i = 0; i < 4; ++i) {
628 padding[i] = {param.pad_low[i], param.pad_high[i]};
629 }
630
631 auto init_value =
632 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
633 CHECK(param.reducer == kAdd || param.reducer == kMax);
634 auto reducer = param.reducer;
635 auto computation = reducer == kAdd
636 ? CreateScalarAddComputation(FloatType(), &b)
637 : CreateScalarMaxComputation(FloatType(), &b);
638 ReduceWindowWithGeneralPadding(
639 /*operand=*/parameter,
640 /*init_value=*/init_value,
641 /*computation=*/computation,
642 /*window_dimensions=*/param.window_bounds,
643 /*window_strides=*/param.strides,
644 /*base_dilations=*/{},
645 /*window_dilations=*/{},
646 /*padding=*/padding);
647
648 CHECK(reducer == kAdd || reducer == kMax);
649 auto reduce_func = reducer == kAdd
650 ? +[](float a, float b) { return a + b; }
651 : +[](float a, float b) { return std::max(a, b); };
652 std::unique_ptr<Array4D<float>> expected =
653 ReferenceUtil::ReduceWindow4DGeneric(
654 /*operand=*/input,
655 /*init=*/kInitValue,
656 /*reduce_func=*/reduce_func,
657 /*window=*/param.window_bounds,
658 /*stride=*/param.strides,
659 /*padding=*/padding);
660 Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
661 const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
662 input_literal.shape().element_type(),
663 AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
664 ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
665 DefaultErrorSpec(), &expected_shape_with_layout);
666 }
667 };
668
TEST_P(R4ReduceWindowTest,DoIt)669 TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
670
671 // base_bounds, window_bounds, strides, pad_low, pad_high
672 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
673 // Minimal edge case.
674 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
675 /*window_bounds=*/{1, 1, 1, 1},
676 /*strides=*/{1, 1, 1, 1},
677 /*pad_low=*/{0, 0, 0, 0},
678 /*pad_high=*/{0, 0, 0, 0},
679 /*layout=*/{3, 2, 1, 0},
680 /*reducer=*/kAdd},
681
682 // Arbitrary padding (not kSame or kValid).
683 R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
684 /*window_bounds=*/{3, 3, 1, 1},
685 /*strides=*/{2, 2, 1, 1},
686 /*pad_low=*/{4, 4, 0, 0},
687 /*pad_high=*/{4, 4, 0, 0},
688 /*layout=*/{3, 2, 1, 0},
689 /*reducer=*/kAdd},
690
691 // Zero base bound edge case.
692 R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
693 /*window_bounds=*/{1, 1, 1, 1},
694 /*strides=*/{1, 1, 1, 1},
695 /*pad_low=*/{0, 0, 0, 0},
696 /*pad_high=*/{0, 0, 0, 0},
697 /*layout=*/{3, 2, 1, 0},
698 /*reducer=*/kAdd},
699
700 // With max instead of add.
701 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
702 /*window_bounds=*/{2, 3, 1, 1},
703 /*strides=*/{1, 1, 1, 1},
704 /*pad_low=*/{0, 0, 0, 0},
705 /*pad_high=*/{0, 0, 0, 0},
706 /*layout=*/{3, 2, 1, 0},
707 /*reducer=*/kMax},
708
709 // With stride.
710 R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
711 /*window_bounds=*/{3, 2, 1, 1},
712 /*strides=*/{2, 4, 1, 1},
713 /*pad_low=*/{0, 0, 0, 0},
714 /*pad_high=*/{0, 0, 0, 0},
715 /*layout=*/{3, 2, 1, 0},
716 /*reducer=*/kAdd},
717
718 // With low padding.
719 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
720 /*window_bounds=*/{3, 2, 1, 1},
721 /*strides=*/{2, 2, 1, 1},
722 /*pad_low=*/{3, 2, 0, 0},
723 /*pad_high=*/{0, 0, 0, 0},
724 /*layout=*/{3, 2, 1, 0},
725 /*reducer=*/kAdd},
726
727 // With high padding.
728 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
729 /*window_bounds=*/{3, 2, 1, 1},
730 /*strides=*/{2, 2, 1, 1},
731 /*pad_low=*/{0, 0, 0, 0},
732 /*pad_high=*/{2, 3, 0, 0},
733 /*layout=*/{3, 2, 1, 0},
734 /*reducer=*/kAdd},
735
736 // Window touches both sides of the padding simultaneously.
737 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
738 /*window_bounds=*/{3, 3, 1, 1},
739 /*strides=*/{1, 1, 1, 1},
740 /*pad_low=*/{1, 1, 0, 0},
741 /*pad_high=*/{1, 1, 0, 0},
742 /*layout=*/{3, 2, 1, 0},
743 /*reducer=*/kAdd},
744
745 // Window is entirely in the padding for some positions.
746 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
747 /*window_bounds=*/{3, 3, 1, 1},
748 /*strides=*/{1, 1, 1, 1},
749 /*pad_low=*/{4, 4, 0, 0},
750 /*pad_high=*/{4, 4, 0, 0},
751 /*layout=*/{3, 2, 1, 0},
752 /*reducer=*/kAdd},
753
754 // Zero base bound with padding edge case.
755 R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
756 /*window_bounds=*/{1, 1, 1, 1},
757 /*strides=*/{1, 1, 1, 1},
758 /*pad_low=*/{0, 1, 0, 0},
759 /*pad_high=*/{0, 0, 0, 0},
760 /*layout=*/{3, 2, 1, 0},
761 /*reducer=*/kAdd},
762
763 // With stride, low padding and high padding.
764 R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
765 /*window_bounds=*/{3, 4, 1, 1},
766 /*strides=*/{3, 1, 1, 1},
767 /*pad_low=*/{10, 1, 0, 0},
768 /*pad_high=*/{2, 3, 0, 0},
769 /*layout=*/{3, 2, 1, 0},
770 /*reducer=*/kAdd},
771
772 // With minor dimension == 129.
773 R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
774 /*window_bounds=*/{1, 1, 1, 1},
775 /*strides=*/{1, 1, 1, 1},
776 /*pad_low=*/{0, 0, 0, 0},
777 /*pad_high=*/{0, 0, 0, 0},
778 /*layout=*/{3, 2, 1, 0},
779 /*reducer=*/kAdd},
780
781 // With minor dims reduction and non-overlapped stride.
782 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
783 /*window_bounds=*/{1, 1, 2, 2},
784 /*strides=*/{1, 1, 2, 2},
785 /*pad_low=*/{0, 0, 0, 0},
786 /*pad_high=*/{0, 0, 0, 0},
787 /*layout=*/{3, 2, 1, 0},
788 /*reducer=*/kAdd},
789
790 // With minor dims reduction and overlapped stride.
791 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
792 /*window_bounds=*/{1, 1, 4, 4},
793 /*strides=*/{1, 1, 2, 2},
794 /*pad_low=*/{0, 0, 0, 0},
795 /*pad_high=*/{1, 0, 0, 0},
796 /*layout=*/{3, 2, 1, 0},
797 /*reducer=*/kAdd},
798
799 R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
800 /*window_bounds=*/{1, 64, 64, 1},
801 /*strides=*/{1, 64, 64, 1},
802 /*pad_low=*/{0, 0, 0, 0},
803 /*pad_high=*/{0, 0, 0, 0},
804 /*layout=*/{3, 0, 2, 1},
805 /*reducer=*/kAdd},
806
807 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
808 /*window_bounds=*/{112, 112, 1, 8},
809 /*strides=*/{112, 112, 1, 8},
810 /*pad_low=*/{0, 0, 0, 0},
811 /*pad_high=*/{0, 0, 0, 0},
812 /*layout=*/{3, 2, 1, 0},
813 /*reducer=*/kMax},
814
815 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
816 /*window_bounds=*/{2, 3, 4, 5},
817 /*strides=*/{1, 1, 1, 1},
818 /*pad_low=*/{0, 0, 0, 0},
819 /*pad_high=*/{0, 0, 0, 0},
820 /*layout=*/{3, 2, 1, 0},
821 /*reducer=*/kAdd},
822
823 // With 0321 layout.
824 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
825 /*window_bounds=*/{2, 3, 4, 5},
826 /*strides=*/{1, 2, 3, 4},
827 /*pad_low=*/{0, 0, 0, 0},
828 /*pad_high=*/{0, 0, 0, 0},
829 /*layout=*/{0, 3, 2, 1},
830 /*reducer=*/kAdd},
831
832 // With 0123 layout.
833 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
834 /*window_bounds=*/{2, 3, 7, 9},
835 /*strides=*/{1, 2, 5, 8},
836 /*pad_low=*/{0, 0, 0, 0},
837 /*pad_high=*/{0, 0, 0, 0},
838 /*layout=*/{0, 1, 2, 3},
839 /*reducer=*/kAdd},
840 };
841
842 INSTANTIATE_TEST_CASE_P(
843 R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
844 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
845 ::testing::ValuesIn(use_bfloat16_params)),
846 R4ReduceWindowTestDataToString);
847
848 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
849
XLA_TEST_P(R4ReduceWindowLargeTest,DISABLED_ON_INTERPRETER (DoIt))850 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
851
852 // Test cases that are large/slow/failed.
853 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
854 R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
855 /*window_bounds=*/{3, 3, 1, 5},
856 /*strides=*/{1, 1, 1, 5},
857 /*pad_low=*/{1, 1, 0, 0},
858 /*pad_high=*/{1, 1, 0, 0},
859 /*layout=*/{3, 2, 1, 0},
860 /*reducer=*/kMax},
861
862 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
863 /*window_bounds=*/{3, 3, 1, 1},
864 /*strides=*/{2, 2, 1, 1},
865 /*pad_low=*/{0, 0, 0, 0},
866 /*pad_high=*/{1, 1, 0, 0},
867 /*layout=*/{3, 2, 1, 0},
868 /*reducer=*/kAdd},
869
870 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
871 /*window_bounds=*/{1, 1, 4, 1},
872 /*strides=*/{1, 1, 4, 1},
873 /*pad_low=*/{0, 0, 1, 0},
874 /*pad_high=*/{0, 0, 2, 0},
875 /*layout=*/{3, 2, 1, 0},
876 /*reducer=*/kMax},
877
878 // Patterns generated by cumsum/cumprod.
879 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
880 /*window_bounds=*/{1021, 1, 1, 1},
881 /*strides=*/{1, 1, 1, 1},
882 /*pad_low=*/{1020, 0, 0, 0},
883 /*pad_high=*/{0, 0, 0, 0},
884 /*layout=*/{3, 2, 1, 0},
885 /*reducer=*/kAdd},
886
887 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
888 /*window_bounds=*/{1, 1, 1021, 1},
889 /*strides=*/{1, 1, 1, 1},
890 /*pad_low=*/{0, 0, 1020, 0},
891 /*pad_high=*/{0, 0, 0, 0},
892 /*layout=*/{3, 2, 1, 0},
893 /*reducer=*/kAdd},
894
895 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
896 /*window_bounds=*/{1, 1, 1, 1021},
897 /*strides=*/{1, 1, 1, 1},
898 /*pad_low=*/{0, 0, 0, 1020},
899 /*pad_high=*/{0, 0, 0, 0},
900 /*layout=*/{3, 2, 1, 0},
901 /*reducer=*/kAdd},
902
903 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
904 /*window_bounds=*/{1021, 1, 1, 1},
905 /*strides=*/{1, 1, 1, 1},
906 /*pad_low=*/{1021, 0, 0, 0},
907 /*pad_high=*/{0, 0, 0, 0},
908 /*layout=*/{3, 2, 1, 0},
909 /*reducer=*/kAdd},
910
911 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
912 /*window_bounds=*/{1, 1, 1021, 1},
913 /*strides=*/{1, 1, 1, 1},
914 /*pad_low=*/{0, 0, 1021, 0},
915 /*pad_high=*/{0, 0, 0, 0},
916 /*layout=*/{3, 2, 1, 0},
917 /*reducer=*/kAdd},
918
919 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
920 /*window_bounds=*/{1, 1, 1, 1021},
921 /*strides=*/{1, 1, 1, 1},
922 /*pad_low=*/{0, 0, 0, 1021},
923 /*pad_high=*/{0, 0, 0, 0},
924 /*layout=*/{3, 2, 1, 0},
925 /*reducer=*/kAdd},
926 };
927
928 INSTANTIATE_TEST_CASE_P(
929 R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
930 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
931 ::testing::ValuesIn(use_bfloat16_params)),
932 R4ReduceWindowTestDataToString);
933
934 struct R3ReduceWindowTestData {
935 int64 base_bounds[3];
936 int64 window_bounds[3];
937 int64 strides[3];
938 int64 layout[3];
939 Padding padding;
940 Reducer reducer;
941 } kR3TestCases[] = {
942 {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
943 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
944 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
945 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
946 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
947 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
948 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
949 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
950 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
951 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
952 /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
953 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
954 {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
955 /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
956 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
957 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
958 /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
959 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
960 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
961 /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
962 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
963 {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
964 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
965 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
966 {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
967 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
968 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
969 {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
970 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
971 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
972 {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
973 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
974 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
975 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
976 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
977 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
978 {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
979 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
980 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
981 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
982 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
983 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
984 };
985
R3ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R3ReduceWindowTestData,bool>> & data)986 string R3ReduceWindowTestDataToString(
987 const ::testing::TestParamInfo<
988 ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
989 const auto& param = ::testing::get<0>(data.param);
990 string str = absl::StrCat(
991 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
992 absl::StrJoin(param.window_bounds, "x"), "__strides_",
993 absl::StrJoin(param.strides, "x"), "__padding_",
994 param.padding == Padding::kSame ? "same" : "valid", "__layout_",
995 param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
996 param.reducer == kAdd ? "add" : "max");
997 if (::testing::get<1>(data.param)) {
998 absl::StrAppend(&str, "_bfloat16");
999 }
1000 return str;
1001 }
1002
1003 class R3ReduceWindowTest : public ReduceWindowTestBase,
1004 public ::testing::WithParamInterface<
1005 ::testing::tuple<R3ReduceWindowTestData, bool>> {
1006 protected:
R3ReduceWindowTest()1007 R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1008 };
1009
TEST_P(R3ReduceWindowTest,DoIt)1010 TEST_P(R3ReduceWindowTest, DoIt) {
1011 XlaBuilder b(TestName());
1012 const auto& param = ::testing::get<0>(GetParam());
1013
1014 const float kInitValue = 0.0f;
1015 Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
1016 param.base_bounds[2]);
1017 // Choose a prime iota length so that each window sees a unique set of values.
1018 // (Technically, the requirement is that the iota length is relatively prime
1019 // to all of the dimensions involved in the reduce-window.)
1020 input.FillRepeatedIota(0, 137);
1021 Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
1022 input, LayoutUtil::MakeLayout(param.layout));
1023 auto reducer = param.reducer;
1024 if (use_bfloat16()) {
1025 input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
1026
1027 // To avoid numerical issues, force the reducer to be kMax for bf16
1028 // inputs.
1029 reducer = kMax;
1030 }
1031
1032 XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
1033 auto init_value =
1034 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1035
1036 auto computation = reducer == kAdd
1037 ? CreateScalarAddComputation(FloatType(), &b)
1038 : CreateScalarMaxComputation(FloatType(), &b);
1039
1040 ReduceWindow(/*operand=*/parameter,
1041 /*init_value=*/init_value,
1042 /*computation=*/computation,
1043 /*window_dimensions=*/param.window_bounds,
1044 /*window_strides=*/param.strides, /*padding=*/param.padding);
1045
1046 ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
1047 }
1048
1049 INSTANTIATE_TEST_CASE_P(
1050 R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
1051 ::testing::Combine(::testing::ValuesIn(kR3TestCases),
1052 ::testing::ValuesIn(use_bfloat16_params)),
1053 R3ReduceWindowTestDataToString);
1054
1055 struct R2ReduceWindowTestData {
1056 int64 base_bounds[2];
1057 int64 window_bounds[2];
1058 int64 strides[2];
1059 int64 pad_low[2];
1060 int64 pad_high[2];
1061 int64 layout[2];
1062 Reducer reducer;
1063 } kR2TestCases[] = {
1064 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1065 /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1066 /*layout=*/{0, 1},
1067 /*reducer=*/Reducer::kAdd},
1068 {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
1069 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
1070 /*layout=*/{0, 1},
1071 /*reducer=*/Reducer::kAdd},
1072 {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
1073 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1074 /*layout=*/{0, 1},
1075 /*reducer=*/Reducer::kAdd},
1076 {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
1077 /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
1078 /*layout=*/{0, 1},
1079 /*reducer=*/Reducer::kAdd},
1080 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
1081 // ptxas bug.
1082 #ifndef XLA_TEST_BACKEND_GPU
1083 {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
1084 /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
1085 /*layout=*/{0, 1},
1086 /*reducer=*/Reducer::kAdd},
1087 #endif
1088 {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
1089 /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
1090 /*layout=*/{0, 1},
1091 /*reducer=*/Reducer::kAdd},
1092 {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
1093 /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
1094 /*layout=*/{1, 0},
1095 /*reducer=*/Reducer::kAdd},
1096 {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
1097 /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
1098 /*layout=*/{1, 0},
1099 /*reducer=*/Reducer::kAdd},
1100 // Regression test for a bug that appeared in Inception (b/34784899).
1101 {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
1102 /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
1103 /*layout=*/{1, 0},
1104 /*reducer=*/Reducer::kAdd},
1105 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1106 /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1107 /*layout=*/{1, 0},
1108 /*reducer=*/Reducer::kAdd},
1109 // Regression test for a bug that appeared in Inception (b/34784899).
1110 {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
1111 /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1112 /*layout=*/{1, 0},
1113 /*reducer=*/Reducer::kAdd},
1114 // Regression test for b/73903312: bf16 lacks precision to store result of
1115 // very large windows. Testing with a reasonable window larger than 128.
1116 {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
1117 /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
1118 /*layout=*/{1, 0},
1119 /*reducer=*/Reducer::kAdd},
1120 {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
1121 /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1122 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1123 {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
1124 /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1125 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1126 // Regression test for b/72234705: bf16 lacks precision to store incremental
1127 // results on very large windows. Using smaller window with minor dim 128.
1128 {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
1129 /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1130 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1131 };
1132
R2ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R2ReduceWindowTestData,bool>> & data)1133 string R2ReduceWindowTestDataToString(
1134 const ::testing::TestParamInfo<
1135 ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
1136 const auto& param = ::testing::get<0>(data.param);
1137 string str = absl::StrCat(
1138 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
1139 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
1140 "__strides_", absl::StrJoin(param.strides, "x"), //
1141 "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
1142 absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
1143 param.layout[1], //
1144 "__reducer_", param.reducer == kAdd ? "add" : "max");
1145 if (::testing::get<1>(data.param)) {
1146 absl::StrAppend(&str, "_bfloat16");
1147 }
1148 return str;
1149 }
1150
1151 class R2ReduceWindowTest : public ReduceWindowTestBase,
1152 public ::testing::WithParamInterface<
1153 ::testing::tuple<R2ReduceWindowTestData, bool>> {
1154 protected:
R2ReduceWindowTest()1155 R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1156
DoIt()1157 void DoIt() {
1158 XlaBuilder b(TestName());
1159 const auto& param = ::testing::get<0>(GetParam());
1160
1161 const float kInitValue = 0.0f;
1162 Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
1163 Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
1164 input, LayoutUtil::MakeLayout(param.layout));
1165
1166 XlaOp parameter;
1167 auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
1168 &b, ¶meter);
1169 std::vector<std::pair<int64, int64>> padding(2);
1170 for (int i = 0; i < 2; ++i) {
1171 padding[i] = {param.pad_low[i], param.pad_high[i]};
1172 }
1173 auto computation = param.reducer == kAdd
1174 ? CreateScalarAddComputation(FloatType(), &b)
1175 : CreateScalarMaxComputation(FloatType(), &b);
1176 auto init_value =
1177 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1178 ReduceWindowWithGeneralPadding(
1179 /*operand=*/parameter,
1180 /*init_value=*/init_value,
1181 /*computation=*/computation,
1182 /*window_dimensions=*/param.window_bounds,
1183 /*window_strides=*/param.strides,
1184 /*base_dilations=*/{},
1185 /*window_dilations=*/{},
1186 /*padding=*/padding);
1187
1188 auto reduce_func = param.reducer == kAdd
1189 ? +[](float a, float b) { return a + b; }
1190 : +[](float a, float b) { return std::max(a, b); };
1191 auto expected = ReferenceUtil::ReduceWindow2DGeneric(
1192 /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func,
1193 /*window=*/param.window_bounds,
1194 /*stride=*/param.strides, /*padding=*/padding);
1195
1196 ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
1197 {input_arg.get()}, DefaultErrorSpec());
1198 }
1199 };
1200
TEST_P(R2ReduceWindowTest,DoIt)1201 TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
1202
1203 INSTANTIATE_TEST_CASE_P(
1204 R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
1205 ::testing::Combine(::testing::ValuesIn(kR2TestCases),
1206 ::testing::ValuesIn(use_bfloat16_params)),
1207 R2ReduceWindowTestDataToString);
1208
1209 struct R1ReduceWindowTestData {
1210 int64 base_bounds[1];
1211 int64 window_bounds[1];
1212 int64 strides[1];
1213 int64 pad_low[1];
1214 int64 pad_high[1];
1215 Reducer reducer;
1216 } kR1TestCases[] = {
1217 {/*base_bounds=*/{1}, /*window_bounds=*/{1},
1218 /*strides=*/{1},
1219 /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
1220 /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
1221 /*reducer=*/Reducer::kAdd},
1222
1223 {/*base_bounds=*/{3}, /*window_bounds=*/{3},
1224 /*strides=*/{1},
1225 /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
1226 /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
1227 /*reducer=*/Reducer::kAdd},
1228
1229 {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1230 /*strides=*/{1},
1231 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
1232 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
1233 /*reducer=*/Reducer::kAdd},
1234
1235 {/*base_bounds=*/{5}, /*window_bounds=*/{1},
1236 /*strides=*/{1},
1237 /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
1238 /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
1239 /*reducer=*/Reducer::kMax},
1240
1241 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1242 /*strides=*/{4},
1243 /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
1244 /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
1245 /*reducer=*/Reducer::kMax},
1246
1247 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1248 /*strides=*/{3},
1249 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
1250 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
1251 /*reducer=*/Reducer::kAdd},
1252
1253 {/*base_bounds=*/{128 * 2},
1254 /*window_bounds=*/{30},
1255 /*strides=*/{27},
1256 /*pad_low=*/
1257 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
1258 /*pad_high=*/
1259 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
1260 /*reducer=*/Reducer::kAdd},
1261
1262 {/*base_bounds=*/{128 * 17},
1263 /*window_bounds=*/{7},
1264 /*strides=*/{64},
1265 /*pad_low=*/
1266 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
1267 /*pad_high=*/
1268 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
1269 /*reducer=*/Reducer::kAdd},
1270
1271 {/*base_bounds=*/{128 * 2},
1272 /*window_bounds=*/{32},
1273 /*strides=*/{56},
1274 /*pad_low=*/
1275 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
1276 /*pad_high=*/
1277 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
1278 /*reducer=*/Reducer::kAdd},
1279
1280 {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1281 /*strides=*/{1},
1282 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
1283 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
1284 /*reducer=*/Reducer::kAdd},
1285
1286 {/*base_bounds=*/{5}, /*window_bounds=*/{3},
1287 /*strides=*/{2},
1288 /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
1289 /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
1290 /*reducer=*/Reducer::kAdd},
1291
1292 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1293 /*strides=*/{3},
1294 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
1295 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
1296 /*reducer=*/Reducer::kAdd},
1297
1298 {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1299 /*strides=*/{1},
1300 /*pad_low=*/{0},
1301 /*pad_high=*/{5},
1302 /*reducer=*/Reducer::kAdd},
1303
1304 {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1305 /*strides=*/{1},
1306 /*pad_low=*/{5},
1307 /*pad_high=*/{0},
1308 /*reducer=*/Reducer::kAdd},
1309
1310 // The pattern generated by inclusive scan (cumsum/cumprod).
1311 {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1312 /*strides=*/{1},
1313 /*pad_low=*/{4095},
1314 /*pad_high=*/{0},
1315 /*reducer=*/Reducer::kMax},
1316
1317 // The pattern generated by exclusive scan (cumsum/cumprod).
1318 {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1319 /*strides=*/{1},
1320 /*pad_low=*/{4095},
1321 /*pad_high=*/{0},
1322 /*reducer=*/Reducer::kMax},
1323 };
1324
R1ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R1ReduceWindowTestData,bool>> & data)1325 string R1ReduceWindowTestDataToString(
1326 const ::testing::TestParamInfo<
1327 ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
1328 const auto& param = ::testing::get<0>(data.param);
1329 string str =
1330 absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
1331 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
1332 "__strides_", absl::StrJoin(param.strides, "x"),
1333 "__pad_low_", absl::StrJoin(param.pad_low, "x"),
1334 "__pad_high_", absl::StrJoin(param.pad_high, "x"),
1335 "__reducer_", param.reducer == kAdd ? "add" : "max");
1336 if (::testing::get<1>(data.param)) {
1337 absl::StrAppend(&str, "_bfloat16");
1338 }
1339 return str;
1340 }
1341
1342 class R1ReduceWindowTest : public ReduceWindowTestBase,
1343 public ::testing::WithParamInterface<
1344 ::testing::tuple<R1ReduceWindowTestData, bool>> {
1345 protected:
R1ReduceWindowTest()1346 R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1347 };
1348
TEST_P(R1ReduceWindowTest,DoIt)1349 TEST_P(R1ReduceWindowTest, DoIt) {
1350 XlaBuilder b(TestName());
1351 const auto& param = ::testing::get<0>(GetParam());
1352 CHECK(param.reducer == kAdd || param.reducer == kMax);
1353
1354 const float kInitValue = 0.0f;
1355 std::vector<float> input_vector(param.base_bounds[0]);
1356 std::iota(std::begin(input_vector), std::end(input_vector), 0);
1357 Literal input_literal =
1358 LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
1359 XlaOp parameter;
1360 auto input_arg =
1361 CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter);
1362
1363 std::vector<std::pair<int64, int64>> padding(1);
1364 padding[0] = {param.pad_low[0], param.pad_high[0]};
1365
1366 auto computation = param.reducer == kAdd
1367 ? CreateScalarAddComputation(FloatType(), &b)
1368 : CreateScalarMaxComputation(FloatType(), &b);
1369 auto init_value =
1370 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1371 ReduceWindowWithGeneralPadding(
1372 /*operand=*/parameter,
1373 /*init_value=*/init_value,
1374 /*computation=*/computation,
1375 /*window_dimensions=*/param.window_bounds,
1376 /*window_strides=*/param.strides,
1377 /*base_dilations=*/{},
1378 /*window_dilations=*/{},
1379 /*padding=*/padding);
1380
1381 auto reduce_func = param.reducer == kAdd
1382 ? +[](float a, float b) { return a + b; }
1383 : +[](float a, float b) { return std::max(a, b); };
1384 auto expected = ReferenceUtil::ReduceWindow1DGeneric(
1385 /*operand=*/absl::Span<const float>(input_vector),
1386 /*init=*/kInitValue,
1387 /*reduce_func=*/reduce_func,
1388 /*window=*/param.window_bounds,
1389 /*stride=*/param.strides,
1390 /*padding=*/padding);
1391
1392 ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
1393 {input_arg.get()}, DefaultErrorSpec());
1394 }
1395
1396 INSTANTIATE_TEST_CASE_P(
1397 R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
1398 ::testing::Combine(::testing::ValuesIn(kR1TestCases),
1399 ::testing::ValuesIn(use_bfloat16_params)),
1400 R1ReduceWindowTestDataToString);
1401
1402 // Test class for text-based test cases. Note that this compares with the
1403 // results on the interpreter backend.
1404 class ReduceWindowTextTest : public HloTestBase {};
1405
XLA_TEST_F(ReduceWindowTextTest,R2General256x384)1406 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
1407 const string hlo_string = R"(
1408 HloModule R2Window
1409 mul {
1410 lhs = f32[] parameter(0)
1411 rhs = f32[] parameter(1)
1412 ROOT mul = f32[] multiply(lhs, rhs)
1413 }
1414 ENTRY R2Window {
1415 operand = f32[256,384]{1,0} parameter(0)
1416 constant = f32[] constant(1)
1417 ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1418 }
1419 )";
1420 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1421 }
1422
XLA_TEST_F(ReduceWindowTextTest,R2General256x384Layout01)1423 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
1424 const string hlo_string = R"(
1425 HloModule R2Window
1426 mul {
1427 lhs = f32[] parameter(0)
1428 rhs = f32[] parameter(1)
1429 ROOT mul = f32[] multiply(lhs, rhs)
1430 }
1431 ENTRY R2Window {
1432 operand = f32[256,384]{0,1} parameter(0)
1433 constant = f32[] constant(1)
1434 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1435 }
1436 )";
1437 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1438 }
1439
XLA_TEST_F(ReduceWindowTextTest,R2General2x5)1440 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
1441 const string hlo_string = R"(
1442 HloModule R2Window
1443 mul {
1444 lhs = f32[] parameter(0)
1445 rhs = f32[] parameter(1)
1446 ROOT mul = f32[] multiply(lhs, rhs)
1447 }
1448 ENTRY R2Window {
1449 operand = f32[2,5]{1,0} parameter(0)
1450 constant = f32[] constant(1)
1451 ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
1452 }
1453 )";
1454 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1455 }
1456
XLA_TEST_F(ReduceWindowTextTest,R2EffectiveScalar)1457 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
1458 const string hlo_string = R"(
1459 HloModule R2Window
1460 mul {
1461 lhs = f32[] parameter(0)
1462 rhs = f32[] parameter(1)
1463 ROOT mul = f32[] multiply(lhs, rhs)
1464 }
1465 ENTRY R2Window {
1466 operand = f32[1,1]{1,0} parameter(0)
1467 negate = f32[1,1]{1,0} negate(operand)
1468 constant = f32[] constant(1)
1469 ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
1470 }
1471 )";
1472 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1473 }
1474
XLA_TEST_F(ReduceWindowTextTest,R3EffectiveScalar)1475 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
1476 const string hlo_string = R"(
1477 HloModule R3Window
1478 mul {
1479 lhs = f32[] parameter(0)
1480 rhs = f32[] parameter(1)
1481 ROOT mul = f32[] multiply(lhs, rhs)
1482 }
1483 ENTRY R3Window {
1484 operand = f32[1,1,1]{2,1,0} parameter(0)
1485 negate = f32[1,1,1]{2,1,0} negate(operand)
1486 constant = f32[] constant(1)
1487 ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul
1488 }
1489 )";
1490 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1491 }
1492
XLA_TEST_F(HloTestBase,ReduceWindowIdentity)1493 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
1494 const string hlo_string = R"(
1495 HloModule ReduceWindowIdentity
1496 identity.pad_to_reduce_window {
1497 param0 = f32[] parameter(0)
1498 ROOT param1 = f32[] parameter(1)
1499 }
1500 ENTRY reduce-window-identity {
1501 operand = f32[1,32,64]{2,1,0} parameter(0)
1502 constant.4466 = f32[] constant(0)
1503 ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window
1504 }
1505
1506 )";
1507 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1508 }
1509
XLA_TEST_F(HloTestBase,ReduceWindowS32)1510 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
1511 const string hlo_string = R"(
1512 HloModule reduce-window
1513
1514 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
1515 %param0 = s32[] parameter(0)
1516 ROOT %param1 = s32[] parameter(1)
1517 }
1518
1519 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
1520 %parameter.0 = s32[81,8]{1,0} parameter(0)
1521 %parameter.1 = s32[] parameter(1)
1522 ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1523 }
1524
1525 )";
1526 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1527 }
1528
XLA_TEST_F(HloTestBase,ReduceWindowS64)1529 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
1530 const string hlo_string = R"(
1531 HloModule reduce-window
1532
1533 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
1534 %param0 = s64[] parameter(0)
1535 ROOT %param1 = s64[] parameter(1)
1536 }
1537
1538 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
1539 %parameter.0 = s64[81,8]{1,0} parameter(0)
1540 %parameter.1 = s64[] parameter(1)
1541 ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1542 }
1543
1544 )";
1545 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1546 }
1547
XLA_TEST_F(HloTestBase,ReduceWindowF16)1548 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
1549 const string hlo_string = R"(
1550 HloModule reduce-window
1551
1552 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
1553 %param0 = f16[] parameter(0)
1554 ROOT %param1 = f16[] parameter(1)
1555 }
1556
1557 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
1558 %parameter.0 = f16[81,8]{1,0} parameter(0)
1559 %parameter.1 = f16[] parameter(1)
1560 ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1561 }
1562
1563 )";
1564 EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
1565 }
1566
1567 } // namespace
1568 } // namespace xla
1569