1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <gtest/gtest.h>
11
12 template <typename T1, typename T2>
13 struct promote_type_with_scalar_type_is_valid
14 : std::integral_constant<
15 bool,
16 (std::is_same<T2, torch::executor::internal::B1>::value ||
17 std::is_same<T2, torch::executor::internal::I8>::value ||
18 std::is_same<T2, torch::executor::internal::F8>::value) &&
19 !std::is_same<T1, exec_aten::BFloat16>::value &&
20 !torch::executor::is_qint_type<T1>::value &&
21 !torch::executor::is_bits_type<T1>::value &&
22 !executorch::runtime::is_bits_type<T2>::value &&
23 !executorch::runtime::is_float8_type<T1>::value &&
24 !executorch::runtime::is_float8_type<T2>::value &&
25 !executorch::runtime::is_barebones_unsigned_type<T1>::value &&
26 !executorch::runtime::is_barebones_unsigned_type<T2>::value> {};
27
28 template <typename T1, bool half_to_float>
29 struct CompileTimePromoteTypeWithScalarTypeTestCase {
testAllCompileTimePromoteTypeWithScalarTypeTestCase30 static void testAll() {
31 #define CALL_TEST_ONE(cpp_type, scalar_type) \
32 testOne< \
33 cpp_type, \
34 promote_type_with_scalar_type_is_valid<T1, cpp_type>::value>();
35 ET_FORALL_SCALAR_TYPES(CALL_TEST_ONE)
36 #undef CALL_TEST_ONE
37 }
38
39 template <
40 typename T2,
41 bool valid,
42 typename std::enable_if<valid, bool>::type = true>
testOneCompileTimePromoteTypeWithScalarTypeTestCase43 static void testOne() {
44 auto actual = torch::executor::CppTypeToScalarType<
45 typename torch::executor::native::utils::
46 promote_type_with_scalar_type<T1, T2, half_to_float>::type>::value;
47 const auto scalarType1 = torch::executor::CppTypeToScalarType<T1>::value;
48 const auto scalarType2 = torch::executor::CppTypeToScalarType<T2>::value;
49 T2 scalar_value = 0;
50 auto expected = torch::executor::native::utils::promote_type_with_scalar(
51 scalarType1, scalar_value, half_to_float);
52 EXPECT_EQ(actual, expected)
53 << "promoting " << (int)scalarType1 << " with " << (int)scalarType2
54 << " given half_to_float = " << half_to_float << " expected "
55 << (int)expected << " but got " << (int)actual;
56 }
57
58 template <
59 typename T2,
60 bool valid,
61 typename std::enable_if<!valid, bool>::type = true>
testOneCompileTimePromoteTypeWithScalarTypeTestCase62 static void testOne() {
63 // Skip invalid case
64 }
65 };
66
TEST(ScalarTypeUtilTest,compileTypePromoteTypesTest)67 TEST(ScalarTypeUtilTest, compileTypePromoteTypesTest) {
68 #define INSTANTIATE_TYPE_TEST(cpp_type, scalar_type) \
69 CompileTimePromoteTypeWithScalarTypeTestCase<cpp_type, false>::testAll(); \
70 CompileTimePromoteTypeWithScalarTypeTestCase<cpp_type, true>::testAll();
71
72 ET_FORALL_SCALAR_TYPES(INSTANTIATE_TYPE_TEST);
73 }
74