1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstdint>
10 #include <cstring>
11
12 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20
21 // unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
22 // -> Tensor(a!)
unsqueeze_copy_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,Tensor & out)23 Tensor& unsqueeze_copy_out(
24 KernelRuntimeContext& ctx,
25 const Tensor& self,
26 int64_t dim,
27 Tensor& out) {
28 (void)ctx;
29 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
30 // I think this is safe to do but need to confirm.
31 // If we can do this then subsequent checks that specialize on dim < 0
32 // are not needed
33 if (dim < 0) {
34 dim += out.dim();
35 ET_KERNEL_CHECK(ctx, dim >= 0, InvalidArgument, out);
36 }
37
38 ET_KERNEL_CHECK(ctx, self.dim() + 1 == out.dim(), InvalidArgument, out);
39 ET_KERNEL_CHECK(ctx, dim <= self.dim(), InvalidArgument, out);
40
41 for (size_t i = 0; i < out.dim(); ++i) {
42 if (i < dim) {
43 expected_output_size[i] = self.size(i);
44 } else if (i > dim) {
45 expected_output_size[i] = self.size(i - 1);
46 } else {
47 expected_output_size[i] = 1;
48 }
49 }
50
51 ET_KERNEL_CHECK(
52 ctx,
53 resize_tensor(
54 out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
55 Error::Ok,
56 InvalidArgument,
57 out);
58
59 ET_KERNEL_CHECK(
60 ctx, check_unsqueeze_copy_args(self, dim, out), InvalidArgument, out);
61
62 if (self.nbytes() > 0) {
63 // Note that this check is important. It's valid for a tensor with numel 0
64 // to have a null data pointer, but in some environments it's invalid to
65 // pass a null pointer to memcpy() even when the size is zero.
66 memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
67 }
68 return out;
69 }
70
71 } // namespace native
72 } // namespace executor
73 } // namespace torch
74