• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <gtest/gtest.h>
11 
12 template <typename T1, typename T2>
13 struct promote_type_with_scalar_type_is_valid
14     : std::integral_constant<
15           bool,
16           (std::is_same<T2, torch::executor::internal::B1>::value ||
17            std::is_same<T2, torch::executor::internal::I8>::value ||
18            std::is_same<T2, torch::executor::internal::F8>::value) &&
19               !std::is_same<T1, exec_aten::BFloat16>::value &&
20               !torch::executor::is_qint_type<T1>::value &&
21               !torch::executor::is_bits_type<T1>::value &&
22               !executorch::runtime::is_bits_type<T2>::value &&
23               !executorch::runtime::is_float8_type<T1>::value &&
24               !executorch::runtime::is_float8_type<T2>::value &&
25               !executorch::runtime::is_barebones_unsigned_type<T1>::value &&
26               !executorch::runtime::is_barebones_unsigned_type<T2>::value> {};
27 
28 template <typename T1, bool half_to_float>
29 struct CompileTimePromoteTypeWithScalarTypeTestCase {
testAllCompileTimePromoteTypeWithScalarTypeTestCase30   static void testAll() {
31 #define CALL_TEST_ONE(cpp_type, scalar_type) \
32   testOne<                                   \
33       cpp_type,                              \
34       promote_type_with_scalar_type_is_valid<T1, cpp_type>::value>();
35     ET_FORALL_SCALAR_TYPES(CALL_TEST_ONE)
36 #undef CALL_TEST_ONE
37   }
38 
39   template <
40       typename T2,
41       bool valid,
42       typename std::enable_if<valid, bool>::type = true>
testOneCompileTimePromoteTypeWithScalarTypeTestCase43   static void testOne() {
44     auto actual = torch::executor::CppTypeToScalarType<
45         typename torch::executor::native::utils::
46             promote_type_with_scalar_type<T1, T2, half_to_float>::type>::value;
47     const auto scalarType1 = torch::executor::CppTypeToScalarType<T1>::value;
48     const auto scalarType2 = torch::executor::CppTypeToScalarType<T2>::value;
49     T2 scalar_value = 0;
50     auto expected = torch::executor::native::utils::promote_type_with_scalar(
51         scalarType1, scalar_value, half_to_float);
52     EXPECT_EQ(actual, expected)
53         << "promoting " << (int)scalarType1 << " with " << (int)scalarType2
54         << " given half_to_float = " << half_to_float << " expected "
55         << (int)expected << " but got " << (int)actual;
56   }
57 
58   template <
59       typename T2,
60       bool valid,
61       typename std::enable_if<!valid, bool>::type = true>
testOneCompileTimePromoteTypeWithScalarTypeTestCase62   static void testOne() {
63     // Skip invalid case
64   }
65 };
66 
TEST(ScalarTypeUtilTest,compileTypePromoteTypesTest)67 TEST(ScalarTypeUtilTest, compileTypePromoteTypesTest) {
68 #define INSTANTIATE_TYPE_TEST(cpp_type, scalar_type)                        \
69   CompileTimePromoteTypeWithScalarTypeTestCase<cpp_type, false>::testAll(); \
70   CompileTimePromoteTypeWithScalarTypeTestCase<cpp_type, true>::testAll();
71 
72   ET_FORALL_SCALAR_TYPES(INSTANTIATE_TYPE_TEST);
73 }
74