• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/aten/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
13 
14 namespace torch {
15 namespace executor {
16 
17 using Tensor = exec_aten::Tensor;
18 
check__to_dim_order_copy_args(const Tensor & input,bool non_blocking,exec_aten::OptionalArrayRef<int64_t> dim_order,Tensor & out)19 bool check__to_dim_order_copy_args(
20     const Tensor& input,
21     bool non_blocking,
22     exec_aten::OptionalArrayRef<int64_t> dim_order,
23     Tensor& out) {
24   // Right now we only support blocking data transfer
25   ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);
26 
27   // dim_order is set, the target dim_order will be either contiguous or
28   // channels_last memory format
29   if (dim_order.has_value()) {
30     exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();
31 
32     // dim order size shall equal to input dim
33     ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());
34 
35     ET_LOG_AND_RETURN_IF_FALSE(
36         is_channels_last_dim_order(
37             dim_order.value().data(), dim_order.value().size()) ||
38         is_contiguous_dim_order(
39             dim_order.value().data(), dim_order.value().size()));
40 
41     // Out Aten tensor shall have same memory format stride as dim_order
42     const size_t kMaxNumOfDimensions = 16;
43     ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
44     exec_aten::StridesType target_strides[kMaxNumOfDimensions];
45     dim_order_to_stride_nocheck(
46         out.sizes().data(),
47         dim_order_ref.data(),
48         dim_order_ref.size(),
49         target_strides);
50     ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
51     for (size_t i = 0; i < dim_order_ref.size(); i++) {
52       ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
53     }
54 
55   } else { // dim_order is not set, preserve the dim order of input
56 
57     auto out_strides = out.strides();
58     auto input_strides = input.strides();
59     ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
60     for (size_t i = 0; i < input_strides.size(); i++) {
61       ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
62     }
63   }
64   return true;
65 }
66 
67 } // namespace executor
68 } // namespace torch
69