1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17
18 using Tensor = exec_aten::Tensor;
19
cat_out(KernelRuntimeContext & ctx,exec_aten::ArrayRef<Tensor> tensors,int64_t dim,Tensor & out)20 Tensor& cat_out(
21 KernelRuntimeContext& ctx,
22 exec_aten::ArrayRef<Tensor> tensors,
23 int64_t dim,
24 Tensor& out) {
25 if (dim < 0) {
26 dim += out.dim();
27 }
28
29 ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), InvalidArgument, out);
30
31 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
32 size_t expected_out_dim = 0;
33 get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
34
35 ET_KERNEL_CHECK(
36 ctx,
37 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
38 InvalidArgument,
39 out);
40
41 // Special handling when all inputs are 1D-empty tensors for aten consistency
42 // In that case, just return an 1D-empty tensor without checking dim
43 bool all_1d_empty = true;
44 for (size_t i = 0; i < tensors.size(); ++i) {
45 if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
46 all_1d_empty = false;
47 break;
48 }
49 }
50 if (all_1d_empty) {
51 return out;
52 }
53
54 const size_t outer = getLeadingDims(out, dim);
55 const size_t dim_stride = getTrailingDims(out, dim);
56 const size_t ninputs = tensors.size();
57
58 const auto out_type = out.scalar_type();
59 ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60 CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61 for (size_t i = 0; i < outer; ++i) {
62 for (size_t j = 0; j < ninputs; ++j) {
63 const auto in_type = tensors[j].scalar_type();
64 ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
65 if (tensors[j].numel() == 0) {
66 return;
67 }
68 size_t inner = tensors[j].size(dim) * dim_stride;
69 const CTYPE_IN* const in_ptr =
70 tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71
72 for (size_t k = 0; k < inner; ++k) {
73 out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74 }
75 out_ptr += inner;
76 });
77 }
78 }
79 });
80
81 return out;
82 }
83
84 } // namespace native
85 } // namespace executor
86 } // namespace torch
87