• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 
18 using Tensor = exec_aten::Tensor;
19 
cat_out(KernelRuntimeContext & ctx,exec_aten::ArrayRef<Tensor> tensors,int64_t dim,Tensor & out)20 Tensor& cat_out(
21     KernelRuntimeContext& ctx,
22     exec_aten::ArrayRef<Tensor> tensors,
23     int64_t dim,
24     Tensor& out) {
25   if (dim < 0) {
26     dim += out.dim();
27   }
28 
29   ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), InvalidArgument, out);
30 
31   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
32   size_t expected_out_dim = 0;
33   get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
34 
35   ET_KERNEL_CHECK(
36       ctx,
37       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
38       InvalidArgument,
39       out);
40 
41   // Special handling when all inputs are 1D-empty tensors for aten consistency
42   // In that case, just return an 1D-empty tensor without checking dim
43   bool all_1d_empty = true;
44   for (size_t i = 0; i < tensors.size(); ++i) {
45     if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
46       all_1d_empty = false;
47       break;
48     }
49   }
50   if (all_1d_empty) {
51     return out;
52   }
53 
54   const size_t outer = getLeadingDims(out, dim);
55   const size_t dim_stride = getTrailingDims(out, dim);
56   const size_t ninputs = tensors.size();
57 
58   const auto out_type = out.scalar_type();
59   ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
60     CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
61     for (size_t i = 0; i < outer; ++i) {
62       for (size_t j = 0; j < ninputs; ++j) {
63         const auto in_type = tensors[j].scalar_type();
64         ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
65           if (tensors[j].numel() == 0) {
66             return;
67           }
68           size_t inner = tensors[j].size(dim) * dim_stride;
69           const CTYPE_IN* const in_ptr =
70               tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
71 
72           for (size_t k = 0; k < inner; ++k) {
73             out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
74           }
75           out_ptr += inner;
76         });
77       }
78     }
79   });
80 
81   return out;
82 }
83 
84 } // namespace native
85 } // namespace executor
86 } // namespace torch
87