• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/repeat_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace {
19 
calculate_output_size(const exec_aten::ArrayRef<exec_aten::SizesType> & self_sizes,const exec_aten::ArrayRef<int64_t> & repeats,Tensor::SizesType * out_sizes_ptr)20 bool calculate_output_size(
21     const exec_aten::ArrayRef<exec_aten::SizesType>& self_sizes,
22     const exec_aten::ArrayRef<int64_t>& repeats,
23     Tensor::SizesType* out_sizes_ptr) {
24   ET_LOG_AND_RETURN_IF_FALSE(repeats.size() < kTensorDimensionLimit);
25 
26   ET_LOG_MSG_AND_RETURN_IF_FALSE(
27       repeats.size() >= self_sizes.size(),
28       "Repeats vector size is %zu must be >= self_sizes %zu.",
29       repeats.size(),
30       self_sizes.size());
31 
32   int32_t i = 0;
33   for (; i < (repeats.size() - self_sizes.size()); ++i) {
34     out_sizes_ptr[i] = static_cast<exec_aten::SizesType>(repeats[i]);
35   }
36   int32_t j = 0;
37   for (; i < repeats.size(); ++i) {
38     out_sizes_ptr[i] =
39         static_cast<exec_aten::SizesType>(repeats[i]) * self_sizes[j];
40     j++;
41   }
42 
43   return true;
44 }
45 
46 } // namespace
47 
48 using Tensor = exec_aten::Tensor;
49 
50 // repeat.out(Tensor self, int[] repeats, *, Tensor(a!) out) -> Tensor(a!)
repeat_out(KernelRuntimeContext & ctx,const Tensor & self,exec_aten::ArrayRef<int64_t> repeats,Tensor & out)51 Tensor& repeat_out(
52     KernelRuntimeContext& ctx,
53     const Tensor& self,
54     exec_aten::ArrayRef<int64_t> repeats,
55     Tensor& out) {
56   (void)ctx;
57   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
58 
59   ET_KERNEL_CHECK(
60       ctx,
61       calculate_output_size(self.sizes(), repeats, expected_output_size),
62       InvalidArgument,
63       out);
64 
65   ET_KERNEL_CHECK(
66       ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
67 
68   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
69 
70   // Resize for dynamic shape
71   ET_KERNEL_CHECK_MSG(
72       ctx,
73       resize_tensor(out, {expected_output_size, repeats.size()}) == Error::Ok,
74       InvalidArgument,
75       out,
76       "Failed to resize output tensor.");
77 
78   ET_KERNEL_CHECK(
79       ctx,
80       repeat_tensor(self, repeats, out) == Error::Ok,
81       InvalidArgument,
82       out);
83 
84   return out;
85 }
86 
87 } // namespace native
88 } // namespace executor
89 } // namespace torch
90