1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <algorithm>
10 #include <cstdint>
11 #include <cstring>
12
13 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21
squeeze_copy_dim_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,Tensor & out)22 Tensor& squeeze_copy_dim_out(
23 KernelRuntimeContext& ctx,
24 const Tensor& in,
25 int64_t dim,
26 Tensor& out) {
27 (void)ctx;
28
29 ET_KERNEL_CHECK(
30 ctx, check_squeeze_copy_dim_args(in, dim, out), InvalidArgument, out);
31
32 ET_KERNEL_CHECK(
33 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
34
35 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
36
37 if (dim < 0) {
38 dim += nonzero_dim(in);
39 }
40
41 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
42 size_t expected_out_dim = 0;
43 get_squeeze_copy_dim_out_target_size(
44 in, dim, expected_out_size, &expected_out_dim);
45 ET_KERNEL_CHECK(
46 ctx,
47 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
48 InvalidArgument,
49 out);
50
51 if (in.nbytes() > 0) {
52 // Note that this check is important. It's valid for a tensor with numel 0
53 // to have a null data pointer, but in some environments it's invalid to
54 // pass a null pointer to memcpy() even when the size is zero.
55 memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
56 }
57 return out;
58 }
59
squeeze_copy_dims_out(KernelRuntimeContext & ctx,const Tensor & in,exec_aten::ArrayRef<int64_t> dims,Tensor & out)60 Tensor& squeeze_copy_dims_out(
61 KernelRuntimeContext& ctx,
62 const Tensor& in,
63 exec_aten::ArrayRef<int64_t> dims,
64 Tensor& out) {
65 (void)ctx;
66
67 ET_KERNEL_CHECK(
68 ctx, check_squeeze_copy_dims_args(in, dims, out), InvalidArgument, out);
69
70 ET_KERNEL_CHECK(
71 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
72
73 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
74
75 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
76 size_t expected_out_dim = 0;
77 get_squeeze_copy_dims_out_target_size(
78 in, dims, expected_out_size, &expected_out_dim);
79 ET_KERNEL_CHECK(
80 ctx,
81 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
82 InvalidArgument,
83 out);
84
85 if (in.nbytes() > 0) {
86 // Note that this check is important. It's valid for a tensor with numel 0
87 // to have a null data pointer, but in some environments it's invalid to
88 // pass a null pointer to memcpy() even when the size is zero.
89 memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
90 }
91 return out;
92 }
93
94 } // namespace native
95 } // namespace executor
96 } // namespace torch
97