1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15 namespace {
16
pixel_unshuffle_impl(const Tensor & in,int64_t downscale_factor,Tensor & out)17 void pixel_unshuffle_impl(
18 const Tensor& in,
19 int64_t downscale_factor,
20 Tensor& out) {
21 const char* const in_data =
22 reinterpret_cast<const char*>(in.const_data_ptr());
23 char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
24 const auto elem_size = in.element_size();
25
26 const auto leading_dims = getLeadingDims(in, in.dim() - 3);
27 const auto channels = out.size(in.dim() - 3);
28 const auto height = out.size(in.dim() - 2);
29 const auto width = out.size(in.dim() - 1);
30
31 const auto S = downscale_factor;
32 const auto sub_channels = channels / (S * S);
33
34 // output strides
35 const auto stride_n = channels * height * width;
36 const auto stride_c = S * S * height * width;
37 const auto stride_s1 = S * height * width;
38 const auto stride_s2 = height * width;
39 const auto stride_h = width;
40
41 // input tensor shape of [n, c, h, s1, w, s2]
42 // output tensor shape of [n, c, s1, s2, h, w]
43 size_t i = 0;
44 for (size_t n = 0; n < leading_dims; n++) {
45 for (size_t c = 0; c < sub_channels; c++) {
46 for (size_t h = 0; h < height; h++) {
47 for (size_t s1 = 0; s1 < S; s1++) {
48 for (size_t w = 0; w < width; w++) {
49 for (size_t s2 = 0; s2 < S; s2++) {
50 size_t output_offset = n * stride_n + c * stride_c +
51 s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
52 std::memcpy(
53 out_data + output_offset * elem_size,
54 in_data + i * elem_size,
55 elem_size);
56 i++;
57 }
58 }
59 }
60 }
61 }
62 }
63 }
64
65 } // namespace
66
67 using SizesType = exec_aten::SizesType;
68 using Tensor = exec_aten::Tensor;
69
pixel_unshuffle_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t downscale_factor,Tensor & out)70 Tensor& pixel_unshuffle_out(
71 KernelRuntimeContext& ctx,
72 const Tensor& in,
73 int64_t downscale_factor,
74 Tensor& out) {
75 (void)ctx;
76
77 ET_KERNEL_CHECK(
78 ctx,
79 check_pixel_unshuffle_args(in, downscale_factor, out),
80 InvalidArgument,
81 out);
82
83 // @lint-ignore CLANGTIDY facebook-hte-CArray
84 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
85 size_t expected_out_dim = 0;
86 get_pixel_unshuffle_out_target_size(
87 in, downscale_factor, expected_out_size, &expected_out_dim);
88
89 // Make sure the output tensor is the right size.
90 ET_KERNEL_CHECK(
91 ctx,
92 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
93 InvalidArgument,
94 out);
95
96 pixel_unshuffle_impl(in, downscale_factor, out);
97
98 return out;
99 }
100
101 } // namespace native
102 } // namespace executor
103 } // namespace torch
104