• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16 
17 #include <gtest/gtest.h>
18 
19 using namespace ::testing;
20 using exec_aten::optional;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using executorch::runtime::KernelRuntimeContext;
24 using torch::executor::native::quantized_mixed_mm_out;
25 using torch::executor::testing::TensorFactory;
26 
27 class OpQuantizedMixedMMTest : public ::testing::Test {
28  protected:
SetUp()29   void SetUp() override {
30     // Since these tests cause ET_LOG to be called, the PAL must be initialized
31     // first.
32     torch::executor::runtime_init();
33   }
34 };
35 
36 template <ScalarType DTYPE>
test_dtype()37 void test_dtype() {
38   TensorFactory<DTYPE> tf;
39   TensorFactory<ScalarType::Char> tf_char;
40 
41   Tensor input = tf.make(
42       /*sizes=*/{1, 3},
43       /*data=*/{1.0, 1.5, 2.0});
44   Tensor weight = tf_char.make(
45       /*sizes=*/{3, 2},
46       /*data=*/{5, 4, 3, 2, 1, 1});
47   Tensor weight_scales = tf.make(
48       /*sizes=*/{3},
49       /*data=*/{0.2, 0.4, 0.5});
50   const optional<Tensor> opt_weight_zp{};
51 
52   Tensor out = tf.zeros({1, 2});
53 
54   Tensor expected = tf.make(
55       /*sizes=*/{1, 2},
56       /*data=*/{3.8, 3.0});
57 
58   KernelRuntimeContext ctx{};
59 
60   quantized_mixed_mm_out(ctx, input, weight, weight_scales, opt_weight_zp, out);
61 
62   EXPECT_TENSOR_CLOSE(out, expected);
63 }
64 
TEST_F(OpQuantizedMixedMMTest,FloatInput)65 TEST_F(OpQuantizedMixedMMTest, FloatInput) {
66   test_dtype<ScalarType::Float>();
67 }
68 
TEST_F(OpQuantizedMixedMMTest,HalfInput)69 TEST_F(OpQuantizedMixedMMTest, HalfInput) {
70   test_dtype<ScalarType::Half>();
71 }
72