• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 
index_select_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,Tensor & out)22 Tensor& index_select_out(
23     KernelRuntimeContext& ctx,
24     const Tensor& in,
25     int64_t dim,
26     const Tensor& index,
27     Tensor& out) {
28   ET_KERNEL_CHECK(
29       ctx, check_index_select_args(in, dim, index, out), InvalidArgument, out);
30 
31   ET_KERNEL_CHECK(
32       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
33 
34   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
35 
36   if (dim < 0) {
37     dim += nonzero_dim(in);
38   }
39 
40   size_t expected_ndim = 0;
41   Tensor::SizesType expected_size[kTensorDimensionLimit];
42   get_index_select_out_target_size(
43       in, dim, index, expected_size, &expected_ndim);
44 
45   ET_KERNEL_CHECK(
46       ctx,
47       resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
48       InvalidArgument,
49       out);
50 
51   if (in.dim() == 0) {
52     memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
53     return out;
54   }
55 
56   size_t leading_dims = getLeadingDims(in, dim);
57   size_t trailing_dims = getTrailingDims(in, dim);
58 
59   if (leading_dims == 0 || trailing_dims == 0) {
60     return out;
61   }
62 
63   size_t out_dim_length = out.size(dim);
64   size_t in_dim_length = in.size(dim);
65 
66   size_t length_per_step = trailing_dims * in.element_size();
67 
68   const char* input_data = in.const_data_ptr<char>();
69   char* out_data = out.mutable_data_ptr<char>();
70 
71   ScalarType ix_type = index.scalar_type();
72 
73   ET_SWITCH_TWO_TYPES(
74       Long, Int, ix_type, ctx, "index_select.out", CTYPE, [&]() {
75         const CTYPE* const index_arr = index.mutable_data_ptr<CTYPE>();
76         for (int i = 0; i < leading_dims; i++) {
77           const char* src = input_data + i * in_dim_length * length_per_step;
78           char* dest = out_data + i * out_dim_length * length_per_step;
79           for (auto j = 0; j < out_dim_length; j++) {
80             const char* copy_src = src + index_arr[j] * length_per_step;
81             memcpy(dest, copy_src, length_per_step);
82             dest += length_per_step;
83           }
84         }
85       });
86 
87   return out;
88 }
89 
90 } // namespace native
91 } // namespace executor
92 } // namespace torch
93