• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests the select-and-scatter XLA operation.
17 
18 // b/194424657: On macs, the compiler hangs when trying to compile this file
19 #if !defined(__APPLE__)
20 
21 #include <memory>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/padding.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/reference_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/test.h"
39 
40 namespace xla {
41 namespace {
42 
43 struct SelectAndScatterTestParam {
44   std::vector<int64_t> operand_shape;
45   std::vector<int64_t> source_shape;
46   Padding padding_type;
47   std::vector<int64_t> window_dimensions;
48   std::vector<int64_t> window_strides;
49 };
50 
51 class SelectAndScatterTest
52     : public ClientLibraryTestBase,
53       public ::testing::WithParamInterface<SelectAndScatterTestParam> {
54  public:
SelectAndScatterTest()55   SelectAndScatterTest() : builder_(TestName()) {
56     // Create S32 GE and ADD computations for select and scatter respectively.
57     ge_s32_ = CreateScalarGeComputation(S32, &builder_);
58     add_s32_ = CreateScalarAddComputation(S32, &builder_);
59     ge_f32_ = CreateScalarGeComputation(F32, &builder_);
60     add_f32_ = CreateScalarAddComputation(F32, &builder_);
61     max_f32_ = CreateScalarMaxComputation(F32, &builder_);
62     min_f32_ = CreateScalarMinComputation(F32, &builder_);
63   }
64 
65   XlaBuilder builder_;
66   XlaComputation ge_s32_;
67   XlaComputation add_s32_;
68   XlaComputation ge_f32_;
69   XlaComputation add_f32_;
70   XlaComputation max_f32_;
71   XlaComputation min_f32_;
72 };
73 
XLA_TEST_P(SelectAndScatterTest,ParamTest)74 XLA_TEST_P(SelectAndScatterTest, ParamTest) {
75   auto operand_shape = GetParam().operand_shape;
76   Array<float> o(operand_shape);
77   o.FillRandom(1.5f);
78   auto operand = ConstantFromArray(&builder_, o);
79 
80   auto source_shape = GetParam().source_shape;
81   Array<float> s(source_shape);
82   s.FillRandom(12.0f);
83   auto source = ConstantFromArray(&builder_, s);
84 
85   SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
86                    GetParam().window_strides, GetParam().padding_type, source,
87                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
88 
89   ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5, 1e-5));
90 }
91 
92 INSTANTIATE_TEST_CASE_P(
93     SelectAndScatterTest_Instantiation, SelectAndScatterTest,
94     ::testing::Values(
95         SelectAndScatterTestParam{{6, 6, 6, 4, 4},
96                                   {3, 3, 3, 4, 4},
97                                   Padding::kSame,
98                                   {3, 3, 3, 1, 1},
99                                   {2, 2, 2, 1, 1}},
100         SelectAndScatterTestParam{{7, 7, 7, 4, 4},
101                                   {3, 3, 3, 4, 4},
102                                   Padding::kValid,
103                                   {3, 3, 3, 1, 1},
104                                   {2, 2, 2, 1, 1}},
105 
106         SelectAndScatterTestParam{{8, 8, 8, 4, 4},
107                                   {1, 3, 3, 4, 4},
108                                   Padding::kValid,
109                                   {8, 4, 4, 1, 1},
110                                   {1, 2, 2, 1, 1}},
111         SelectAndScatterTestParam{{6, 6, 256, 128},
112                                   {3, 3, 256, 128},
113                                   Padding::kSame,
114                                   {3, 3, 1, 1},
115                                   {2, 2, 1, 1}},
116         SelectAndScatterTestParam{{7, 7, 256, 128},
117                                   {3, 3, 256, 128},
118                                   Padding::kValid,
119                                   {3, 3, 1, 1},
120                                   {2, 2, 1, 1}},
121         SelectAndScatterTestParam{{6, 7, 256, 128},
122                                   {3, 3, 256, 128},
123                                   Padding::kValid,
124                                   {2, 3, 1, 1},
125                                   {2, 2, 1, 1}},
126         SelectAndScatterTestParam{{6, 7, 256, 128},
127                                   {2, 3, 256, 128},
128                                   Padding::kValid,
129                                   {2, 3, 1, 1},
130                                   {3, 2, 1, 1}},
131         SelectAndScatterTestParam{{9, 9, 16, 128},
132                                   {3, 3, 16, 128},
133                                   Padding::kValid,
134                                   {3, 3, 1, 1},
135                                   {3, 3, 1, 1}},
136         SelectAndScatterTestParam{{3, 3, 4, 4},
137                                   {1, 1, 4, 4},
138                                   Padding::kValid,
139                                   {3, 3, 1, 1},
140                                   {3, 3, 1, 1}},
141         SelectAndScatterTestParam{{3, 3, 4, 4},
142                                   {1, 1, 4, 4},
143                                   Padding::kValid,
144                                   {3, 3, 1, 1},
145                                   {3, 3, 1, 1}},
146         SelectAndScatterTestParam{{9, 3, 4, 4},
147                                   {3, 1, 4, 4},
148                                   Padding::kValid,
149                                   {3, 3, 1, 1},
150                                   {3, 3, 1, 1}},
151         // Uncovered by b/126212776.
152         SelectAndScatterTestParam{{15, 1, 1, 1},
153                                   {2, 1, 1, 1},
154                                   Padding::kValid,
155                                   {14, 1, 1, 1},
156                                   {1, 1, 1, 1}},
157         SelectAndScatterTestParam{{7, 3, 4, 4},
158                                   {3, 1, 4, 4},
159                                   Padding::kValid,
160                                   {3, 3, 1, 1},
161                                   {2, 3, 1, 1}},
162         SelectAndScatterTestParam{{1, 1, 5, 5},
163                                   {1, 1, 5, 5},
164                                   Padding::kSame,
165                                   {3, 3, 1, 1},
166                                   {3, 3, 1, 1}},
167         SelectAndScatterTestParam{{7, 7, 8, 256},
168                                   {4, 4, 8, 256},
169                                   Padding::kSame,
170                                   {2, 2, 1, 1},
171                                   {2, 2, 1, 1}},
172         SelectAndScatterTestParam{
173             {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
174         SelectAndScatterTestParam{
175             {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
176         SelectAndScatterTestParam{{7, 256, 128},
177                                   {3, 256, 128},
178                                   Padding::kValid,
179                                   {3, 1, 1},
180                                   {2, 1, 1}},
181         SelectAndScatterTestParam{{6, 256, 128},
182                                   {3, 256, 128},
183                                   Padding::kValid,
184                                   {2, 1, 1},
185                                   {2, 1, 1}},
186         SelectAndScatterTestParam{{6, 256, 128},
187                                   {2, 256, 128},
188                                   Padding::kValid,
189                                   {2, 1, 1},
190                                   {3, 1, 1}},
191         SelectAndScatterTestParam{{10, 10, 8, 256},
192                                   {5, 5, 8, 256},
193                                   Padding::kSame,
194                                   {2, 2, 1, 1},
195                                   {2, 2, 1, 1}},
196         SelectAndScatterTestParam{
197             {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
198         SelectAndScatterTestParam{
199             {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
200         SelectAndScatterTestParam{
201             {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
202         SelectAndScatterTestParam{
203             {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
204         SelectAndScatterTestParam{
205             {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
206         SelectAndScatterTestParam{
207             {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
208         SelectAndScatterTestParam{
209             {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}},
210         SelectAndScatterTestParam{{1104}, {551}, Padding::kValid, {3}, {2}},
211         SelectAndScatterTestParam{{1300}, {1171}, Padding::kValid, {130}, {1}},
212         SelectAndScatterTestParam{{3000}, {1701}, Padding::kValid, {1300}, {1}},
213         SelectAndScatterTestParam{{6500}, {5}, Padding::kValid, {1300}, {1300}},
214         SelectAndScatterTestParam{
215             {3000}, {401}, Padding::kValid, {2600}, {1}}));
216 
217 // Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest,R1S0F32)218 XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
219   const auto operand = ConstantR1<float>(&builder_, {});
220   const auto source = ConstantR1<float>(&builder_, {});
221   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
222                    /*window_strides=*/{3}, Padding::kValid, source,
223                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
224   ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
225 }
226 
227 // Test for F32 1D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R1F32)228 XLA_TEST_F(SelectAndScatterTest, R1F32) {
229   const auto operand =
230       ConstantR1<float>(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
231   const auto source = ConstantR1<float>(&builder_, {34.f, 42.f});
232   const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
233   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
234                    /*window_strides=*/{3}, Padding::kValid, source,
235                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
236   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
237 }
238 
239 // Test for S32 1D array, when windows do not overlap and the init value is 1.
XLA_TEST_F(SelectAndScatterTest,R1S32)240 XLA_TEST_F(SelectAndScatterTest, R1S32) {
241   const auto operand = ConstantR1<int32_t>(&builder_, {-1, 0, 6, 4, -4, 10});
242   const auto source = ConstantR1<int32_t>(&builder_, {-10, 20});
243   const std::vector<int32_t> expected = {1, 1, -9, 1, 1, 21};
244   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
245                    /*window_strides=*/{3}, Padding::kValid, source,
246                    ConstantR0<int32_t>(&builder_, 1), add_s32_);
247   ComputeAndCompareR1<int32_t>(&builder_, expected, {});
248 }
249 
250 // Test for S32 1D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R1S32OverlappingWindow)251 XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
252   const auto operand = ConstantR1<int32_t>(&builder_, {1, 9, 3, 7, 5, 6});
253   const auto source = ConstantR1<int32_t>(&builder_, {34, 42, 53, 19});
254   const std::vector<int32_t> expected = {0, 76, 0, 72, 0, 0};
255   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
256                    /*window_strides=*/{1}, Padding::kValid, source,
257                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
258   ComputeAndCompareR1<int32_t>(&builder_, expected, {});
259 }
260 
261 // Test for S32 2D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R2S32)262 XLA_TEST_F(SelectAndScatterTest, R2S32) {
263   const auto operand =
264       ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
265   const auto source = ConstantR2<int32_t>(&builder_, {{2, 6}});
266   Array2D<int32_t> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
267   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
268                    /*window_strides=*/{2, 3}, Padding::kValid, source,
269                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
270   ComputeAndCompareR2<int32_t>(&builder_, expected, {});
271 }
272 
273 // Test for tie breaking rule in ge_f32_. When a tie is present, the operand
274 // that has the lower lexicographical order (smaller index) should be chosen.
XLA_TEST_F(SelectAndScatterTest,R2F32Tie)275 XLA_TEST_F(SelectAndScatterTest, R2F32Tie) {
276   const auto operand = ConstantR2<float>(
277       &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
278   const auto source = ConstantR2<float>(
279       &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
280   Array2D<float> expected(
281       {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}});
282   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
283                    /*window_strides=*/{1, 1}, Padding::kSame, source,
284                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
285   ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
286 }
287 
288 // Similar to SelectAndScatterTest.R2S32 but the input is transposed.
XLA_TEST_F(SelectAndScatterTest,ReshapeR2S32)289 XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
290   const auto operand = ConstantR2<int32_t>(
291       &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
292   const auto reshape =
293       Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
294   const auto source = ConstantR2<int32_t>(&builder_, {{2, 6}});
295   Array2D<int32_t> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
296   SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
297                    /*window_strides=*/{2, 3}, Padding::kValid, source,
298                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
299   ComputeAndCompareR2<int32_t>(&builder_, expected, {});
300 }
301 
302 // Test for S32 2D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32OverlappingWindow)303 XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
304   const auto operand =
305       ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
306   const auto source = ConstantR2<int32_t>(&builder_, {{2, 6, 4}});
307   Array2D<int32_t> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
308   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
309                    /*window_strides=*/{1, 1}, Padding::kValid, source,
310                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
311   ComputeAndCompareR2<int32_t>(&builder_, expected, {});
312 }
313 
314 // Test for S32 2D array, when the padding is Padding::kSAME.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePadding)315 XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
316   const auto operand =
317       ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
318   const auto source = ConstantR2<int32_t>(&builder_, {{2, 6, 4}});
319   Array2D<int32_t> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
320   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
321                    /*window_strides=*/{2, 2}, Padding::kSame, source,
322                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
323   ComputeAndCompareR2<int32_t>(&builder_, expected, {});
324 }
325 
326 // Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
327 // with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePaddingOverlappingWindow)328 XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
329   const auto operand =
330       ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
331   const auto source =
332       ConstantR2<int32_t>(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
333   Array2D<int32_t> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
334   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
335                    /*window_strides=*/{1, 1}, Padding::kSame, source,
336                    ConstantR0<int32_t>(&builder_, 0), add_s32_);
337   ComputeAndCompareR2<int32_t>(&builder_, expected, {});
338 }
339 
XLA_TEST_F(SelectAndScatterTest,R2F32OverlappingR2Source)340 XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
341   const auto operand = ConstantR2<float>(
342       &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
343   const auto source =
344       ConstantR2<float>(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}});
345   Array2D<float> expected(
346       {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
347   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
348                    /*window_strides=*/{1, 1}, Padding::kValid, source,
349                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
350   ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
351 }
352 
TEST_F(SelectAndScatterTest,R4F32Valid)353 TEST_F(SelectAndScatterTest, R4F32Valid) {
354   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
355                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
356                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
357                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
358   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
359   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
360                         {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
361                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
362                         {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
363   Array4D<float> o(4, 6, 15, 220);
364   o.FillWithPZ(pzo);
365   auto operand = ConstantR4FromArray4D(&builder_, o);
366   Array4D<float> e(4, 6, 15, 220);
367   e.FillWithPZ(pze);
368   Array4D<float> s(2, 2, 15, 220);
369   s.FillWithPZ(pzs);
370   auto source = ConstantR4FromArray4D(&builder_, s);
371   s.FillWithPZ(pzs);
372   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
373                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
374                    add_f32_);
375   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
376 }
377 
TEST_F(SelectAndScatterTest,R4F32Overlap)378 TEST_F(SelectAndScatterTest, R4F32Overlap) {
379   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
380                         {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
381                         {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
382                         {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
383   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
384   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
385                         {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
386                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
387                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
388   Array4D<float> o(4, 5, 17, 128);
389   o.FillWithPZ(pzo);
390   auto operand = ConstantR4FromArray4D(&builder_, o);
391   Array4D<float> e(4, 5, 17, 128);
392   e.FillWithPZ(pze);
393   Array4D<float> s(2, 2, 17, 128);
394   s.FillWithPZ(pzs);
395   auto source = ConstantR4FromArray4D(&builder_, s);
396   s.FillWithPZ(pzs);
397   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
398                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
399                    add_f32_);
400   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
401 }
402 
TEST_F(SelectAndScatterTest,R4F32OverlapSmall)403 TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
404   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
405                         {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
406                         {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
407                         {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
408   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
409   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
410                         {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
411                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
412                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
413   Array4D<float> o(4, 5, 1, 1);
414   o.FillWithPZ(pzo);
415   auto operand = ConstantR4FromArray4D(&builder_, o);
416   Array4D<float> e(4, 5, 1, 1);
417   e.FillWithPZ(pze);
418   Array4D<float> s(2, 2, 1, 1);
419   s.FillWithPZ(pzs);
420   auto source = ConstantR4FromArray4D(&builder_, s);
421   s.FillWithPZ(pzs);
422   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
423                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
424                    add_f32_);
425   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
426 }
427 
TEST_F(SelectAndScatterTest,R4F32RefValidFixedSmall)428 TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
429   // This test is testing the Reference Util
430   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
431                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
432                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
433                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
434   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
435   Array4D<float> o(4, 6, 4, 4);
436   o.FillWithPZ(pzo);
437   auto operand = ConstantR4FromArray4D(&builder_, o);
438   Array4D<float> s(2, 2, 4, 4);
439   s.FillWithPZ(pzs);
440 
441   auto source = ConstantR4FromArray4D(&builder_, s);
442   s.FillWithPZ(pzs);
443   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
444                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
445                    add_f32_);
446   auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
447                                                    {2, 3, 1, 1}, false);
448   ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
449 }
450 
451 // Test for F32 4D array with negative padding on both ends.
XLA_TEST_F(SelectAndScatterTest,R4NegativePaddingOnBothEnds)452 XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingOnBothEnds) {
453   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
454                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
455                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
456                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
457   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
458   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
459                         {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
460                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
461                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}};
462   Array4D<float> o(4, 6, 4, 4);
463   o.FillWithPZ(pzo);
464   auto operand = ConstantR4FromArray4D(&builder_, o);
465   Array4D<float> e(4, 6, 4, 4);
466   e.FillWithPZ(pze);
467   Array4D<float> s(2, 2, 4, 4);
468   s.FillWithPZ(pzs);
469   auto source = ConstantR4FromArray4D(&builder_, s);
470   s.FillWithPZ(pzs);
471   SelectAndScatterWithGeneralPadding(
472       operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
473       {{0, 0}, {-1, -1}, {0, 0}, {0, 0}}, source,
474       ConstantR0<float>(&builder_, 0.0f), add_f32_);
475   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
476 }
477 
478 // Test for F32 4D array with positive low padding and negative high padding.
XLA_TEST_F(SelectAndScatterTest,R4PositivePaddingLowAndNegativePaddingHigh)479 XLA_TEST_F(SelectAndScatterTest, R4PositivePaddingLowAndNegativePaddingHigh) {
480   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
481                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
482                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
483                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
484   Array2D<float> pzs = {{2.0f, 6.0f, 4.0f}, {3.0f, 1.0f, 5.0f}};
485   Array2D<float> pze = {{2.0f, 0.0f, 0.0f, 0.0f, 4.0f, 0.0f},
486                         {0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f},
487                         {3.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f},
488                         {0.0f, 0.0f, 0.0f, 5.0f, 0.0f, 0.0f}};
489   Array4D<float> o(4, 6, 4, 4);
490   o.FillWithPZ(pzo);
491   auto operand = ConstantR4FromArray4D(&builder_, o);
492   Array4D<float> e(4, 6, 4, 4);
493   e.FillWithPZ(pze);
494   Array4D<float> s(2, 3, 4, 4);
495   s.FillWithPZ(pzs);
496   auto source = ConstantR4FromArray4D(&builder_, s);
497   s.FillWithPZ(pzs);
498   SelectAndScatterWithGeneralPadding(
499       operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
500       {{0, 0}, {1, -1}, {0, 0}, {0, 0}}, source,
501       ConstantR0<float>(&builder_, 0.0f), add_f32_);
502   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
503 }
504 
505 // Test for F32 4D array with negative low padding and positive high padding.
XLA_TEST_F(SelectAndScatterTest,R4NegativePaddingLowAndPositivePaddingHigh)506 XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingLowAndPositivePaddingHigh) {
507   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
508                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
509                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
510                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
511   Array2D<float> pzs = {{2.0f, 6.0f, 4.0f}, {3.0f, 1.0f, 5.0f}};
512   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 4.0f},
513                         {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
514                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
515                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 5.0f}};
516   Array4D<float> o(4, 6, 4, 4);
517   o.FillWithPZ(pzo);
518   auto operand = ConstantR4FromArray4D(&builder_, o);
519   Array4D<float> e(4, 6, 4, 4);
520   e.FillWithPZ(pze);
521   Array4D<float> s(2, 3, 4, 4);
522   s.FillWithPZ(pzs);
523   auto source = ConstantR4FromArray4D(&builder_, s);
524   s.FillWithPZ(pzs);
525   SelectAndScatterWithGeneralPadding(
526       operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
527       {{0, 0}, {-1, 1}, {0, 0}, {0, 0}}, source,
528       ConstantR0<float>(&builder_, 0.0f), add_f32_);
529   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
530 }
531 
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMaxScatter)532 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
533   const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
534   const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
535   const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
536   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
537                    /*window_strides=*/{1}, Padding::kValid, source,
538                    ConstantR0<float>(&builder_, 0), max_f32_);
539   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
540 }
541 
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMinScatter)542 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
543   const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
544   const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
545   const float max_float = std::numeric_limits<float>::max();
546   const std::vector<float> expected = {max_float, max_float, max_float, 19,
547                                        max_float, max_float, max_float};
548   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
549                    /*window_strides=*/{1}, Padding::kValid, source,
550                    ConstantR0<float>(&builder_, max_float), min_f32_);
551   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
552 }
553 
554 }  // namespace
555 }  // namespace xla
556 
557 #endif  // !defined(__APPLE__)
558