• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdint>
10 #include <cstring>
11 
12 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 
21 // unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
22 // -> Tensor(a!)
unsqueeze_copy_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,Tensor & out)23 Tensor& unsqueeze_copy_out(
24     KernelRuntimeContext& ctx,
25     const Tensor& self,
26     int64_t dim,
27     Tensor& out) {
28   (void)ctx;
29   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
30   // I think this is safe to do but need to confirm.
31   // If we can do this then subsequent checks that specialize on dim < 0
32   // are not needed
33   if (dim < 0) {
34     dim += out.dim();
35     ET_KERNEL_CHECK(ctx, dim >= 0, InvalidArgument, out);
36   }
37 
38   ET_KERNEL_CHECK(ctx, self.dim() + 1 == out.dim(), InvalidArgument, out);
39   ET_KERNEL_CHECK(ctx, dim <= self.dim(), InvalidArgument, out);
40 
41   for (size_t i = 0; i < out.dim(); ++i) {
42     if (i < dim) {
43       expected_output_size[i] = self.size(i);
44     } else if (i > dim) {
45       expected_output_size[i] = self.size(i - 1);
46     } else {
47       expected_output_size[i] = 1;
48     }
49   }
50 
51   ET_KERNEL_CHECK(
52       ctx,
53       resize_tensor(
54           out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
55           Error::Ok,
56       InvalidArgument,
57       out);
58 
59   ET_KERNEL_CHECK(
60       ctx, check_unsqueeze_copy_args(self, dim, out), InvalidArgument, out);
61 
62   if (self.nbytes() > 0) {
63     // Note that this check is important. It's valid for a tensor with numel 0
64     // to have a null data pointer, but in some environments it's invalid to
65     // pass a null pointer to memcpy() even when the size is zero.
66     memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
67   }
68   return out;
69 }
70 
71 } // namespace native
72 } // namespace executor
73 } // namespace torch
74