1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 #include <tuple>
13
14 #include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
15 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
16 #include <executorch/runtime/kernel/kernel_includes.h>
17
18 namespace torch {
19 namespace executor {
20 namespace native {
21
22 using Tensor = exec_aten::Tensor;
23 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
24
index_Tensor_out(KernelRuntimeContext & ctx,const Tensor & in,TensorOptList indices,Tensor & out)25 Tensor& index_Tensor_out(
26 KernelRuntimeContext& ctx,
27 const Tensor& in,
28 TensorOptList indices,
29 Tensor& out) {
30 (void)ctx;
31
32 ET_KERNEL_CHECK(
33 ctx, check_index_args(in, indices, out), InvalidArgument, out);
34
35 ET_KERNEL_CHECK(
36 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37
38 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
39
40 ScalarType in_type = in.scalar_type();
41 size_t block_count = count_index_blocks(indices);
42
43 // If indices list is empty or all indices are null, just copy the input to
44 // output and return early.
45 if (block_count == 0) {
46 ET_KERNEL_CHECK(
47 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
48 ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
49 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
50 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
51 memcpy(out_data, in_data, in.nbytes());
52 });
53 return out;
54 }
55
56 // The output shape depends on whether all the non-null indices are adjacent
57 // or not.
58 bool adjacent = (block_count == 1);
59
60 Tensor::SizesType expected_size[kTensorDimensionLimit];
61 size_t expected_ndim = 0;
62
63 ET_KERNEL_CHECK(
64 ctx,
65 get_index_out_target_size(
66 in, indices, adjacent, expected_size, &expected_ndim),
67 InvalidArgument,
68 out);
69
70 ET_KERNEL_CHECK(
71 ctx,
72 resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
73 InvalidArgument,
74 out);
75
76 if (out.numel() == 0) {
77 return out;
78 }
79
80 int32_t dim_map[kTensorDimensionLimit];
81 int32_t ix_map[kTensorDimensionLimit];
82 size_t start = 0;
83 size_t xdim = 0;
84
85 if (adjacent) {
86 start = get_num_leading_null_indices(indices);
87 }
88 xdim = get_indices_broadcast_ndim(indices);
89 compute_dim_map(in, indices, dim_map, block_count == 1);
90 compute_index_map(in, indices, ix_map);
91
92 ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
93 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
94 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
95
96 for (auto out_ix = 0; out_ix < out.numel(); out_ix++) {
97 size_t in_ix = 0;
98 bool success = true;
99 std::tie(in_ix, success) =
100 get_in_ix(in, indices, out, out_ix, start, xdim, dim_map, ix_map);
101 ET_KERNEL_CHECK(ctx, success, InvalidArgument, );
102 out_data[out_ix] = in_data[in_ix];
103 }
104 });
105
106 return out;
107 }
108
109 } // namespace native
110 } // namespace executor
111 } // namespace torch
112