1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18 #include <sys/types.h>
19
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25
26 class OpIndexSelectOutTest : public OperatorTest {
27 protected:
op_index_select_out(const Tensor & self,int64_t dim,const Tensor & index,Tensor & out)28 Tensor& op_index_select_out(
29 const Tensor& self,
30 int64_t dim,
31 const Tensor& index,
32 Tensor& out) {
33 return torch::executor::aten::index_select_outf(
34 context_, self, dim, index, out);
35 }
36
37 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()38 void test_dtype() {
39 TensorFactory<DTYPE> tf;
40 TensorFactory<ScalarType::Long> tfl;
41
42 // test index_select on dimension 0.
43
44 // clang-format off
45 Tensor x = tf.make(
46 {3, 2, 4},
47 {
48 // all ones below are from x,
49 // and all zeros are from y.
50 // [0, :, :]
51 1, 1, 1, 1, // [0, 0, :]
52 0, 0, 0, 0, // [0, 1, :]
53
54 // [1, :, :]
55 1, 1, 1, 1, // [1, 0, :]
56 0, 0, 0, 0, // [1, 1, :]
57
58 // [2, :, :]
59 1, 1, 1, 1, // [2, 0, :]
60 0, 0, 0, 0, // [2, 1, :]
61 });
62 // clang-format on
63
64 // Expected values for out_0 and ret_0 after the test are all ones(3, 4)
65 // based on the above rules. So here we set the default value of out_0 as
66 // zeros(3, 4) on purpose, to eliminate the influence to the final result
67 // from initial value. Same for out_1 and ret_1.
68
69 Tensor out_0 = tf.zeros({3, 1, 4});
70 Tensor out_1 = tf.ones({3, 1, 4});
71 Tensor index_0 = tfl.make({1}, {0});
72 Tensor index_1 = tfl.make({1}, {1});
73 Tensor ret_0 = op_index_select_out(x, /*dim=*/1, /*index=*/index_0, out_0);
74 Tensor ret_1 = op_index_select_out(x, /*dim=*/1, /*index=*/index_1, out_1);
75
76 EXPECT_TENSOR_EQ(ret_0, out_0);
77 EXPECT_TENSOR_EQ(ret_1, out_1);
78
79 EXPECT_TENSOR_EQ(ret_0, tf.ones({3, 1, 4}));
80 EXPECT_TENSOR_EQ(ret_1, tf.zeros({3, 1, 4}));
81 }
82
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)83 void test_dynamic_shape(
84 const std::vector<int32_t>& out_shape,
85 enum torch::executor::TensorShapeDynamism dynamism) {
86 /* %python
87 %rewrite(index_select_template) */
88
89 TensorFactory<ScalarType::Float> tf;
90 TensorFactory<ScalarType::Long> tf_index;
91
92 Tensor input = tf.make(
93 {2, 3, 4},
94 {0.49625658988952637, 0.7682217955589294, 0.08847743272781372,
95 0.13203048706054688, 0.30742281675338745, 0.6340786814689636,
96 0.4900934100151062, 0.8964447379112244, 0.455627977848053,
97 0.6323062777519226, 0.3488934636116028, 0.40171730518341064,
98 0.022325754165649414, 0.16885894536972046, 0.2938884496688843,
99 0.518521785736084, 0.6976675987243652, 0.800011396408081,
100 0.16102945804595947, 0.28226858377456665, 0.6816085577011108,
101 0.9151939749717712, 0.39709991216659546, 0.8741558790206909});
102 Tensor index = tf_index.make({2}, {0, 2});
103 Tensor expected = tf.make(
104 {2, 3, 2},
105 {0.49625658988952637,
106 0.08847743272781372,
107 0.30742281675338745,
108 0.4900934100151062,
109 0.455627977848053,
110 0.3488934636116028,
111 0.022325754165649414,
112 0.2938884496688843,
113 0.6976675987243652,
114 0.16102945804595947,
115 0.6816085577011108,
116 0.39709991216659546});
117 Tensor out = tf.zeros(out_shape, dynamism);
118
119 op_index_select_out(input, 2, index, out);
120 EXPECT_TENSOR_CLOSE(out, expected);
121 }
122
123 // Run the test by selecting Tensor x on given dim and all available indexes
124 // on that dimension
run_test_cases(const Tensor & x,ssize_t dim,const Tensor & index,const Tensor & expected)125 void run_test_cases(
126 const Tensor& x,
127 ssize_t dim,
128 const Tensor& index,
129 const Tensor& expected) {
130 // Generated out tensor sharing same size and dtype with expected tensor
131 TensorFactory<ScalarType::Double> tf;
132
133 const std::vector<int32_t> out_size(
134 expected.sizes().begin(), expected.sizes().end());
135 Tensor out = tf.ones(out_size);
136
137 Tensor ret = op_index_select_out(x, dim, index, out);
138 EXPECT_TENSOR_EQ(out, ret);
139 EXPECT_TENSOR_EQ(ret, expected);
140 }
141 };
142
TEST_F(OpIndexSelectOutTest,SelectFrontDimAllIndexes)143 TEST_F(OpIndexSelectOutTest, SelectFrontDimAllIndexes) {
144 TensorFactory<ScalarType::Double> tf;
145 TensorFactory<ScalarType::Long> tfl;
146 // clang-format off
147 Tensor x = tf.make(
148 {2, 3, 4},
149 {
150 // [0, :, :]
151 1., 2., 3., 4., // [0, 0, :]
152 5., 6., 7., 8., // [0, 1, :]
153 9., 10., 11., 12., // [0, 2, :]
154
155 // [1, :, :]
156 -1., -2., -3., -4., // [1, 0, :]
157 -5., -6., -7., -8., // [1, 1, :]
158 -9., -10., -11., -12., // [1, 2, :]
159 });
160 // clang-format on
161
162 // Try to select the tensor from the input at 0th dimension
163 const std::vector<int32_t> out_size = {1, 3, 4};
164
165 Tensor out = tf.zeros(out_size);
166 Tensor index = tfl.make({1}, {0});
167 // clang-format off
168 Tensor expected = tf.make(
169 out_size,
170 {
171 1., 2., 3., 4., // [0, 0, :]
172 5., 6., 7., 8., // [0, 1, :]
173 9., 10., 11., 12., // [0, 2, :]
174 }
175 );
176 // clang-format on
177
178 run_test_cases(x, /*dim=*/0, /*index=*/index, expected);
179 }
180
TEST_F(OpIndexSelectOutTest,SelectMiddleDimAllIndexes)181 TEST_F(OpIndexSelectOutTest, SelectMiddleDimAllIndexes) {
182 TensorFactory<ScalarType::Double> tf;
183 TensorFactory<ScalarType::Long> tfl;
184 // clang-format off
185 Tensor x = tf.make(
186 {2, 3, 4},
187 {
188 // [0, :, :]
189 1., 2., 3., 4., // [0, 0, :]
190 5., 6., 7., 8., // [0, 1, :]
191 9., 10., 11., 12., // [0, 2, :]
192
193 // [1, :, :]
194 -1., -2., -3., -4., // [1, 0, :]
195 -5., -6., -7., -8., // [1, 1, :]
196 -9., -10., -11., -12., // [1, 2, :]
197 });
198 // clang-format on
199
200 // Try to select the tensor from the input at 1st dimension
201 const std::vector<int32_t> out_size = {2, 2, 4};
202
203 Tensor out = tf.zeros(out_size);
204 Tensor index = tfl.make({2}, {0, 2});
205 // clang-format off
206 Tensor expected = tf.make(
207 out_size,
208 {
209 1., 2., 3., 4., // [0, 0, :]
210 9., 10., 11., 12., // [0, 2, :]
211
212 -1., -2., -3., -4., // [1, 0, :]
213 -9., -10., -11., -12., // [1, 2, :]
214 }
215 );
216 // clang-format on
217
218 run_test_cases(x, /*dim=*/1, /*index=*/index, expected);
219 }
220
TEST_F(OpIndexSelectOutTest,SelectEndDimAllIndexes)221 TEST_F(OpIndexSelectOutTest, SelectEndDimAllIndexes) {
222 TensorFactory<ScalarType::Double> tf;
223 TensorFactory<ScalarType::Long> tfl;
224 // clang-format off
225 Tensor x = tf.make(
226 {2, 3, 4},
227 {
228 // [0, :, :]
229 1., 2., 3., 4., // [0, 0, :]
230 5., 6., 7., 8., // [0, 1, :]
231 9., 10., 11., 12., // [0, 2, :]
232
233 // [1, :, :]
234 -1., -2., -3., -4., // [1, 0, :]
235 -5., -6., -7., -8., // [1, 1, :]
236 -9., -10., -11., -12., // [1, 2, :]
237 });
238 // clang-format on
239
240 // Try to select the tensor from the input at 0th dimension
241 const std::vector<int32_t> out_size = {2, 3, 2};
242
243 Tensor out = tf.zeros(out_size);
244 Tensor index = tfl.make({2}, {0, 2});
245 // clang-format off
246 Tensor expected = tf.make(
247 out_size,
248 {
249 // [0, :, :]
250 1., 3.,
251 5., 7.,
252 9., 11.,
253
254 // [1, :, :]
255 -1., -3.,
256 -5., -7.,
257 -9., -11.,
258 }
259 );
260 // clang-format on
261 run_test_cases(x, /*dim=*/2, /*index=*/index, expected);
262 }
263
264 /// A generic smoke test that works for any dtype that supports ones() and
265 /// zeros().
TEST_F(OpIndexSelectOutTest,AllDtypesSupported)266 TEST_F(OpIndexSelectOutTest, AllDtypesSupported) {
267 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
268 ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
269 #undef TEST_ENTRY
270 // TODO: Also add tests for half, complex, quantized, and other types. Easiest
271 // way to do that would be to make TensorFactory support zeros() and ones()
272 // for those types.
273 }
274
275 //////////////////////////////////////////////////////////////////////////////
276 // The following tests focus on empty-size tensor and empty tensor.
277 // Here we first define the term:
278 // empty-size tensor: size is [] but do have data (e.g.tensor(5))
279 // empty tensor: size is not [] and the size of at least one
280 // dim is zero, and does not have data in it (e.g ones(1,0,2,3))
281
282 // In this test we are gonnna find if our select function support non-empty
283 // tensor input and empty-size tensor output.
TEST_F(OpIndexSelectOutTest,NonEmptyInputEmptyOutputWithMismatchDimDies)284 TEST_F(OpIndexSelectOutTest, NonEmptyInputEmptyOutputWithMismatchDimDies) {
285 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
286 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
287 }
288 TensorFactory<ScalarType::Int> tf;
289 TensorFactory<ScalarType::Long> tfl;
290 Tensor x = tf.make({10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
291 Tensor index = tfl.make({1}, {5});
292
293 // Make an empty-size out tensor and demonstrate that it has data.
294 Tensor out = tf.make({}, {0});
295 EXPECT_EQ(out.numel(), 1);
296
297 // pass the empty-size tensor to the function,
298 Tensor expect = tf.make({}, {5});
299 ET_EXPECT_KERNEL_FAILURE(
300 context_, op_index_select_out(x, /*dim=*/0, /*index=*/index, out));
301 }
302
303 // This test focuses on the support for empty tensor (dim() > 0) input and empty
304 // tensor output
TEST_F(OpIndexSelectOutTest,EmptyInputEmptyOutputWithMatchingDimSupported)305 TEST_F(OpIndexSelectOutTest, EmptyInputEmptyOutputWithMatchingDimSupported) {
306 TensorFactory<ScalarType::Int> tf;
307 TensorFactory<ScalarType::Long> tfl;
308
309 Tensor index = tfl.make({1}, {3});
310
311 // Using empty tensors as input.
312 Tensor x = tf.make({3, 0, 10, 3}, {});
313 EXPECT_EQ(x.numel(), 0);
314
315 // Output whose shape is appropriate for selecting along dim(2)
316 Tensor out = tf.make({3, 0, 1, 3}, {});
317 EXPECT_EQ(out.numel(), 0);
318
319 Tensor ret = op_index_select_out(x, /*dim=*/2, /*index=*/index, out);
320 EXPECT_EQ(ret.numel(), 0);
321 // Success if it doesn't assert on the weird-shaped empty input and the
322 // ret is still a empty array
323 }
324
325 ///////////////////////////////////////////////////////////////////////
326
TEST_F(OpIndexSelectOutTest,DimOutOfBoundDies)327 TEST_F(OpIndexSelectOutTest, DimOutOfBoundDies) {
328 TensorFactory<ScalarType::Int> tf;
329 TensorFactory<ScalarType::Long> tfl;
330
331 Tensor x = tf.ones({1, 1, 1});
332 Tensor out = tf.zeros({1, 1, 1});
333 Tensor index = tfl.make({1}, {0});
334
335 // Some invalid dim values.
336 const std::vector<int32_t> invalid_dims = {3, 4, 5, -4, -5, -6};
337 for (ssize_t dim : invalid_dims) {
338 ET_EXPECT_KERNEL_FAILURE(
339 context_, op_index_select_out(x, dim, /*index=*/index, out));
340 }
341 }
342
TEST_F(OpIndexSelectOutTest,MismatchedDtypesDies)343 TEST_F(OpIndexSelectOutTest, MismatchedDtypesDies) {
344 TensorFactory<ScalarType::Int> tf_int;
345 TensorFactory<ScalarType::Float> tf_float;
346 TensorFactory<ScalarType::Long> tf_long;
347
348 Tensor x = tf_int.zeros({1, 2, 2});
349
350 // Size is compatible to the output, but a mismatched dtype.
351 Tensor out = tf_float.ones({1, 2, 2});
352 Tensor index = tf_long.make({1}, {0});
353
354 ET_EXPECT_KERNEL_FAILURE(
355 context_, op_index_select_out(x, /*dim=*/0, /*index=*/index, out));
356 }
357
TEST_F(OpIndexSelectOutTest,OutMatchNumelLackDimAtEndDies)358 TEST_F(OpIndexSelectOutTest, OutMatchNumelLackDimAtEndDies) {
359 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
360 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
361 }
362 TensorFactory<ScalarType::Int> tf;
363 TensorFactory<ScalarType::Long> tfl;
364
365 Tensor x = tf.zeros({1, 2, 2, 1});
366 Tensor index = tfl.make({1}, {0});
367
368 // Out shares the same dtype and numel as the expected output, but a
369 // mixmatched size (out.dim() should always equal to x.dim())
370 Tensor out = tf.ones({1, 2, 2});
371
372 ET_EXPECT_KERNEL_FAILURE(
373 context_, op_index_select_out(x, /*dim=*/0, /*index=*/index, out));
374 }
375
TEST_F(OpIndexSelectOutTest,OutMatchNumelExtraDimAtFrontDies)376 TEST_F(OpIndexSelectOutTest, OutMatchNumelExtraDimAtFrontDies) {
377 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
378 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
379 }
380 TensorFactory<ScalarType::Int> tf;
381 TensorFactory<ScalarType::Long> tfl;
382
383 Tensor x = tf.zeros({2, 2});
384 Tensor index = tfl.make({1}, {0});
385
386 // Out shares the same dtype as the expected output, but a
387 // mismatched size
388 Tensor out = tf.ones({1, 1, 2});
389
390 ET_EXPECT_KERNEL_FAILURE(
391 context_, op_index_select_out(x, /*dim=*/0, /*index=*/index, out));
392 }
393
TEST_F(OpIndexSelectOutTest,OutSizeMismatchDimDies)394 TEST_F(OpIndexSelectOutTest, OutSizeMismatchDimDies) {
395 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
396 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
397 }
398 TensorFactory<ScalarType::Int> tf;
399 TensorFactory<ScalarType::Long> tfl;
400
401 Tensor x = tf.zeros({2, 4, 7, 5});
402 Tensor index = tfl.make({1}, {3});
403
404 // Should be {2, 4, 1, 5} to match the x when calling index_select() with
405 // dim 2.
406 Tensor out = tf.zeros({2, 4, 7});
407
408 ET_EXPECT_KERNEL_FAILURE(
409 context_, op_index_select_out(x, /*dim=*/2, /*index=*/index, out));
410 }
411
TEST_F(OpIndexSelectOutTest,IndexWithInvalidDtypeDies)412 TEST_F(OpIndexSelectOutTest, IndexWithInvalidDtypeDies) {
413 TensorFactory<ScalarType::Int> tf;
414 TensorFactory<ScalarType::Float> tff;
415
416 Tensor x = tf.zeros({2, 4, 7, 5});
417 Tensor index = tff.make({1}, {3});
418
419 Tensor out = tf.zeros({2, 1, 7, 5});
420
421 ET_EXPECT_KERNEL_FAILURE(
422 context_, op_index_select_out(x, /*dim=*/1, /*index=*/index, out));
423 }
424
TEST_F(OpIndexSelectOutTest,IndexWithInvalidDimDies)425 TEST_F(OpIndexSelectOutTest, IndexWithInvalidDimDies) {
426 TensorFactory<ScalarType::Int> tf;
427 TensorFactory<ScalarType::Long> tfl;
428
429 Tensor x = tf.zeros({2, 4, 7, 5});
430 // 2-D Tensor, will error out
431 Tensor index = tfl.make({1, 1}, {3});
432
433 Tensor out = tf.zeros({2, 1, 7, 5});
434
435 ET_EXPECT_KERNEL_FAILURE(
436 context_, op_index_select_out(x, /*dim=*/1, /*index=*/index, out));
437 }
438
439 #if !defined(USE_ATEN_LIB)
TEST_F(OpIndexSelectOutTest,UpperBoundOutTensor)440 TEST_F(OpIndexSelectOutTest, UpperBoundOutTensor) {
441 TensorFactory<ScalarType::Double> tf;
442 TensorFactory<ScalarType::Long> tfl;
443 // clang-format off
444 Tensor x = tf.make(
445 {2, 3, 4},
446 {
447 // [0, :, :]
448 1., 2., 3., 4., // [0, 0, :]
449 5., 6., 7., 8., // [0, 1, :]
450 9., 10., 11., 12., // [0, 2, :]
451
452 // [1, :, :]
453 -1., -2., -3., -4., // [1, 0, :]
454 -5., -6., -7., -8., // [1, 1, :]
455 -9., -10., -11., -12., // [1, 2, :]
456 });
457 // clang-format on
458
459 // Try to select the tensor from the input at 0th dimension
460 const std::vector<int32_t> out_size = {1, 3, 4};
461
462 Tensor out =
463 tf.zeros({2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
464 Tensor index = tfl.make({1}, {0});
465 // clang-format off
466 Tensor expected = tf.make(
467 out_size,
468 {
469 1., 2., 3., 4., // [0, 0, :]
470 5., 6., 7., 8., // [0, 1, :]
471 9., 10., 11., 12., // [0, 2, :]
472 }
473 );
474 // clang-format on
475
476 Tensor ret = op_index_select_out(x, 0, index, out);
477 EXPECT_TENSOR_EQ(out, ret);
478 EXPECT_TENSOR_EQ(ret, expected);
479 }
480 #endif
481
482 /* %python
483 import torch
484 torch.manual_seed(0)
485 input = torch.rand(2, 3, 4)
486 index = torch.tensor([0, 2])
487 dim = 2
488 expected = torch.index_select(input, dim, index)
489
490 index_select_template = f"""
491 {declare_tensor_factory("ScalarType::Float", "tf")}
492 {declare_tensor_factory("ScalarType::Long", "tf_index")}
493
494 {declare_tensor_make_t("input", "tf")}
495 {declare_tensor_make_t("index", "tf_index")}
496 {declare_tensor_make_t("expected", "tf")}
497 {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
498
499 op_index_select_out(input, $dim$, index, out);
500 EXPECT_TENSOR_CLOSE(out, expected);""" */
501
TEST_F(OpIndexSelectOutTest,DynamicShapeUpperBoundSameAsExpected)502 TEST_F(OpIndexSelectOutTest, DynamicShapeUpperBoundSameAsExpected) {
503 test_dynamic_shape(
504 {2, 3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
505 }
506
TEST_F(OpIndexSelectOutTest,DynamicShapeUpperBoundLargerThanExpected)507 TEST_F(OpIndexSelectOutTest, DynamicShapeUpperBoundLargerThanExpected) {
508 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
509 GTEST_SKIP() << "Dynamic shape not supported";
510 }
511 test_dynamic_shape(
512 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
513 }
514
TEST_F(OpIndexSelectOutTest,DynamicShapeUnbound)515 TEST_F(OpIndexSelectOutTest, DynamicShapeUnbound) {
516 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
517 GTEST_SKIP() << "Dynamic shape not supported";
518 }
519 test_dynamic_shape(
520 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
521 }
522