• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
11 #include <executorch/kernels/portable/cpu/util/repeat_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <sys/types.h>
14 
15 #include <cstring>
16 
17 namespace torch {
18 namespace executor {
19 namespace native {
20 
21 using Tensor = exec_aten::Tensor;
22 using ScalarType = exec_aten::ScalarType;
23 using Scalar = exec_aten::Scalar;
24 using SizesType = exec_aten::SizesType;
25 
26 constexpr const size_t kTensorDimensionLimit{16};
27 
28 namespace {
29 
map_expand_to_repeats(exec_aten::ArrayRef<SizesType> self_sizes,exec_aten::ArrayRef<int64_t> expand_sizes,int64_t * repeats,const size_t repeats_size)30 size_t map_expand_to_repeats(
31     exec_aten::ArrayRef<SizesType> self_sizes,
32     exec_aten::ArrayRef<int64_t> expand_sizes,
33     int64_t* repeats,
34     const size_t repeats_size) {
35   auto j{expand_sizes.size()};
36   for (size_t i{self_sizes.size()}; i > 0 && j > 0;) {
37     --i;
38     --j;
39 
40     // Default, just copy the expand size to repeat
41     repeats[j] = expand_sizes[j];
42     if (expand_sizes[j] == -1 || expand_sizes[j] == self_sizes[i]) {
43       repeats[j] = 1;
44     }
45   }
46 
47   while (j > 0) {
48     --j;
49     repeats[j] = expand_sizes[j];
50   }
51 
52   return expand_sizes.size();
53 }
54 } // namespace
55 
expand_copy_out(KernelRuntimeContext & ctx,const Tensor & self,ArrayRef<int64_t> expand_sizes,bool implicit,Tensor & out)56 Tensor& expand_copy_out(
57     KernelRuntimeContext& ctx,
58     const Tensor& self,
59     ArrayRef<int64_t> expand_sizes,
60     bool implicit,
61     Tensor& out) {
62   (void)ctx;
63 
64   ET_KERNEL_CHECK(
65       ctx,
66       check_expand_copy_args(self, expand_sizes, implicit, out),
67       InvalidArgument,
68       out);
69 
70   const auto& self_sizes = self.sizes();
71 
72   // Holds the result of converting -1 to the original dim sizes
73   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
74   size_t output_rank = 0;
75   ET_KERNEL_CHECK(
76       ctx,
77       get_expand_copy_out_target_size(
78           self_sizes, expand_sizes, output_sizes, &output_rank),
79       InvalidArgument,
80       out);
81 
82   ET_KERNEL_CHECK(
83       ctx,
84       resize_tensor(out, {output_sizes, output_rank}) == Error::Ok,
85       InvalidArgument,
86       out);
87 
88   ET_KERNEL_CHECK(
89       ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
90   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
91 
92   // Holds the result of expand_sizes converted to repeat sizes
93   int64_t repeats[kTensorDimensionLimit];
94   const auto repeats_size{map_expand_to_repeats(
95       self_sizes, expand_sizes, repeats, kTensorDimensionLimit)};
96 
97   ET_KERNEL_CHECK(
98       ctx,
99       repeat_tensor(self, {repeats, repeats_size}, out) == Error::Ok,
100       InvalidArgument,
101       out);
102 
103   return out;
104 }
105 
106 } // namespace native
107 } // namespace executor
108 } // namespace torch
109