1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests the select-and-scatter XLA operation.
17
18 // b/194424657: On macs, the compiler hangs when trying to compile this file
19 #if !defined(__APPLE__)
20
21 #include <memory>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/client/padding.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/reference_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/platform/test.h"
39
40 namespace xla {
41 namespace {
42
43 struct SelectAndScatterTestParam {
44 std::vector<int64_t> operand_shape;
45 std::vector<int64_t> source_shape;
46 Padding padding_type;
47 std::vector<int64_t> window_dimensions;
48 std::vector<int64_t> window_strides;
49 };
50
51 class SelectAndScatterTest
52 : public ClientLibraryTestBase,
53 public ::testing::WithParamInterface<SelectAndScatterTestParam> {
54 public:
SelectAndScatterTest()55 SelectAndScatterTest() : builder_(TestName()) {
56 // Create S32 GE and ADD computations for select and scatter respectively.
57 ge_s32_ = CreateScalarGeComputation(S32, &builder_);
58 add_s32_ = CreateScalarAddComputation(S32, &builder_);
59 ge_f32_ = CreateScalarGeComputation(F32, &builder_);
60 add_f32_ = CreateScalarAddComputation(F32, &builder_);
61 max_f32_ = CreateScalarMaxComputation(F32, &builder_);
62 min_f32_ = CreateScalarMinComputation(F32, &builder_);
63 }
64
65 XlaBuilder builder_;
66 XlaComputation ge_s32_;
67 XlaComputation add_s32_;
68 XlaComputation ge_f32_;
69 XlaComputation add_f32_;
70 XlaComputation max_f32_;
71 XlaComputation min_f32_;
72 };
73
XLA_TEST_P(SelectAndScatterTest,ParamTest)74 XLA_TEST_P(SelectAndScatterTest, ParamTest) {
75 auto operand_shape = GetParam().operand_shape;
76 Array<float> o(operand_shape);
77 o.FillRandom(1.5f);
78 auto operand = ConstantFromArray(&builder_, o);
79
80 auto source_shape = GetParam().source_shape;
81 Array<float> s(source_shape);
82 s.FillRandom(12.0f);
83 auto source = ConstantFromArray(&builder_, s);
84
85 SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
86 GetParam().window_strides, GetParam().padding_type, source,
87 ConstantR0<float>(&builder_, 0.0f), add_f32_);
88
89 ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5, 1e-5));
90 }
91
92 INSTANTIATE_TEST_CASE_P(
93 SelectAndScatterTest_Instantiation, SelectAndScatterTest,
94 ::testing::Values(
95 SelectAndScatterTestParam{{6, 6, 6, 4, 4},
96 {3, 3, 3, 4, 4},
97 Padding::kSame,
98 {3, 3, 3, 1, 1},
99 {2, 2, 2, 1, 1}},
100 SelectAndScatterTestParam{{7, 7, 7, 4, 4},
101 {3, 3, 3, 4, 4},
102 Padding::kValid,
103 {3, 3, 3, 1, 1},
104 {2, 2, 2, 1, 1}},
105
106 SelectAndScatterTestParam{{8, 8, 8, 4, 4},
107 {1, 3, 3, 4, 4},
108 Padding::kValid,
109 {8, 4, 4, 1, 1},
110 {1, 2, 2, 1, 1}},
111 SelectAndScatterTestParam{{6, 6, 256, 128},
112 {3, 3, 256, 128},
113 Padding::kSame,
114 {3, 3, 1, 1},
115 {2, 2, 1, 1}},
116 SelectAndScatterTestParam{{7, 7, 256, 128},
117 {3, 3, 256, 128},
118 Padding::kValid,
119 {3, 3, 1, 1},
120 {2, 2, 1, 1}},
121 SelectAndScatterTestParam{{6, 7, 256, 128},
122 {3, 3, 256, 128},
123 Padding::kValid,
124 {2, 3, 1, 1},
125 {2, 2, 1, 1}},
126 SelectAndScatterTestParam{{6, 7, 256, 128},
127 {2, 3, 256, 128},
128 Padding::kValid,
129 {2, 3, 1, 1},
130 {3, 2, 1, 1}},
131 SelectAndScatterTestParam{{9, 9, 16, 128},
132 {3, 3, 16, 128},
133 Padding::kValid,
134 {3, 3, 1, 1},
135 {3, 3, 1, 1}},
136 SelectAndScatterTestParam{{3, 3, 4, 4},
137 {1, 1, 4, 4},
138 Padding::kValid,
139 {3, 3, 1, 1},
140 {3, 3, 1, 1}},
141 SelectAndScatterTestParam{{3, 3, 4, 4},
142 {1, 1, 4, 4},
143 Padding::kValid,
144 {3, 3, 1, 1},
145 {3, 3, 1, 1}},
146 SelectAndScatterTestParam{{9, 3, 4, 4},
147 {3, 1, 4, 4},
148 Padding::kValid,
149 {3, 3, 1, 1},
150 {3, 3, 1, 1}},
151 // Uncovered by b/126212776.
152 SelectAndScatterTestParam{{15, 1, 1, 1},
153 {2, 1, 1, 1},
154 Padding::kValid,
155 {14, 1, 1, 1},
156 {1, 1, 1, 1}},
157 SelectAndScatterTestParam{{7, 3, 4, 4},
158 {3, 1, 4, 4},
159 Padding::kValid,
160 {3, 3, 1, 1},
161 {2, 3, 1, 1}},
162 SelectAndScatterTestParam{{1, 1, 5, 5},
163 {1, 1, 5, 5},
164 Padding::kSame,
165 {3, 3, 1, 1},
166 {3, 3, 1, 1}},
167 SelectAndScatterTestParam{{7, 7, 8, 256},
168 {4, 4, 8, 256},
169 Padding::kSame,
170 {2, 2, 1, 1},
171 {2, 2, 1, 1}},
172 SelectAndScatterTestParam{
173 {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
174 SelectAndScatterTestParam{
175 {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
176 SelectAndScatterTestParam{{7, 256, 128},
177 {3, 256, 128},
178 Padding::kValid,
179 {3, 1, 1},
180 {2, 1, 1}},
181 SelectAndScatterTestParam{{6, 256, 128},
182 {3, 256, 128},
183 Padding::kValid,
184 {2, 1, 1},
185 {2, 1, 1}},
186 SelectAndScatterTestParam{{6, 256, 128},
187 {2, 256, 128},
188 Padding::kValid,
189 {2, 1, 1},
190 {3, 1, 1}},
191 SelectAndScatterTestParam{{10, 10, 8, 256},
192 {5, 5, 8, 256},
193 Padding::kSame,
194 {2, 2, 1, 1},
195 {2, 2, 1, 1}},
196 SelectAndScatterTestParam{
197 {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
198 SelectAndScatterTestParam{
199 {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
200 SelectAndScatterTestParam{
201 {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
202 SelectAndScatterTestParam{
203 {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
204 SelectAndScatterTestParam{
205 {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
206 SelectAndScatterTestParam{
207 {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
208 SelectAndScatterTestParam{
209 {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}},
210 SelectAndScatterTestParam{{1104}, {551}, Padding::kValid, {3}, {2}},
211 SelectAndScatterTestParam{{1300}, {1171}, Padding::kValid, {130}, {1}},
212 SelectAndScatterTestParam{{3000}, {1701}, Padding::kValid, {1300}, {1}},
213 SelectAndScatterTestParam{{6500}, {5}, Padding::kValid, {1300}, {1300}},
214 SelectAndScatterTestParam{
215 {3000}, {401}, Padding::kValid, {2600}, {1}}));
216
217 // Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest,R1S0F32)218 XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
219 const auto operand = ConstantR1<float>(&builder_, {});
220 const auto source = ConstantR1<float>(&builder_, {});
221 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
222 /*window_strides=*/{3}, Padding::kValid, source,
223 ConstantR0<float>(&builder_, 0.0f), add_f32_);
224 ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
225 }
226
227 // Test for F32 1D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R1F32)228 XLA_TEST_F(SelectAndScatterTest, R1F32) {
229 const auto operand =
230 ConstantR1<float>(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
231 const auto source = ConstantR1<float>(&builder_, {34.f, 42.f});
232 const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
233 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
234 /*window_strides=*/{3}, Padding::kValid, source,
235 ConstantR0<float>(&builder_, 0.0f), add_f32_);
236 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
237 }
238
239 // Test for S32 1D array, when windows do not overlap and the init value is 1.
XLA_TEST_F(SelectAndScatterTest,R1S32)240 XLA_TEST_F(SelectAndScatterTest, R1S32) {
241 const auto operand = ConstantR1<int32_t>(&builder_, {-1, 0, 6, 4, -4, 10});
242 const auto source = ConstantR1<int32_t>(&builder_, {-10, 20});
243 const std::vector<int32_t> expected = {1, 1, -9, 1, 1, 21};
244 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
245 /*window_strides=*/{3}, Padding::kValid, source,
246 ConstantR0<int32_t>(&builder_, 1), add_s32_);
247 ComputeAndCompareR1<int32_t>(&builder_, expected, {});
248 }
249
250 // Test for S32 1D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R1S32OverlappingWindow)251 XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
252 const auto operand = ConstantR1<int32_t>(&builder_, {1, 9, 3, 7, 5, 6});
253 const auto source = ConstantR1<int32_t>(&builder_, {34, 42, 53, 19});
254 const std::vector<int32_t> expected = {0, 76, 0, 72, 0, 0};
255 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
256 /*window_strides=*/{1}, Padding::kValid, source,
257 ConstantR0<int32_t>(&builder_, 0), add_s32_);
258 ComputeAndCompareR1<int32_t>(&builder_, expected, {});
259 }
260
261 // Test for S32 2D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R2S32)262 XLA_TEST_F(SelectAndScatterTest, R2S32) {
263 const auto operand =
264 ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
265 const auto source = ConstantR2<int32_t>(&builder_, {{2, 6}});
266 Array2D<int32_t> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
267 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
268 /*window_strides=*/{2, 3}, Padding::kValid, source,
269 ConstantR0<int32_t>(&builder_, 0), add_s32_);
270 ComputeAndCompareR2<int32_t>(&builder_, expected, {});
271 }
272
273 // Test for tie breaking rule in ge_f32_. When a tie is present, the operand
274 // that has the lower lexicographical order (smaller index) should be chosen.
XLA_TEST_F(SelectAndScatterTest,R2F32Tie)275 XLA_TEST_F(SelectAndScatterTest, R2F32Tie) {
276 const auto operand = ConstantR2<float>(
277 &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
278 const auto source = ConstantR2<float>(
279 &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
280 Array2D<float> expected(
281 {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}});
282 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
283 /*window_strides=*/{1, 1}, Padding::kSame, source,
284 ConstantR0<float>(&builder_, 0.0f), add_f32_);
285 ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
286 }
287
288 // Similar to SelectAndScatterTest.R2S32 but the input is transposed.
XLA_TEST_F(SelectAndScatterTest,ReshapeR2S32)289 XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
290 const auto operand = ConstantR2<int32_t>(
291 &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
292 const auto reshape =
293 Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
294 const auto source = ConstantR2<int32_t>(&builder_, {{2, 6}});
295 Array2D<int32_t> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
296 SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
297 /*window_strides=*/{2, 3}, Padding::kValid, source,
298 ConstantR0<int32_t>(&builder_, 0), add_s32_);
299 ComputeAndCompareR2<int32_t>(&builder_, expected, {});
300 }
301
302 // Test for S32 2D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32OverlappingWindow)303 XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
304 const auto operand =
305 ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
306 const auto source = ConstantR2<int32_t>(&builder_, {{2, 6, 4}});
307 Array2D<int32_t> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
308 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
309 /*window_strides=*/{1, 1}, Padding::kValid, source,
310 ConstantR0<int32_t>(&builder_, 0), add_s32_);
311 ComputeAndCompareR2<int32_t>(&builder_, expected, {});
312 }
313
314 // Test for S32 2D array, when the padding is Padding::kSAME.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePadding)315 XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
316 const auto operand =
317 ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
318 const auto source = ConstantR2<int32_t>(&builder_, {{2, 6, 4}});
319 Array2D<int32_t> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
320 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
321 /*window_strides=*/{2, 2}, Padding::kSame, source,
322 ConstantR0<int32_t>(&builder_, 0), add_s32_);
323 ComputeAndCompareR2<int32_t>(&builder_, expected, {});
324 }
325
326 // Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
327 // with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePaddingOverlappingWindow)328 XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
329 const auto operand =
330 ConstantR2<int32_t>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
331 const auto source =
332 ConstantR2<int32_t>(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
333 Array2D<int32_t> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
334 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
335 /*window_strides=*/{1, 1}, Padding::kSame, source,
336 ConstantR0<int32_t>(&builder_, 0), add_s32_);
337 ComputeAndCompareR2<int32_t>(&builder_, expected, {});
338 }
339
XLA_TEST_F(SelectAndScatterTest,R2F32OverlappingR2Source)340 XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
341 const auto operand = ConstantR2<float>(
342 &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
343 const auto source =
344 ConstantR2<float>(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}});
345 Array2D<float> expected(
346 {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
347 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
348 /*window_strides=*/{1, 1}, Padding::kValid, source,
349 ConstantR0<float>(&builder_, 0.0f), add_f32_);
350 ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
351 }
352
TEST_F(SelectAndScatterTest,R4F32Valid)353 TEST_F(SelectAndScatterTest, R4F32Valid) {
354 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
355 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
356 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
357 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
358 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
359 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
360 {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
361 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
362 {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
363 Array4D<float> o(4, 6, 15, 220);
364 o.FillWithPZ(pzo);
365 auto operand = ConstantR4FromArray4D(&builder_, o);
366 Array4D<float> e(4, 6, 15, 220);
367 e.FillWithPZ(pze);
368 Array4D<float> s(2, 2, 15, 220);
369 s.FillWithPZ(pzs);
370 auto source = ConstantR4FromArray4D(&builder_, s);
371 s.FillWithPZ(pzs);
372 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
373 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
374 add_f32_);
375 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
376 }
377
TEST_F(SelectAndScatterTest,R4F32Overlap)378 TEST_F(SelectAndScatterTest, R4F32Overlap) {
379 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
380 {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
381 {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
382 {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
383 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
384 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
385 {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
386 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
387 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
388 Array4D<float> o(4, 5, 17, 128);
389 o.FillWithPZ(pzo);
390 auto operand = ConstantR4FromArray4D(&builder_, o);
391 Array4D<float> e(4, 5, 17, 128);
392 e.FillWithPZ(pze);
393 Array4D<float> s(2, 2, 17, 128);
394 s.FillWithPZ(pzs);
395 auto source = ConstantR4FromArray4D(&builder_, s);
396 s.FillWithPZ(pzs);
397 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
398 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
399 add_f32_);
400 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
401 }
402
TEST_F(SelectAndScatterTest,R4F32OverlapSmall)403 TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
404 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
405 {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
406 {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
407 {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
408 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
409 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
410 {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
411 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
412 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
413 Array4D<float> o(4, 5, 1, 1);
414 o.FillWithPZ(pzo);
415 auto operand = ConstantR4FromArray4D(&builder_, o);
416 Array4D<float> e(4, 5, 1, 1);
417 e.FillWithPZ(pze);
418 Array4D<float> s(2, 2, 1, 1);
419 s.FillWithPZ(pzs);
420 auto source = ConstantR4FromArray4D(&builder_, s);
421 s.FillWithPZ(pzs);
422 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
423 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
424 add_f32_);
425 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
426 }
427
TEST_F(SelectAndScatterTest,R4F32RefValidFixedSmall)428 TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
429 // This test is testing the Reference Util
430 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
431 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
432 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
433 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
434 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
435 Array4D<float> o(4, 6, 4, 4);
436 o.FillWithPZ(pzo);
437 auto operand = ConstantR4FromArray4D(&builder_, o);
438 Array4D<float> s(2, 2, 4, 4);
439 s.FillWithPZ(pzs);
440
441 auto source = ConstantR4FromArray4D(&builder_, s);
442 s.FillWithPZ(pzs);
443 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
444 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
445 add_f32_);
446 auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
447 {2, 3, 1, 1}, false);
448 ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
449 }
450
451 // Test for F32 4D array with negative padding on both ends.
XLA_TEST_F(SelectAndScatterTest,R4NegativePaddingOnBothEnds)452 XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingOnBothEnds) {
453 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
454 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
455 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
456 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
457 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
458 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
459 {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
460 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
461 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}};
462 Array4D<float> o(4, 6, 4, 4);
463 o.FillWithPZ(pzo);
464 auto operand = ConstantR4FromArray4D(&builder_, o);
465 Array4D<float> e(4, 6, 4, 4);
466 e.FillWithPZ(pze);
467 Array4D<float> s(2, 2, 4, 4);
468 s.FillWithPZ(pzs);
469 auto source = ConstantR4FromArray4D(&builder_, s);
470 s.FillWithPZ(pzs);
471 SelectAndScatterWithGeneralPadding(
472 operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
473 {{0, 0}, {-1, -1}, {0, 0}, {0, 0}}, source,
474 ConstantR0<float>(&builder_, 0.0f), add_f32_);
475 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
476 }
477
478 // Test for F32 4D array with positive low padding and negative high padding.
XLA_TEST_F(SelectAndScatterTest,R4PositivePaddingLowAndNegativePaddingHigh)479 XLA_TEST_F(SelectAndScatterTest, R4PositivePaddingLowAndNegativePaddingHigh) {
480 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
481 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
482 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
483 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
484 Array2D<float> pzs = {{2.0f, 6.0f, 4.0f}, {3.0f, 1.0f, 5.0f}};
485 Array2D<float> pze = {{2.0f, 0.0f, 0.0f, 0.0f, 4.0f, 0.0f},
486 {0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f},
487 {3.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f},
488 {0.0f, 0.0f, 0.0f, 5.0f, 0.0f, 0.0f}};
489 Array4D<float> o(4, 6, 4, 4);
490 o.FillWithPZ(pzo);
491 auto operand = ConstantR4FromArray4D(&builder_, o);
492 Array4D<float> e(4, 6, 4, 4);
493 e.FillWithPZ(pze);
494 Array4D<float> s(2, 3, 4, 4);
495 s.FillWithPZ(pzs);
496 auto source = ConstantR4FromArray4D(&builder_, s);
497 s.FillWithPZ(pzs);
498 SelectAndScatterWithGeneralPadding(
499 operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
500 {{0, 0}, {1, -1}, {0, 0}, {0, 0}}, source,
501 ConstantR0<float>(&builder_, 0.0f), add_f32_);
502 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
503 }
504
505 // Test for F32 4D array with negative low padding and positive high padding.
XLA_TEST_F(SelectAndScatterTest,R4NegativePaddingLowAndPositivePaddingHigh)506 XLA_TEST_F(SelectAndScatterTest, R4NegativePaddingLowAndPositivePaddingHigh) {
507 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 3.0f},
508 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
509 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
510 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
511 Array2D<float> pzs = {{2.0f, 6.0f, 4.0f}, {3.0f, 1.0f, 5.0f}};
512 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 4.0f},
513 {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
514 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
515 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 5.0f}};
516 Array4D<float> o(4, 6, 4, 4);
517 o.FillWithPZ(pzo);
518 auto operand = ConstantR4FromArray4D(&builder_, o);
519 Array4D<float> e(4, 6, 4, 4);
520 e.FillWithPZ(pze);
521 Array4D<float> s(2, 3, 4, 4);
522 s.FillWithPZ(pzs);
523 auto source = ConstantR4FromArray4D(&builder_, s);
524 s.FillWithPZ(pzs);
525 SelectAndScatterWithGeneralPadding(
526 operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
527 {{0, 0}, {-1, 1}, {0, 0}, {0, 0}}, source,
528 ConstantR0<float>(&builder_, 0.0f), add_f32_);
529 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
530 }
531
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMaxScatter)532 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
533 const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
534 const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
535 const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
536 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
537 /*window_strides=*/{1}, Padding::kValid, source,
538 ConstantR0<float>(&builder_, 0), max_f32_);
539 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
540 }
541
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMinScatter)542 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
543 const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
544 const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
545 const float max_float = std::numeric_limits<float>::max();
546 const std::vector<float> expected = {max_float, max_float, max_float, 19,
547 max_float, max_float, max_float};
548 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
549 /*window_strides=*/{1}, Padding::kValid, source,
550 ConstantR0<float>(&builder_, max_float), min_f32_);
551 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
552 }
553
554 } // namespace
555 } // namespace xla
556
557 #endif // !defined(__APPLE__)
558