1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/repeat_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace {
19
calculate_output_size(const exec_aten::ArrayRef<exec_aten::SizesType> & self_sizes,const exec_aten::ArrayRef<int64_t> & repeats,Tensor::SizesType * out_sizes_ptr)20 bool calculate_output_size(
21 const exec_aten::ArrayRef<exec_aten::SizesType>& self_sizes,
22 const exec_aten::ArrayRef<int64_t>& repeats,
23 Tensor::SizesType* out_sizes_ptr) {
24 ET_LOG_AND_RETURN_IF_FALSE(repeats.size() < kTensorDimensionLimit);
25
26 ET_LOG_MSG_AND_RETURN_IF_FALSE(
27 repeats.size() >= self_sizes.size(),
28 "Repeats vector size is %zu must be >= self_sizes %zu.",
29 repeats.size(),
30 self_sizes.size());
31
32 int32_t i = 0;
33 for (; i < (repeats.size() - self_sizes.size()); ++i) {
34 out_sizes_ptr[i] = static_cast<exec_aten::SizesType>(repeats[i]);
35 }
36 int32_t j = 0;
37 for (; i < repeats.size(); ++i) {
38 out_sizes_ptr[i] =
39 static_cast<exec_aten::SizesType>(repeats[i]) * self_sizes[j];
40 j++;
41 }
42
43 return true;
44 }
45
46 } // namespace
47
48 using Tensor = exec_aten::Tensor;
49
50 // repeat.out(Tensor self, int[] repeats, *, Tensor(a!) out) -> Tensor(a!)
repeat_out(KernelRuntimeContext & ctx,const Tensor & self,exec_aten::ArrayRef<int64_t> repeats,Tensor & out)51 Tensor& repeat_out(
52 KernelRuntimeContext& ctx,
53 const Tensor& self,
54 exec_aten::ArrayRef<int64_t> repeats,
55 Tensor& out) {
56 (void)ctx;
57 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
58
59 ET_KERNEL_CHECK(
60 ctx,
61 calculate_output_size(self.sizes(), repeats, expected_output_size),
62 InvalidArgument,
63 out);
64
65 ET_KERNEL_CHECK(
66 ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
67
68 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
69
70 // Resize for dynamic shape
71 ET_KERNEL_CHECK_MSG(
72 ctx,
73 resize_tensor(out, {expected_output_size, repeats.size()}) == Error::Ok,
74 InvalidArgument,
75 out,
76 "Failed to resize output tensor.");
77
78 ET_KERNEL_CHECK(
79 ctx,
80 repeat_tensor(self, repeats, out) == Error::Ok,
81 InvalidArgument,
82 out);
83
84 return out;
85 }
86
87 } // namespace native
88 } // namespace executor
89 } // namespace torch
90