1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20 using exec_aten::optional;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using executorch::runtime::KernelRuntimeContext;
24 using torch::executor::native::quantized_mixed_mm_out;
25 using torch::executor::testing::TensorFactory;
26
27 class OpQuantizedMixedMMTest : public ::testing::Test {
28 protected:
SetUp()29 void SetUp() override {
30 // Since these tests cause ET_LOG to be called, the PAL must be initialized
31 // first.
32 torch::executor::runtime_init();
33 }
34 };
35
36 template <ScalarType DTYPE>
test_dtype()37 void test_dtype() {
38 TensorFactory<DTYPE> tf;
39 TensorFactory<ScalarType::Char> tf_char;
40
41 Tensor input = tf.make(
42 /*sizes=*/{1, 3},
43 /*data=*/{1.0, 1.5, 2.0});
44 Tensor weight = tf_char.make(
45 /*sizes=*/{3, 2},
46 /*data=*/{5, 4, 3, 2, 1, 1});
47 Tensor weight_scales = tf.make(
48 /*sizes=*/{3},
49 /*data=*/{0.2, 0.4, 0.5});
50 const optional<Tensor> opt_weight_zp{};
51
52 Tensor out = tf.zeros({1, 2});
53
54 Tensor expected = tf.make(
55 /*sizes=*/{1, 2},
56 /*data=*/{3.8, 3.0});
57
58 KernelRuntimeContext ctx{};
59
60 quantized_mixed_mm_out(ctx, input, weight, weight_scales, opt_weight_zp, out);
61
62 EXPECT_TENSOR_CLOSE(out, expected);
63 }
64
TEST_F(OpQuantizedMixedMMTest,FloatInput)65 TEST_F(OpQuantizedMixedMMTest, FloatInput) {
66 test_dtype<ScalarType::Float>();
67 }
68
TEST_F(OpQuantizedMixedMMTest,HalfInput)69 TEST_F(OpQuantizedMixedMMTest, HalfInput) {
70 test_dtype<ScalarType::Half>();
71 }
72