• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace impl {
13 namespace reference {
14 namespace native {
15 
16 using executorch::aten::Tensor;
17 using executorch::runtime::getLeadingDims;
18 using executorch::runtime::KernelRuntimeContext;
19 
20 // The quantized matmul. The quantized matmul accumulates in a wider register,
21 // whose type is TA.
22 template <
23     typename TZ,
24     typename TA = float,
25     bool transposed = false,
26     typename TX = TZ,
27     typename TY = TZ>
qmatmul(TZ * __restrict__ Z,int32_t Z_multiplier,int32_t Z_shift,int32_t Z_zero_point,const TX * __restrict__ X,int32_t X_zero_point,const TY * __restrict__ y,int32_t Y_zero_point,size_t m,size_t n,size_t p)28 __attribute__((noinline)) void qmatmul(
29     TZ* __restrict__ Z,
30     int32_t Z_multiplier,
31     int32_t Z_shift,
32     int32_t Z_zero_point,
33     const TX* __restrict__ X,
34     int32_t X_zero_point,
35     const TY* __restrict__ y,
36     int32_t Y_zero_point,
37     size_t m,
38     size_t n,
39     size_t p) {
40   // Compute the Z_scale from Z_multiplier and Z_shift
41   const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift);
42   for (size_t i = 0; i < m; ++i) {
43     for (size_t j = 0; j < p; ++j) {
44       TA sum = 0;
45       for (size_t k = 0; k < n; ++k) {
46         if (transposed) {
47           sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point);
48         } else {
49           sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point);
50         }
51       }
52       Z[i * p + j] = kernels::quantize<TZ>(sum, Z_scale, Z_zero_point);
53     }
54   }
55 }
56 
57 template <typename T>
_typed_quantized_matmul(const Tensor & X,int64_t X_zero_point,const Tensor & Y,int64_t Y_zero_point,const executorch::aten::optional<Tensor> & bias,int64_t out_multiplier,int64_t out_shift,int64_t out_zero_point,bool transposed,Tensor & out)58 void inline _typed_quantized_matmul(
59     const Tensor& X,
60     int64_t X_zero_point,
61     const Tensor& Y,
62     int64_t Y_zero_point,
63     const executorch::aten::optional<Tensor>& bias,
64     int64_t out_multiplier,
65     int64_t out_shift,
66     int64_t out_zero_point,
67     bool transposed,
68     Tensor& out) {
69   size_t batch_size = getLeadingDims(X, X.dim() - 2);
70   size_t leading_dim = X.size(X.dim() - 2);
71   size_t out_dim = Y.size(Y.dim() - 1 - transposed);
72   size_t in_dim = X.size(X.dim() - 1);
73 
74   T* __restrict__ out_data = out.mutable_data_ptr<T>();
75   const T* __restrict__ X_data = X.const_data_ptr<T>();
76   const T* __restrict__ Y_data = Y.const_data_ptr<T>();
77   for (size_t i = 0; i < batch_size; ++i) {
78     const T* x = X_data + i * leading_dim * in_dim;
79     const T* y = Y_data + i * in_dim * out_dim;
80     T* z = out_data + i * leading_dim * out_dim;
81     if (transposed) {
82       qmatmul<T, int32_t, true>(
83           z,
84           static_cast<int32_t>(out_multiplier),
85           static_cast<int32_t>(out_shift),
86           static_cast<int32_t>(out_zero_point),
87           x,
88           static_cast<int32_t>(X_zero_point),
89           y,
90           static_cast<int32_t>(Y_zero_point),
91           leading_dim,
92           in_dim,
93           out_dim);
94     } else {
95       qmatmul<T, int32_t, false>(
96           z,
97           static_cast<int32_t>(out_multiplier),
98           static_cast<int32_t>(out_shift),
99           static_cast<int32_t>(out_zero_point),
100           x,
101           static_cast<int32_t>(X_zero_point),
102           y,
103           static_cast<int32_t>(Y_zero_point),
104           leading_dim,
105           in_dim,
106           out_dim);
107     }
108   }
109 }
110 
quantized_matmul_out(KernelRuntimeContext & ctx,const Tensor & X,int64_t X_zero_point,const Tensor & Y,int64_t Y_zero_point,const executorch::aten::optional<Tensor> & bias,int64_t out_multiplier,int64_t out_shift,int64_t out_zero_point,bool transposed,Tensor & out)111 void quantized_matmul_out(
112     KernelRuntimeContext& ctx,
113     const Tensor& X,
114     int64_t X_zero_point,
115     const Tensor& Y,
116     int64_t Y_zero_point,
117     const executorch::aten::optional<Tensor>& bias,
118     int64_t out_multiplier,
119     int64_t out_shift,
120     int64_t out_zero_point,
121     bool transposed,
122     Tensor& out) {
123   if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
124     _typed_quantized_matmul<uint8_t>(
125         X,
126         X_zero_point,
127         Y,
128         Y_zero_point,
129         bias,
130         out_multiplier,
131         out_shift,
132         out_zero_point,
133         transposed,
134         out);
135   } else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
136     _typed_quantized_matmul<int8_t>(
137         X,
138         X_zero_point,
139         Y,
140         Y_zero_point,
141         bias,
142         out_multiplier,
143         out_shift,
144         out_zero_point,
145         transposed,
146         out);
147   } else {
148     ET_CHECK_MSG(
149         false,
150         "Unhandled input dtype %hhd",
151         static_cast<int8_t>(X.scalar_type()));
152   }
153 }
154 
155 }; // namespace native
156 }; // namespace reference
157 }; // namespace impl
158