• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11 
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 #include <executorch/test/utils/DeathTest.h>
17 
18 #include <gtest/gtest.h>
19 
20 using namespace ::testing;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using executorch::runtime::ArrayRef;
24 using executorch::runtime::testing::TensorFactory;
25 using torch::executor::broadcast_tensor;
26 using torch::executor::delinearize_index;
27 using torch::executor::get_broadcast_target_size;
28 using torch::executor::linearize_access_indexes;
29 using torch::executor::tensor_is_broadcastable_to;
30 using torch::executor::tensors_are_broadcastable_between;
31 
TEST(BroadcastUtilTest,BroadcastTensor)32 TEST(BroadcastUtilTest, BroadcastTensor) {
33   TensorFactory<ScalarType::Int> tf;
34 
35   Tensor a = tf.make({1}, {2});
36   Tensor b = tf.make({2, 2}, {2, 2, 2, 2});
37   Tensor c = tf.zeros({2, 2});
38 
39   Tensor d = torch::executor::broadcast_tensor(a, c);
40   EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
41   torch::executor::free_broadcast_tensor(d);
42 
43   d = torch::executor::broadcast_tensor(b, c);
44   EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
45   torch::executor::free_broadcast_tensor(d);
46 }
47 
TEST(BroadcastUtilTest,BroadcastableBetween)48 TEST(BroadcastUtilTest, BroadcastableBetween) {
49   TensorFactory<ScalarType::Int> tf;
50 
51   std::vector<Tensor> tensor_list = {
52       tf.zeros({1, 2}), tf.zeros({2, 1}), tf.zeros({1}), tf.zeros({2, 2})};
53 
54   for (int i = 0; i < 4; i++) {
55     for (int j = i + 1; j < 4; j++) {
56       EXPECT_TRUE(
57           tensors_are_broadcastable_between(tensor_list[i], tensor_list[j]));
58     }
59   }
60 }
61 
TEST(BroadcastUtilTest,BroadcastableToFrom)62 TEST(BroadcastUtilTest, BroadcastableToFrom) {
63   TensorFactory<ScalarType::Int> tf;
64 
65   Tensor a = tf.make({1, 2}, {2, 2});
66   Tensor b = tf.make({2, 1}, {2, 2});
67   Tensor c = tf.zeros({2, 2});
68 
69   ASSERT_TRUE(tensor_is_broadcastable_to(a, c));
70   Tensor d = torch::executor::broadcast_tensor(a, c);
71   EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
72   torch::executor::free_broadcast_tensor(d);
73 
74   ASSERT_TRUE(tensor_is_broadcastable_to(b, c));
75   d = torch::executor::broadcast_tensor(b, c);
76   EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
77   torch::executor::free_broadcast_tensor(d);
78 }
79 
TEST(BroadcastUtilTest,NotBroadcastableTo)80 TEST(BroadcastUtilTest, NotBroadcastableTo) {
81   TensorFactory<ScalarType::Int> tf;
82 
83   // Tensor a is broadcastable to tensor b means when tracing their sizes from
84   // back to front, each pair of corresponding dimensions should meet one of the
85   // following conditions:
86   // 1. the two dimensions are equal;
87   // 2. a's dimension is 1;
88   // 3. one of the dimensions does not exist.
89   Tensor a = tf.make({3}, {2, 2, 2});
90   Tensor b = tf.zeros({2, 1});
91   Tensor c = tf.zeros({1, 2});
92 
93   ASSERT_FALSE(tensor_is_broadcastable_to(a, b));
94   ET_EXPECT_DEATH(broadcast_tensor(a, b), "");
95 
96   // Can not broadcast from b to c, though they are broadcastable.
97   // When broadcasting, b and c should be broadcasted to a new size (2, 2).
98   // Neither of them can be broadcasted to each other's size.
99   ASSERT_FALSE(tensor_is_broadcastable_to(b, c));
100   ET_EXPECT_DEATH(broadcast_tensor(b, c), "");
101 }
102 
TEST(BroadcastUtilTest,NotBroadcastableBetween)103 TEST(BroadcastUtilTest, NotBroadcastableBetween) {
104   TensorFactory<ScalarType::Int> tf;
105 
106   Tensor a = tf.make({3}, {2, 2, 2});
107   Tensor b = tf.zeros({2, 1});
108 
109   EXPECT_FALSE(tensor_is_broadcastable_to(a, b));
110 }
111 
TEST(BroadcastUtilTest,GetBroadcastTargetSize)112 TEST(BroadcastUtilTest, GetBroadcastTargetSize) {
113   TensorFactory<ScalarType::Int> tf;
114   Tensor::SizesType
115       expected_output_size[torch::executor::kTensorDimensionLimit] = {};
116   size_t expected_output_dim = 0;
117 
118   Tensor a = tf.zeros({2, 1});
119   Tensor b = tf.zeros({5, 1, 2});
120 
121   executorch::runtime::Error err = get_broadcast_target_size(
122       a,
123       b,
124       expected_output_size,
125       torch::executor::kTensorDimensionLimit,
126       &expected_output_dim);
127   EXPECT_EQ(err, torch::executor::Error::Ok);
128 
129   EXPECT_TRUE(
130       ArrayRef<Tensor::SizesType>(expected_output_size, expected_output_dim)
131           .equals(ArrayRef<Tensor::SizesType>({5, 2, 2})));
132 }
133 
linearize_indexes(size_t * indexes,size_t indexes_len,const Tensor & t)134 size_t linearize_indexes(size_t* indexes, size_t indexes_len, const Tensor& t) {
135   size_t linear_index = 0;
136   size_t acc_loop_counts = 1;
137   for (ssize_t i = indexes_len - 1; i >= 0; --i) {
138     linear_index += indexes[i] * acc_loop_counts;
139     acc_loop_counts *= (size_t)t.sizes()[i];
140   }
141   return linear_index;
142 }
143 
TEST(BroadcastUtilTest,DelinearizeIndex)144 TEST(BroadcastUtilTest, DelinearizeIndex) {
145   TensorFactory<ScalarType::Int> tf;
146 
147   const size_t DIMS = 3;
148   Tensor t = tf.zeros({4, 3, 5});
149   auto sizes = t.sizes();
150 
151   for (size_t i0 = 0; i0 < (size_t)sizes[0]; ++i0) {
152     for (size_t i1 = 0; i1 < (size_t)sizes[1]; ++i1) {
153       for (size_t i2 = 0; i2 < (size_t)sizes[2]; ++i2) {
154         size_t indexes[DIMS] = {i0, i1, i2};
155         auto linear_index = linearize_indexes(indexes, DIMS, t);
156 
157         size_t out_indexes[DIMS];
158         delinearize_index(linear_index, t, out_indexes, DIMS);
159 
160         EXPECT_EQ(linear_index, linearize_indexes(out_indexes, DIMS, t));
161       }
162     }
163   }
164 }
165 
TEST(BroadcastUtilTest,LinearizeIndex)166 TEST(BroadcastUtilTest, LinearizeIndex) {
167   TensorFactory<ScalarType::Int> tf;
168 
169   Tensor broadcast_from = tf.zeros({2, 1, 3, 1});
170   Tensor broadcast_to = tf.zeros({2, 2, 3, 4});
171 
172   // The linear index for brodcast_from should be the same in
173   // the brocasted dimension.
174   for (size_t i = 0; i < 3; ++i) {
175     size_t test_indexes[] = {0, 0, 0, i};
176     ArrayRef<size_t> broadcast_to_indexes(test_indexes);
177     size_t linear_index = linearize_access_indexes(
178         broadcast_to_indexes, broadcast_to.dim(), broadcast_from);
179     EXPECT_EQ(linear_index, 0);
180   }
181 
182   // The linear index for brodcast_from should be the same.
183   // the brocasted dimension.
184   for (size_t i = 0; i <= 2; ++i) {
185     size_t test_indexes[] = {0, i, 2, 3};
186     ArrayRef<size_t> broadcast_to_indexes(test_indexes);
187     size_t linear_index = linearize_access_indexes(
188         broadcast_to_indexes, broadcast_to.dim(), broadcast_from);
189     EXPECT_EQ(linear_index, 2);
190   }
191 }
192