• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 #include <tuple>
13 
14 #include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
15 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
16 #include <executorch/runtime/kernel/kernel_includes.h>
17 
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
24 
index_Tensor_out(KernelRuntimeContext & ctx,const Tensor & in,TensorOptList indices,Tensor & out)25 Tensor& index_Tensor_out(
26     KernelRuntimeContext& ctx,
27     const Tensor& in,
28     TensorOptList indices,
29     Tensor& out) {
30   (void)ctx;
31 
32   ET_KERNEL_CHECK(
33       ctx, check_index_args(in, indices, out), InvalidArgument, out);
34 
35   ET_KERNEL_CHECK(
36       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37 
38   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
39 
40   ScalarType in_type = in.scalar_type();
41   size_t block_count = count_index_blocks(indices);
42 
43   // If indices list is empty or all indices are null, just copy the input to
44   // output and return early.
45   if (block_count == 0) {
46     ET_KERNEL_CHECK(
47         ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
48     ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
49       const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
50       CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
51       memcpy(out_data, in_data, in.nbytes());
52     });
53     return out;
54   }
55 
56   // The output shape depends on whether all the non-null indices are adjacent
57   // or not.
58   bool adjacent = (block_count == 1);
59 
60   Tensor::SizesType expected_size[kTensorDimensionLimit];
61   size_t expected_ndim = 0;
62 
63   ET_KERNEL_CHECK(
64       ctx,
65       get_index_out_target_size(
66           in, indices, adjacent, expected_size, &expected_ndim),
67       InvalidArgument,
68       out);
69 
70   ET_KERNEL_CHECK(
71       ctx,
72       resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
73       InvalidArgument,
74       out);
75 
76   if (out.numel() == 0) {
77     return out;
78   }
79 
80   int32_t dim_map[kTensorDimensionLimit];
81   int32_t ix_map[kTensorDimensionLimit];
82   size_t start = 0;
83   size_t xdim = 0;
84 
85   if (adjacent) {
86     start = get_num_leading_null_indices(indices);
87   }
88   xdim = get_indices_broadcast_ndim(indices);
89   compute_dim_map(in, indices, dim_map, block_count == 1);
90   compute_index_map(in, indices, ix_map);
91 
92   ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
93     const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
94     CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
95 
96     for (auto out_ix = 0; out_ix < out.numel(); out_ix++) {
97       size_t in_ix = 0;
98       bool success = true;
99       std::tie(in_ix, success) =
100           get_in_ix(in, indices, out, out_ix, start, xdim, dim_map, ix_map);
101       ET_KERNEL_CHECK(ctx, success, InvalidArgument, );
102       out_data[out_ix] = in_data[in_ix];
103     }
104   });
105 
106   return out;
107 }
108 
109 } // namespace native
110 } // namespace executor
111 } // namespace torch
112