• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/transpose_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
16 using SizesType = exec_aten::SizesType;
17 using StridesType = exec_aten::StridesType;
18 using Tensor = exec_aten::Tensor;
19 
20 /**
21  * Swaps dimension 'dim0' of 'a' with 'dim1', and copying
22  * that mutation into `out` in a manner such that the data is densely packed
23  * and is_contiguous() would return true (stride dim[size-1] = 1).
24  *
25  * transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out)
26  */
transpose_copy_int_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim0,int64_t dim1,Tensor & out)27 Tensor& transpose_copy_int_out(
28     KernelRuntimeContext& ctx,
29     const Tensor& in,
30     int64_t dim0,
31     int64_t dim1,
32     Tensor& out) {
33   (void)ctx;
34 
35   ET_KERNEL_CHECK(
36       ctx,
37       check_transpose_copy_args(in, dim0, dim1, out),
38       InvalidArgument,
39       out);
40 
41   if (dim0 < 0) {
42     dim0 += nonzero_dim(in);
43   }
44   if (dim1 < 0) {
45     dim1 += nonzero_dim(in);
46   }
47 
48   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
49   size_t expected_out_dim = 0;
50   get_transpose_out_target_size(
51       in, dim0, dim1, expected_out_size, &expected_out_dim);
52 
53   // Resize for dynamic shape
54   ET_KERNEL_CHECK(
55       ctx,
56       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
57       InvalidArgument,
58       out);
59 
60   ET_KERNEL_CHECK(
61       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
62 
63   ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
64     transpose_tensors<CTYPE>(in, dim0, dim1, out);
65   });
66 
67   return out;
68 }
69 
70 } // namespace native
71 } // namespace executor
72 } // namespace torch
73