1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21
index_select_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,Tensor & out)22 Tensor& index_select_out(
23 KernelRuntimeContext& ctx,
24 const Tensor& in,
25 int64_t dim,
26 const Tensor& index,
27 Tensor& out) {
28 ET_KERNEL_CHECK(
29 ctx, check_index_select_args(in, dim, index, out), InvalidArgument, out);
30
31 ET_KERNEL_CHECK(
32 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
33
34 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
35
36 if (dim < 0) {
37 dim += nonzero_dim(in);
38 }
39
40 size_t expected_ndim = 0;
41 Tensor::SizesType expected_size[kTensorDimensionLimit];
42 get_index_select_out_target_size(
43 in, dim, index, expected_size, &expected_ndim);
44
45 ET_KERNEL_CHECK(
46 ctx,
47 resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
48 InvalidArgument,
49 out);
50
51 if (in.dim() == 0) {
52 memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
53 return out;
54 }
55
56 size_t leading_dims = getLeadingDims(in, dim);
57 size_t trailing_dims = getTrailingDims(in, dim);
58
59 if (leading_dims == 0 || trailing_dims == 0) {
60 return out;
61 }
62
63 size_t out_dim_length = out.size(dim);
64 size_t in_dim_length = in.size(dim);
65
66 size_t length_per_step = trailing_dims * in.element_size();
67
68 const char* input_data = in.const_data_ptr<char>();
69 char* out_data = out.mutable_data_ptr<char>();
70
71 ScalarType ix_type = index.scalar_type();
72
73 ET_SWITCH_TWO_TYPES(
74 Long, Int, ix_type, ctx, "index_select.out", CTYPE, [&]() {
75 const CTYPE* const index_arr = index.mutable_data_ptr<CTYPE>();
76 for (int i = 0; i < leading_dims; i++) {
77 const char* src = input_data + i * in_dim_length * length_per_step;
78 char* dest = out_data + i * out_dim_length * length_per_step;
79 for (auto j = 0; j < out_dim_length; j++) {
80 const char* copy_src = src + index_arr[j] * length_per_step;
81 memcpy(dest, copy_src, length_per_step);
82 dest += length_per_step;
83 }
84 }
85 });
86
87 return out;
88 }
89
90 } // namespace native
91 } // namespace executor
92 } // namespace torch
93