1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/aten/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
13
14 namespace torch {
15 namespace executor {
16
17 using Tensor = exec_aten::Tensor;
18
check__to_dim_order_copy_args(const Tensor & input,bool non_blocking,exec_aten::OptionalArrayRef<int64_t> dim_order,Tensor & out)19 bool check__to_dim_order_copy_args(
20 const Tensor& input,
21 bool non_blocking,
22 exec_aten::OptionalArrayRef<int64_t> dim_order,
23 Tensor& out) {
24 // Right now we only support blocking data transfer
25 ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);
26
27 // dim_order is set, the target dim_order will be either contiguous or
28 // channels_last memory format
29 if (dim_order.has_value()) {
30 exec_aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();
31
32 // dim order size shall equal to input dim
33 ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());
34
35 ET_LOG_AND_RETURN_IF_FALSE(
36 is_channels_last_dim_order(
37 dim_order.value().data(), dim_order.value().size()) ||
38 is_contiguous_dim_order(
39 dim_order.value().data(), dim_order.value().size()));
40
41 // Out Aten tensor shall have same memory format stride as dim_order
42 const size_t kMaxNumOfDimensions = 16;
43 ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
44 exec_aten::StridesType target_strides[kMaxNumOfDimensions];
45 dim_order_to_stride_nocheck(
46 out.sizes().data(),
47 dim_order_ref.data(),
48 dim_order_ref.size(),
49 target_strides);
50 ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
51 for (size_t i = 0; i < dim_order_ref.size(); i++) {
52 ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
53 }
54
55 } else { // dim_order is not set, preserve the dim order of input
56
57 auto out_strides = out.strides();
58 auto input_strides = input.strides();
59 ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
60 for (size_t i = 0; i < input_strides.size(); i++) {
61 ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
62 }
63 }
64 return true;
65 }
66
67 } // namespace executor
68 } // namespace torch
69