• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdint>
10 #include <cstring>
11 
12 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using TensorList = exec_aten::TensorList;
21 
22 /**
23  * Splits the tensor into chunks of size `split_size` along the specified
24  * dimension.
25  *
26  * The last chunk will be smaller if the tensor size along the given dimension
27  * dim is not evenly divisible by `split_size`.
28  *
29  * split_copy.Tensor_out(Tensor input, int split_size, int dim=0, *,
30  * Tensor(a!)[] out) -> ()
31  */
split_copy_Tensor_out(KernelRuntimeContext & ctx,const Tensor & input,int64_t split_size,int64_t dim,TensorList out)32 void split_copy_Tensor_out(
33     KernelRuntimeContext& ctx,
34     const Tensor& input,
35     int64_t split_size,
36     int64_t dim,
37     TensorList out) {
38   (void)ctx;
39   // Support python-style negative indexing.
40   if (dim < 0) {
41     dim += input.dim();
42   }
43 
44   ET_KERNEL_CHECK(
45       ctx,
46       check_split_copy_args(input, split_size, dim, out),
47       InvalidArgument, );
48 
49   for (size_t i = 0; i < out.size(); ++i) {
50     ET_KERNEL_CHECK(
51         ctx, tensors_have_same_dim_order(input, out[i]), InvalidArgument, );
52   }
53 
54   const size_t leading_dims = getLeadingDims(input, dim);
55   const size_t trailing_dims = getTrailingDims(input, dim);
56   const size_t step = input.size(dim) * trailing_dims;
57 
58   ScalarType in_type = input.scalar_type();
59   ScalarType out_type = out[0].scalar_type();
60 
61   ET_SWITCH_REAL_TYPES_AND(
62       Bool, in_type, ctx, "split_copy.Tensor_out", CTYPE_IN, [&]() {
63         ET_SWITCH_REAL_TYPES_AND(
64             Bool, out_type, ctx, "split_copy.Tensor_out", CTYPE_OUT, [&]() {
65               const CTYPE_IN* input_data = input.const_data_ptr<CTYPE_IN>();
66               for (size_t i = 0, e = out.size(); i < e; ++i) {
67                 size_t out_step = out[i].size(dim) * trailing_dims;
68                 if (out_step == 0) {
69                   continue;
70                 }
71                 const CTYPE_IN* src = input_data;
72                 CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
73                 for (size_t j = 0; j < leading_dims; ++j) {
74                   for (size_t k = 0; k < out_step; ++k) {
75                     dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
76                   }
77                   src += step;
78                   dest += out_step;
79                 }
80                 input_data += out_step;
81               }
82             });
83       });
84 }
85 
86 } // namespace native
87 } // namespace executor
88 } // namespace torch
89