1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstdint>
10 #include <cstring>
11
12 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using TensorList = exec_aten::TensorList;
21
22 /**
23 * Splits the tensor into chunks of size `split_size` along the specified
24 * dimension.
25 *
26 * The last chunk will be smaller if the tensor size along the given dimension
27 * dim is not evenly divisible by `split_size`.
28 *
29 * split_copy.Tensor_out(Tensor input, int split_size, int dim=0, *,
30 * Tensor(a!)[] out) -> ()
31 */
split_copy_Tensor_out(KernelRuntimeContext & ctx,const Tensor & input,int64_t split_size,int64_t dim,TensorList out)32 void split_copy_Tensor_out(
33 KernelRuntimeContext& ctx,
34 const Tensor& input,
35 int64_t split_size,
36 int64_t dim,
37 TensorList out) {
38 (void)ctx;
39 // Support python-style negative indexing.
40 if (dim < 0) {
41 dim += input.dim();
42 }
43
44 ET_KERNEL_CHECK(
45 ctx,
46 check_split_copy_args(input, split_size, dim, out),
47 InvalidArgument, );
48
49 for (size_t i = 0; i < out.size(); ++i) {
50 ET_KERNEL_CHECK(
51 ctx, tensors_have_same_dim_order(input, out[i]), InvalidArgument, );
52 }
53
54 const size_t leading_dims = getLeadingDims(input, dim);
55 const size_t trailing_dims = getTrailingDims(input, dim);
56 const size_t step = input.size(dim) * trailing_dims;
57
58 ScalarType in_type = input.scalar_type();
59 ScalarType out_type = out[0].scalar_type();
60
61 ET_SWITCH_REAL_TYPES_AND(
62 Bool, in_type, ctx, "split_copy.Tensor_out", CTYPE_IN, [&]() {
63 ET_SWITCH_REAL_TYPES_AND(
64 Bool, out_type, ctx, "split_copy.Tensor_out", CTYPE_OUT, [&]() {
65 const CTYPE_IN* input_data = input.const_data_ptr<CTYPE_IN>();
66 for (size_t i = 0, e = out.size(); i < e; ++i) {
67 size_t out_step = out[i].size(dim) * trailing_dims;
68 if (out_step == 0) {
69 continue;
70 }
71 const CTYPE_IN* src = input_data;
72 CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
73 for (size_t j = 0; j < leading_dims; ++j) {
74 for (size_t k = 0; k < out_step; ++k) {
75 dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
76 }
77 src += step;
78 dest += out_step;
79 }
80 input_data += out_step;
81 }
82 });
83 });
84 }
85
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89