• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <algorithm>
10 #include <cstdint>
11 #include <cstring>
12 
13 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 
squeeze_copy_dim_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,Tensor & out)22 Tensor& squeeze_copy_dim_out(
23     KernelRuntimeContext& ctx,
24     const Tensor& in,
25     int64_t dim,
26     Tensor& out) {
27   (void)ctx;
28 
29   ET_KERNEL_CHECK(
30       ctx, check_squeeze_copy_dim_args(in, dim, out), InvalidArgument, out);
31 
32   ET_KERNEL_CHECK(
33       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
34 
35   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
36 
37   if (dim < 0) {
38     dim += nonzero_dim(in);
39   }
40 
41   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
42   size_t expected_out_dim = 0;
43   get_squeeze_copy_dim_out_target_size(
44       in, dim, expected_out_size, &expected_out_dim);
45   ET_KERNEL_CHECK(
46       ctx,
47       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
48       InvalidArgument,
49       out);
50 
51   if (in.nbytes() > 0) {
52     // Note that this check is important. It's valid for a tensor with numel 0
53     // to have a null data pointer, but in some environments it's invalid to
54     // pass a null pointer to memcpy() even when the size is zero.
55     memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
56   }
57   return out;
58 }
59 
squeeze_copy_dims_out(KernelRuntimeContext & ctx,const Tensor & in,exec_aten::ArrayRef<int64_t> dims,Tensor & out)60 Tensor& squeeze_copy_dims_out(
61     KernelRuntimeContext& ctx,
62     const Tensor& in,
63     exec_aten::ArrayRef<int64_t> dims,
64     Tensor& out) {
65   (void)ctx;
66 
67   ET_KERNEL_CHECK(
68       ctx, check_squeeze_copy_dims_args(in, dims, out), InvalidArgument, out);
69 
70   ET_KERNEL_CHECK(
71       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
72 
73   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
74 
75   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
76   size_t expected_out_dim = 0;
77   get_squeeze_copy_dims_out_target_size(
78       in, dims, expected_out_size, &expected_out_dim);
79   ET_KERNEL_CHECK(
80       ctx,
81       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
82       InvalidArgument,
83       out);
84 
85   if (in.nbytes() > 0) {
86     // Note that this check is important. It's valid for a tensor with numel 0
87     // to have a null data pointer, but in some environments it's invalid to
88     // pass a null pointer to memcpy() even when the size is zero.
89     memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
90   }
91   return out;
92 }
93 
94 } // namespace native
95 } // namespace executor
96 } // namespace torch
97