• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 namespace {
16 
pixel_unshuffle_impl(const Tensor & in,int64_t downscale_factor,Tensor & out)17 void pixel_unshuffle_impl(
18     const Tensor& in,
19     int64_t downscale_factor,
20     Tensor& out) {
21   const char* const in_data =
22       reinterpret_cast<const char*>(in.const_data_ptr());
23   char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
24   const auto elem_size = in.element_size();
25 
26   const auto leading_dims = getLeadingDims(in, in.dim() - 3);
27   const auto channels = out.size(in.dim() - 3);
28   const auto height = out.size(in.dim() - 2);
29   const auto width = out.size(in.dim() - 1);
30 
31   const auto S = downscale_factor;
32   const auto sub_channels = channels / (S * S);
33 
34   // output strides
35   const auto stride_n = channels * height * width;
36   const auto stride_c = S * S * height * width;
37   const auto stride_s1 = S * height * width;
38   const auto stride_s2 = height * width;
39   const auto stride_h = width;
40 
41   // input tensor shape of [n, c, h, s1, w, s2]
42   // output tensor shape of [n, c, s1, s2, h, w]
43   size_t i = 0;
44   for (size_t n = 0; n < leading_dims; n++) {
45     for (size_t c = 0; c < sub_channels; c++) {
46       for (size_t h = 0; h < height; h++) {
47         for (size_t s1 = 0; s1 < S; s1++) {
48           for (size_t w = 0; w < width; w++) {
49             for (size_t s2 = 0; s2 < S; s2++) {
50               size_t output_offset = n * stride_n + c * stride_c +
51                   s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
52               std::memcpy(
53                   out_data + output_offset * elem_size,
54                   in_data + i * elem_size,
55                   elem_size);
56               i++;
57             }
58           }
59         }
60       }
61     }
62   }
63 }
64 
65 } // namespace
66 
67 using SizesType = exec_aten::SizesType;
68 using Tensor = exec_aten::Tensor;
69 
pixel_unshuffle_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t downscale_factor,Tensor & out)70 Tensor& pixel_unshuffle_out(
71     KernelRuntimeContext& ctx,
72     const Tensor& in,
73     int64_t downscale_factor,
74     Tensor& out) {
75   (void)ctx;
76 
77   ET_KERNEL_CHECK(
78       ctx,
79       check_pixel_unshuffle_args(in, downscale_factor, out),
80       InvalidArgument,
81       out);
82 
83   // @lint-ignore CLANGTIDY facebook-hte-CArray
84   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
85   size_t expected_out_dim = 0;
86   get_pixel_unshuffle_out_target_size(
87       in, downscale_factor, expected_out_size, &expected_out_dim);
88 
89   // Make sure the output tensor is the right size.
90   ET_KERNEL_CHECK(
91       ctx,
92       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
93       InvalidArgument,
94       out);
95 
96   pixel_unshuffle_impl(in, downscale_factor, out);
97 
98   return out;
99 }
100 
101 } // namespace native
102 } // namespace executor
103 } // namespace torch
104