1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::optional;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23
24 class OpLogitOutTest : public OperatorTest {
25 protected:
op_logit_out(const Tensor & self,optional<double> eps,Tensor & out)26 Tensor& op_logit_out(const Tensor& self, optional<double> eps, Tensor& out) {
27 return torch::executor::aten::logit_outf(context_, self, eps, out);
28 }
29
30 // Common testing for logit operator
31 template <ScalarType DTYPE, ScalarType OUTPUT_DTYPE>
test_integer_logit_out()32 void test_integer_logit_out() {
33 TensorFactory<DTYPE> tf;
34 TensorFactory<OUTPUT_DTYPE> tf_out;
35
36 const std::vector<int32_t> sizes = {2, 2};
37
38 // Destination for the logit operator.
39 Tensor out = tf_out.zeros(sizes);
40
41 op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0, out);
42 EXPECT_TENSOR_CLOSE(
43 out,
44 tf_out.make(sizes, /*data=*/{INFINITY, INFINITY, INFINITY, INFINITY}));
45 }
46
47 // Common testing for logit operator
48 template <ScalarType DTYPE, ScalarType OUTPUT_DTYPE>
test_integer_logit_out_eps_set()49 void test_integer_logit_out_eps_set() {
50 TensorFactory<DTYPE> tf;
51 TensorFactory<OUTPUT_DTYPE> tf_out;
52
53 const std::vector<int32_t> sizes = {2, 2};
54
55 // Destination for the logit operator.
56 Tensor out = tf_out.zeros(sizes);
57
58 op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0.1, out);
59
60 // Check that it matches (or close to) the expected output.
61 EXPECT_TENSOR_CLOSE(
62 out,
63 tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224}));
64 }
65
66 // Unhandled output dtypes.
67 template <ScalarType OUTPUT_DTYPE>
test_logit_invalid_output_dtype_dies()68 void test_logit_invalid_output_dtype_dies() {
69 TensorFactory<ScalarType::Float> tf;
70 TensorFactory<OUTPUT_DTYPE> tf_out;
71
72 const std::vector<int32_t> sizes = {2, 5};
73
74 Tensor in = tf.ones(sizes);
75 Tensor out = tf_out.zeros(sizes);
76
77 ET_EXPECT_KERNEL_FAILURE(context_, op_logit_out(in, 0, out));
78 }
79 };
80
81 template <>
82 void OpLogitOutTest::
test_integer_logit_out()83 test_integer_logit_out<ScalarType::Float, ScalarType::Float>() {
84 TensorFactory<ScalarType::Float> tf;
85 TensorFactory<ScalarType::Float> tf_out;
86
87 const std::vector<int32_t> sizes = {2, 2};
88
89 // Destination for the logit operator.
90 Tensor out = tf_out.zeros(sizes);
91
92 // Check that it matches (or close to) the expected output.
93 op_logit_out(tf.make(sizes, /*data=*/{.1, .2, .4, .8}), 0, out);
94 EXPECT_TENSOR_CLOSE(
95 out,
96 tf_out.make(
97 sizes, /*data=*/{-2.197224, -1.386294, -0.405465, 1.3862943}));
98 }
99
TEST_F(OpLogitOutTest,AllRealInputFloatOutputSupport)100 TEST_F(OpLogitOutTest, AllRealInputFloatOutputSupport) {
101 #define TEST_ENTRY(ctype, dtype) \
102 test_integer_logit_out<ScalarType::dtype, ScalarType::Float>();
103 ET_FORALL_REAL_TYPES(TEST_ENTRY);
104 #undef TEST_ENTRY
105 }
106
TEST_F(OpLogitOutTest,AllRealInputDoubleOutputSupport)107 TEST_F(OpLogitOutTest, AllRealInputDoubleOutputSupport) {
108 #define TEST_ENTRY(ctype, dtype) \
109 test_integer_logit_out<ScalarType::dtype, ScalarType::Double>();
110 ET_FORALL_REAL_TYPES(TEST_ENTRY);
111 #undef TEST_ENTRY
112 }
TEST_F(OpLogitOutTest,AllRealInputFloatOutputSupportEpsSet)113 TEST_F(OpLogitOutTest, AllRealInputFloatOutputSupportEpsSet) {
114 #define TEST_ENTRY(ctype, dtype) \
115 test_integer_logit_out_eps_set<ScalarType::dtype, ScalarType::Float>();
116 ET_FORALL_REAL_TYPES(TEST_ENTRY);
117 #undef TEST_ENTRY
118 }
119
TEST_F(OpLogitOutTest,AllRealInputDoubleOutputSupportEpsSet)120 TEST_F(OpLogitOutTest, AllRealInputDoubleOutputSupportEpsSet) {
121 #define TEST_ENTRY(ctype, dtype) \
122 test_integer_logit_out_eps_set<ScalarType::dtype, ScalarType::Double>();
123 ET_FORALL_REAL_TYPES(TEST_ENTRY);
124 #undef TEST_ENTRY
125 }
126
127 // Mismatched shape tests.
TEST_F(OpLogitOutTest,MismatchedShapesDies)128 TEST_F(OpLogitOutTest, MismatchedShapesDies) {
129 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
130 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
131 }
132 TensorFactory<ScalarType::Int> tf;
133 TensorFactory<ScalarType::Float> tf_out;
134
135 Tensor a = tf.ones(/*sizes=*/{4});
136 Tensor out = tf_out.ones(/*sizes=*/{2, 2});
137
138 ET_EXPECT_KERNEL_FAILURE(context_, op_logit_out(a, 0, out));
139 }
140
TEST_F(OpLogitOutTest,AllNonFloatOutputDTypeDies)141 TEST_F(OpLogitOutTest, AllNonFloatOutputDTypeDies) {
142 #define TEST_ENTRY(ctype, dtype) \
143 test_logit_invalid_output_dtype_dies<ScalarType::dtype>();
144 ET_FORALL_INT_TYPES(TEST_ENTRY);
145 #undef TEST_ENTRY
146 }
147
TEST_F(OpLogitOutTest,SimpleGeneratedCase)148 TEST_F(OpLogitOutTest, SimpleGeneratedCase) {
149 TensorFactory<ScalarType::Float> tf;
150
151 Tensor x = tf.make(
152 {10, 10},
153 {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
154 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
155 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
156 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
157 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
158 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
159 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
160 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
161 Tensor expected_result = tf.make(
162 {10, 10}, {2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
163 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
164 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
165 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
166 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
167 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
168 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
169 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
170 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
171 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
172 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
173 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
174 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
175 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
176 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
177 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
178 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
179 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
180 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
181 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
182 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
183 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
184 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
185 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
186 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
187 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
188 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
189 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
190 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
191 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
192 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
193 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
194 2.1972243785858154, 2.1972243785858154, 2.1972243785858154,
195 2.1972243785858154});
196
197 Tensor out = tf.zeros({10, 10});
198 Tensor ret = op_logit_out(x, 0.1, out);
199 EXPECT_TENSOR_CLOSE(out, expected_result);
200 }
201
TEST_F(OpLogitOutTest,DynamicShapeUpperBoundSameAsExpected)202 TEST_F(OpLogitOutTest, DynamicShapeUpperBoundSameAsExpected) {
203 TensorFactory<ScalarType::Float> tf;
204
205 Tensor x = tf.make(
206 {3, 2},
207 {0.9622091054916382,
208 0.511866569519043,
209 0.15690308809280396,
210 0.7423648834228516,
211 0.627659797668457,
212 0.4892460107803345});
213 Tensor expected_result = tf.make(
214 {3, 2},
215 {2.1972243785858154,
216 0.04747522622346878,
217 -1.6814535856246948,
218 1.05829656124115,
219 0.5221903324127197,
220 -0.043022606521844864});
221
222 Tensor out =
223 tf.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
224 Tensor ret = op_logit_out(x, 0.1, out);
225 EXPECT_TENSOR_CLOSE(out, expected_result);
226 }
227
TEST_F(OpLogitOutTest,DynamicShapeUpperBoundLargerThanExpected)228 TEST_F(OpLogitOutTest, DynamicShapeUpperBoundLargerThanExpected) {
229 TensorFactory<ScalarType::Float> tf;
230
231 Tensor x = tf.make(
232 {3, 2},
233 {0.9622091054916382,
234 0.511866569519043,
235 0.15690308809280396,
236 0.7423648834228516,
237 0.627659797668457,
238 0.4892460107803345});
239 Tensor expected_result = tf.make(
240 {3, 2},
241 {2.1972243785858154,
242 0.04747522622346878,
243 -1.6814535856246948,
244 1.05829656124115,
245 0.5221903324127197,
246 -0.043022606521844864});
247
248 Tensor out =
249 tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
250 Tensor ret = op_logit_out(x, 0.1, out);
251 EXPECT_TENSOR_CLOSE(out, expected_result);
252 }
253
TEST_F(OpLogitOutTest,DynamicShapeUnbound)254 TEST_F(OpLogitOutTest, DynamicShapeUnbound) {
255 GTEST_SKIP() << "Dynamic shape unbound not supported";
256 TensorFactory<ScalarType::Float> tf;
257
258 Tensor x = tf.make(
259 {3, 2},
260 {0.9622091054916382,
261 0.511866569519043,
262 0.15690308809280396,
263 0.7423648834228516,
264 0.627659797668457,
265 0.4892460107803345});
266 Tensor expected_result = tf.make(
267 {3, 2},
268 {2.1972243785858154,
269 0.04747522622346878,
270 -1.6814535856246948,
271 1.05829656124115,
272 0.5221903324127197,
273 -0.043022606521844864});
274
275 Tensor out =
276 tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
277 Tensor ret = op_logit_out(x, 0.1, out);
278 EXPECT_TENSOR_CLOSE(out, expected_result);
279 }
280