1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::Scalar;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23
24 class OpMaskedFillTest : public OperatorTest {
25 protected:
op_masked_fill_scalar_out(const Tensor & self,const Tensor & mask,const Scalar & value,Tensor & out)26 Tensor& op_masked_fill_scalar_out(
27 const Tensor& self,
28 const Tensor& mask,
29 const Scalar& value,
30 Tensor& out) {
31 return torch::executor::aten::masked_fill_outf(
32 context_, self, mask, value, out);
33 }
34
35 // Common testing for masked fill of integer Tensor.
36 template <ScalarType DTYPE>
test_integer_masked_fill_scalar_out()37 void test_integer_masked_fill_scalar_out() {
38 TensorFactory<DTYPE> tf;
39 TensorFactory<ScalarType::Bool> tf_bool;
40
41 const std::vector<int32_t> sizes = {2, 2};
42
43 // Destination for the masked_fill.
44 Tensor out = tf.zeros(sizes);
45
46 // Masked fill half of the tensor.
47 op_masked_fill_scalar_out(
48 tf.make(sizes, /*data=*/{23, 29, 31, 37}),
49 tf_bool.make(sizes, /*data=*/{false, true, true, false}),
50 /*value=*/71,
51 out);
52
53 // Check that it matches the expected output.
54 EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{23, 71, 71, 37}));
55 }
56
57 // Common testing for masked fill of floating point Tensor.
58 template <ScalarType DTYPE>
test_floating_point_masked_fill_scalar_out()59 void test_floating_point_masked_fill_scalar_out() {
60 TensorFactory<DTYPE> tf;
61 TensorFactory<ScalarType::Bool> tf_bool;
62
63 const std::vector<int32_t> sizes = {2, 2};
64
65 // Destination for the masked_fill.
66 Tensor out = tf.zeros(sizes);
67
68 // Masked fill half of the tensor.
69 op_masked_fill_scalar_out(
70 tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
71 tf_bool.make(sizes, /*data=*/{true, false, false, true}),
72 /*value=*/3.3,
73 out);
74
75 // Check that it matches the expected output.
76 EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{3.3, 2.2, 4.4, 3.3}));
77 }
78 };
79
TEST_F(OpMaskedFillTest,ByteTensors)80 TEST_F(OpMaskedFillTest, ByteTensors) {
81 test_integer_masked_fill_scalar_out<ScalarType::Byte>();
82 }
83
TEST_F(OpMaskedFillTest,CharTensors)84 TEST_F(OpMaskedFillTest, CharTensors) {
85 test_integer_masked_fill_scalar_out<ScalarType::Char>();
86 }
87
TEST_F(OpMaskedFillTest,ShortTensors)88 TEST_F(OpMaskedFillTest, ShortTensors) {
89 test_integer_masked_fill_scalar_out<ScalarType::Short>();
90 }
91
TEST_F(OpMaskedFillTest,IntTensors)92 TEST_F(OpMaskedFillTest, IntTensors) {
93 test_integer_masked_fill_scalar_out<ScalarType::Int>();
94 }
95
TEST_F(OpMaskedFillTest,LongTensors)96 TEST_F(OpMaskedFillTest, LongTensors) {
97 test_integer_masked_fill_scalar_out<ScalarType::Long>();
98 }
99
TEST_F(OpMaskedFillTest,IntTensorFloatAlphaDies)100 TEST_F(OpMaskedFillTest, IntTensorFloatAlphaDies) {
101 // add_out() doesn't handle floating alpha for intergal inputs
102 TensorFactory<ScalarType::Int> tf;
103
104 const std::vector<int32_t> sizes = {2, 2};
105
106 // Destination for the op.
107 Tensor out = tf.zeros(sizes);
108
109 // Elementwise add operation on two integral tensor with floating alpha
110 // should cause an assertion and kill the test process.
111 ET_EXPECT_KERNEL_FAILURE(
112 context_,
113 op_masked_fill_scalar_out(
114 tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out));
115 }
116
TEST_F(OpMaskedFillTest,FloatTensors)117 TEST_F(OpMaskedFillTest, FloatTensors) {
118 test_floating_point_masked_fill_scalar_out<ScalarType::Float>();
119 }
120
TEST_F(OpMaskedFillTest,DoubleTensors)121 TEST_F(OpMaskedFillTest, DoubleTensors) {
122 test_floating_point_masked_fill_scalar_out<ScalarType::Double>();
123 }
124
TEST_F(OpMaskedFillTest,BoolTensors)125 TEST_F(OpMaskedFillTest, BoolTensors) {
126 TensorFactory<ScalarType::Bool> tf;
127
128 const std::vector<int32_t> sizes = {2, 2};
129
130 // Input and mask
131 Tensor self = tf.make(sizes, /*data=*/{false, true, false, true});
132 Tensor mask = tf.make(sizes, /*data=*/{true, false, true, false});
133
134 // Destination for the masked_fill.
135 Tensor out = tf.zeros(sizes);
136
137 op_masked_fill_scalar_out(self, mask, /*value=*/true, out);
138 // Check that it matches the expected output.
139 EXPECT_TENSOR_CLOSE(out, tf.ones(sizes));
140 }
141
142 // The input tensor and value may not have different dtypes.
TEST_F(OpMaskedFillTest,MismatchedInputAndValueDtypesDies)143 TEST_F(OpMaskedFillTest, MismatchedInputAndValueDtypesDies) {
144 TensorFactory<ScalarType::Byte> tf_byte;
145 TensorFactory<ScalarType::Char> tf_char;
146
147 const std::vector<int32_t> sizes = {2, 2};
148
149 // Dummy input and mask value
150 Tensor self = tf_byte.ones(sizes);
151 Tensor mask = tf_char.ones(sizes);
152
153 // Destination for the fill; matches the type of the input.
154 Tensor out = tf_byte.zeros(sizes);
155
156 // Filling tensor with mismatched scalar should cause an assertion and kill
157 // the test process.
158 ET_EXPECT_KERNEL_FAILURE(
159 context_, op_masked_fill_scalar_out(self, mask, /*value=*/1.3, out));
160 }
161
162 // The output tensor may not have a dtype different from the inputs even if it
163 // has the same shape.
TEST_F(OpMaskedFillTest,MismatchedOutputDtypeDies)164 TEST_F(OpMaskedFillTest, MismatchedOutputDtypeDies) {
165 // Two different dtypes. This test uses two types with the same size to
166 // demonstrate that the ScalarType itself matters, not the size of the
167 // tensor elements.
168 TensorFactory<ScalarType::Bool> tf_bool;
169 TensorFactory<ScalarType::Byte> tf_byte;
170 TensorFactory<ScalarType::Char> tf_char;
171
172 const std::vector<int32_t> sizes = {2, 2};
173
174 // Input and mask
175 Tensor self = tf_byte.ones(sizes);
176 Tensor mask = tf_bool.ones(sizes);
177
178 // Destination with a dtype then input.
179 Tensor out = tf_char.zeros(sizes);
180
181 // Filling the tensor into a mismatched output should cause an assertion and
182 // kill the test process.
183 ET_EXPECT_KERNEL_FAILURE(
184 context_, op_masked_fill_scalar_out(self, mask, /*fill=*/0, out));
185 }
186 // The mask tensor type must be bool, even if shapes are the same
TEST_F(OpMaskedFillTest,MismatchedMaskDtypeDies)187 TEST_F(OpMaskedFillTest, MismatchedMaskDtypeDies) {
188 TensorFactory<ScalarType::Int> tf;
189
190 const std::vector<int32_t> sizes = {2, 2};
191
192 // Input and destination
193 Tensor self = tf.ones(sizes);
194 Tensor out = tf.zeros(sizes);
195
196 // Mask tensor with non bool dtype
197 Tensor mask = tf.ones(sizes);
198
199 // Filling the tensor using non boolean mask should cause an assertion and
200 // kill the test process.
201 ET_EXPECT_KERNEL_FAILURE(
202 context_, op_masked_fill_scalar_out(self, mask, /*fill=*/0, out));
203 }
204
205 // Mismatched shape tests.
TEST_F(OpMaskedFillTest,MismatchedInputShapesDies)206 TEST_F(OpMaskedFillTest, MismatchedInputShapesDies) {
207 TensorFactory<ScalarType::Int> tf;
208 TensorFactory<ScalarType::Bool> tf_bool;
209
210 // Input and mask of different shapes that cannot be broadcasted.
211 Tensor self = tf.ones(/*sizes=*/{4});
212 Tensor mask = tf_bool.ones(/*sizes=*/{2});
213
214 // Destination for the sum; matches the shape of one of the inputs.
215 Tensor out = tf.zeros(/*sizes=*/{4});
216
217 // Masked fill with mismatch input and mask shapes should cause an assertion
218 // and kill the test process.
219 ET_EXPECT_KERNEL_FAILURE(
220 context_, op_masked_fill_scalar_out(self, mask, /*value=*/0, out));
221 }
222
TEST_F(OpMaskedFillTest,BroadcastTest)223 TEST_F(OpMaskedFillTest, BroadcastTest) {
224 TensorFactory<ScalarType::Int> tf;
225 TensorFactory<ScalarType::Bool> tf_bool;
226
227 // Input and mask of different shapes
228 Tensor self = tf.make({2, 2}, /*data=*/{1, 2, 4, 8});
229 Tensor mask = tf_bool.make({2}, /*data=*/{true, false});
230
231 // Destination for the masked_fill.
232 Tensor out = tf.zeros({2, 2});
233
234 // Masked fill half of the tensor.
235 op_masked_fill_scalar_out(
236 self,
237 mask,
238 /*value=*/3,
239 out);
240
241 // Check that it matches the expected output.
242 EXPECT_TENSOR_CLOSE(out, tf.make({2, 2}, /*data=*/{3, 2, 3, 8}));
243 }
244
TEST_F(OpMaskedFillTest,MismatchedOutputShapesDies)245 TEST_F(OpMaskedFillTest, MismatchedOutputShapesDies) {
246 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
247 GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
248 }
249 TensorFactory<ScalarType::Int> tf;
250 TensorFactory<ScalarType::Bool> tf_bool;
251
252 const std::vector<int32_t> sizes = {2, 2};
253
254 // Input and mask of different shapes
255 Tensor a = tf.ones(sizes);
256 Tensor b = tf_bool.ones(sizes);
257
258 // Destination with a different shape.
259 Tensor out = tf.zeros(/*sizes=*/{4});
260
261 // Mask filling the tensor into a mismatched output should cause an assertion
262 // and kill the test process.
263 ET_EXPECT_KERNEL_FAILURE(
264 context_, op_masked_fill_scalar_out(a, b, /*value=*/0, out));
265 }
266
TEST_F(OpMaskedFillTest,BroadcastDimSizeIsOneAB)267 TEST_F(OpMaskedFillTest, BroadcastDimSizeIsOneAB) {
268 TensorFactory<ScalarType::Float> tf;
269 TensorFactory<ScalarType::Bool> bool_tf;
270
271 Tensor x = tf.make(
272 {3, 2},
273 {0.9701170325279236,
274 0.4185227155685425,
275 0.39851099252700806,
276 0.8725584745407104,
277 0.714692234992981,
278 0.3167606592178345});
279 Tensor y = bool_tf.make({1, 2}, {false, false});
280 Tensor expected_result = tf.make(
281 {3, 2},
282 {0.9701170325279236,
283 0.4185227155685425,
284 0.39851099252700806,
285 0.8725584745407104,
286 0.714692234992981,
287 0.3167606592178345});
288
289 Tensor out = tf.zeros({3, 2});
290 Tensor ret = op_masked_fill_scalar_out(x, y, Scalar(3.0), out);
291 EXPECT_TENSOR_CLOSE(out, expected_result);
292 }
293
TEST_F(OpMaskedFillTest,BroadcastDimSizeMissingAB)294 TEST_F(OpMaskedFillTest, BroadcastDimSizeMissingAB) {
295 TensorFactory<ScalarType::Float> tf;
296 TensorFactory<ScalarType::Bool> bool_tf;
297
298 Tensor x = tf.make(
299 {3, 2},
300 {0.9701170325279236,
301 0.4185227155685425,
302 0.39851099252700806,
303 0.8725584745407104,
304 0.714692234992981,
305 0.3167606592178345});
306 Tensor y = bool_tf.make({2}, {false, false});
307 Tensor expected_result = tf.make(
308 {3, 2},
309 {0.9701170325279236,
310 0.4185227155685425,
311 0.39851099252700806,
312 0.8725584745407104,
313 0.714692234992981,
314 0.3167606592178345});
315
316 Tensor out = tf.zeros({3, 2});
317 Tensor ret = op_masked_fill_scalar_out(x, y, Scalar(3.0), out);
318 EXPECT_TENSOR_CLOSE(out, expected_result);
319 }
320
TEST_F(OpMaskedFillTest,DynamicShapeUpperBoundSameAsExpected)321 TEST_F(OpMaskedFillTest, DynamicShapeUpperBoundSameAsExpected) {
322 TensorFactory<ScalarType::Float> tf;
323 TensorFactory<ScalarType::Bool> bool_tf;
324
325 Tensor x = tf.make(
326 {3, 2},
327 {0.974706768989563,
328 0.46383917331695557,
329 0.050839245319366455,
330 0.26296138763427734,
331 0.8404526114463806,
332 0.49675875902175903});
333 Tensor y = bool_tf.make({3, 2}, {false, false, false, false, false, true});
334 Tensor expected_result = tf.make(
335 {3, 2},
336 {0.974706768989563,
337 0.46383917331695557,
338 0.050839245319366455,
339 0.26296138763427734,
340 0.8404526114463806,
341 3.0});
342
343 Tensor out =
344 tf.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
345 Tensor ret = op_masked_fill_scalar_out(x, y, Scalar(3.0), out);
346 EXPECT_TENSOR_CLOSE(out, expected_result);
347 }
348
TEST_F(OpMaskedFillTest,DynamicShapeUpperBoundLargerThanExpected)349 TEST_F(OpMaskedFillTest, DynamicShapeUpperBoundLargerThanExpected) {
350 TensorFactory<ScalarType::Float> tf;
351 TensorFactory<ScalarType::Bool> bool_tf;
352
353 Tensor x = tf.make(
354 {3, 2},
355 {0.974706768989563,
356 0.46383917331695557,
357 0.050839245319366455,
358 0.26296138763427734,
359 0.8404526114463806,
360 0.49675875902175903});
361 Tensor y = bool_tf.make({3, 2}, {false, false, false, false, false, true});
362 Tensor expected_result = tf.make(
363 {3, 2},
364 {0.974706768989563,
365 0.46383917331695557,
366 0.050839245319366455,
367 0.26296138763427734,
368 0.8404526114463806,
369 3.0});
370
371 Tensor out =
372 tf.zeros({6, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
373 Tensor ret = op_masked_fill_scalar_out(x, y, Scalar(3.0), out);
374 EXPECT_TENSOR_CLOSE(out, expected_result);
375 }
376
TEST_F(OpMaskedFillTest,DynamicShapeUnbound)377 TEST_F(OpMaskedFillTest, DynamicShapeUnbound) {
378 GTEST_SKIP() << "Dynamic shape unbound not supported";
379 TensorFactory<ScalarType::Float> tf;
380 TensorFactory<ScalarType::Bool> bool_tf;
381
382 Tensor x = tf.make(
383 {3, 2},
384 {0.974706768989563,
385 0.46383917331695557,
386 0.050839245319366455,
387 0.26296138763427734,
388 0.8404526114463806,
389 0.49675875902175903});
390 Tensor y = bool_tf.make({3, 2}, {false, false, false, false, false, true});
391 Tensor expected_result = tf.make(
392 {3, 2},
393 {0.974706768989563,
394 0.46383917331695557,
395 0.050839245319366455,
396 0.26296138763427734,
397 0.8404526114463806,
398 3.0});
399
400 Tensor out =
401 tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
402 Tensor ret = op_masked_fill_scalar_out(x, y, Scalar(3.0), out);
403 EXPECT_TENSOR_CLOSE(out, expected_result);
404 }
405
TEST_F(OpMaskedFillTest,BroadcastDimSizeIsOneBA)406 TEST_F(OpMaskedFillTest, BroadcastDimSizeIsOneBA) {
407 TensorFactory<ScalarType::Float> tf;
408 TensorFactory<ScalarType::Bool> tf_bool;
409
410 auto x = tf.make(
411 {3, 2},
412 {0.38566190004348755,
413 0.47776442766189575,
414 0.1954779028892517,
415 0.6691004633903503,
416 0.6580829620361328,
417 0.48968571424484253});
418 auto y = tf_bool.make({2}, {false, false});
419 auto z = Scalar(3.0);
420 Tensor expected_result = tf.make(
421 {3, 2},
422 {0.38566190004348755,
423 0.47776442766189575,
424 0.1954779028892517,
425 0.6691004633903503,
426 0.6580829620361328,
427 0.48968571424484253});
428
429 Tensor out = tf.zeros({3, 2});
430 Tensor ret = op_masked_fill_scalar_out(x, y, z, out);
431 EXPECT_TENSOR_CLOSE(out, expected_result);
432 }
433
TEST_F(OpMaskedFillTest,BroadcastDimSizeMissingBA)434 TEST_F(OpMaskedFillTest, BroadcastDimSizeMissingBA) {
435 TensorFactory<ScalarType::Float> tf;
436 TensorFactory<ScalarType::Bool> tf_bool;
437
438 auto x = tf.make(
439 {3, 2},
440 {0.38566190004348755,
441 0.47776442766189575,
442 0.1954779028892517,
443 0.6691004633903503,
444 0.6580829620361328,
445 0.48968571424484253});
446 auto y = tf_bool.make({2}, {false, false});
447 auto z = Scalar(3.0);
448 Tensor expected_result = tf.make(
449 {3, 2},
450 {0.38566190004348755,
451 0.47776442766189575,
452 0.1954779028892517,
453 0.6691004633903503,
454 0.6580829620361328,
455 0.48968571424484253});
456
457 Tensor out = tf.zeros({3, 2});
458 Tensor ret = op_masked_fill_scalar_out(x, y, z, out);
459 EXPECT_TENSOR_CLOSE(out, expected_result);
460 }
461